diff --git a/mvcm/constants.py b/mvcm/constants.py index 40e7dafd0ac9126358f4fc113a20e68766253b48..1b6e10121e6392b3a87710305509f7f07ead1109 100644 --- a/mvcm/constants.py +++ b/mvcm/constants.py @@ -24,7 +24,21 @@ class SensorConstants: VIIRS_VIS_BANDS = ("M01", "M02", "M03", "M04", "M05", "M06", "M07", "M08", "M09", "M10") VIIRS_IR_BANDS = ("M11", "M12", "M13", "M14", "M15", "M16") VIIRS_IMG_BANDS = ("I01", "I02", "I03", "I04", "I05") - REFLECTIVE_BANDS = ("M01", "M04", "M05", "M07", "M10", "M11", "I01", "I02", "I03") + REFLECTIVE_BANDS = ( + "M01", + "M02", + "M03", + "M04", + "M05", + "M07", + "M08", + "M09", + "M10", + "M11", + "I01", + "I02", + "I03", + ) EMISSIVE_BANDS = ("M12", "M13", "M14", "M15", "M16", "I04", "I05") diff --git a/mvcm/main.py b/mvcm/main.py index b59c0ae7dce6284809eb0ecce2be591dd5c40b6f..0e17ea5acc5c77f6d5dd276d5176ed7855af49a9 100644 --- a/mvcm/main.py +++ b/mvcm/main.py @@ -24,6 +24,8 @@ logging.basicConfig(level="NOTSET", datefmt="[%X]", format=_LOG_FORMAT, handlers logger = logging.getLogger(__name__) LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] +proc = psutil.Process(os.getpid()) + def main( satellite: str = "snpp", @@ -104,8 +106,6 @@ def main( verbose_level = np.minimum(verbose + 1, 4) logger.setLevel(LOG_LEVELS[verbose_level]) - proc = psutil.Process(os.getpid()) - if img02 is None or img03 is None: use_hires = False else: @@ -126,7 +126,7 @@ def main( "ECO": f"{eco_file}", "ANC_DIR": f"{data_path}", } - logger.info(f"Memory usage: {proc.memory_info().rss / 1e6} MB") + logger.info(f"Memory usage #1: {proc.memory_info().rss / 1e6} MB") if hires_only is False: logger.info("Running regular MVCM before high resolution") @@ -161,14 +161,17 @@ def main( with Dataset(file_names["MOD03"]) as f: attrs = {attr: f.getncattr(attr) for attr in f.ncattrs()} - logger.info(f"Memory usage: {proc.memory_info().rss / 1e6} MB") + logger.info(f"Memory usage #2: {proc.memory_info().rss / 1e6} MB") + # rd.get_data(satellite, sensor, file_names, thresholds, hires=use_hires) viirs_data, pixel_type = rd.get_data(satellite, sensor, file_names, thresholds, hires=use_hires) - logger.info(f"Memory usage: {proc.memory_info().rss / 1e6} MB") - restore = Restoral(data=viirs_data, thresholds=thresholds, scene_flags=pixel_type) - logger.debug("Instance of Restoral class created successfully.") + # viirs_data = xr.open_dataset("input_data.nc") + logger.info(f"Memory usage #3: {proc.memory_info().rss / 1e6} MB") cmin_3d = np.ones((18, viirs_data.M11.shape[0], viirs_data.M11.shape[1]), dtype=np.float32) + # cmin_3d = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") + + logger.info(f"Memory usage #4: {proc.memory_info().rss / 1e6} MB") ########################################################## _bitlist = [ @@ -199,22 +202,23 @@ def main( for b in _bitlist: bits[b] = { - "test": np.zeros_like(viirs_data.latitude.values, dtype=np.ubyte), - "qa": np.zeros_like(viirs_data.latitude.values, dtype=np.ubyte), + "test": np.zeros(viirs_data.latitude.shape, dtype=np.ubyte), + "qa": np.zeros(viirs_data.latitude.shape, dtype=np.ubyte), + # "qa": xr.zeros_like(viirs_data.latitude, dtype=np.ubyte, chunks="auto"), } i = 0 - viirs_data["locut"] = xr.full_like( - viirs_data.latitude, fill_value=np.nan, dtype=np.float32, chunks="auto" + viirs_data["locut"] = np.full_like( + viirs_data.latitude.shape, fill_value=np.nan, dtype=np.float32 ) - viirs_data["midpt"] = xr.full_like( - viirs_data.latitude, fill_value=np.nan, dtype=np.float32, chunks="auto" + viirs_data["midpt"] = np.full_like( + viirs_data.latitude.shape, fill_value=np.nan, dtype=np.float32 ) - viirs_data["hicut"] = xr.full_like( - viirs_data.latitude, fill_value=np.nan, dtype=np.float32, chunks="auto" + viirs_data["hicut"] = np.full_like( + viirs_data.latitude.shape, fill_value=np.nan, dtype=np.float32 ) - - logger.info(f"Memory usage: {proc.memory_info().rss / 1e6} MB") + logger.info(f"viirs_data: {viirs_data}") + logger.info(f"Memory usage #5: {proc.memory_info().rss / 1e6} MB") # scene_types = np.zeros(viirs_data.M11.shape, dtype=np.ubyte) for scene_name in SceneConstants.SCENE_LIST: # scene_types[viirs_data[scene_name].values == 1] = i @@ -227,7 +231,9 @@ def main( continue logger.debug("initializing CloudTests class") - my_scene = tst.CloudTests(data=viirs_data, scene_name=scene_name, thresholds=thresholds) + my_scene = tst.CloudTests( + data=viirs_data, scene_name=scene_name, thresholds=thresholds, hires=use_hires + ) # Initialize the confidence arrays for the various test groups # cmin_g1 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") @@ -235,11 +241,11 @@ def main( # cmin_g3 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") # cmin_g4 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") # cmin_g5 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") - cmin_g1 = np.ones_like(viirs_data.latitude.values, dtype=np.float32) - cmin_g2 = np.ones_like(viirs_data.latitude.values, dtype=np.float32) - cmin_g3 = np.ones_like(viirs_data.latitude.values, dtype=np.float32) - cmin_g4 = np.ones_like(viirs_data.latitude.values, dtype=np.float32) - cmin_g5 = np.ones_like(viirs_data.latitude.values, dtype=np.float32) + cmin_g1 = np.ones(viirs_data.latitude.shape, dtype=np.float32) + cmin_g2 = np.ones(viirs_data.latitude.shape, dtype=np.float32) + cmin_g3 = np.ones(viirs_data.latitude.shape, dtype=np.float32) + cmin_g4 = np.ones(viirs_data.latitude.shape, dtype=np.float32) + cmin_g5 = np.ones(viirs_data.latitude.shape, dtype=np.float32) # cmin_temp = np.ones(viirs_data.M11.shape) logger.debug("starting tests") @@ -248,7 +254,7 @@ def main( else: m15_name = "M15" - logger.info(f"Memory usage: {proc.memory_info().rss / 1e6} MB") + logger.info(f"Memory usage #6: {proc.memory_info().rss / 1e6} MB") # Group 1 cmin_g1, bits["01"] = my_scene.test_11um(m15_name, cmin_g1, bits["01"]) # this test needs to be implemented properly, @@ -282,7 +288,7 @@ def main( cmin_g4, bits["15"] = my_scene.test_1_38um_high_clouds("M09", cmin_g4, bits["15"]) # Group 5 - cmin_g5, bits["16"] = my_scene.thin_cirrus_4_12um_BTD_test("M13-M16", cmin_g5, bits["16"]) + # cmin_g5, bits["16"] = my_scene.thin_cirrus_4_12um_BTD_test("M13-M16", cmin_g5, bits["16"]) # logger.debug(f"Memory: {tracemalloc.get_traced_memory()}") bit = {} @@ -293,21 +299,23 @@ def main( qabit[b] = bits[f"{b}"]["qa"] # restoral_bits[b] = utils.restoral_flag(bits[f"{b}"]) # # if utils.group_count(bit) != 0: - """ - cmin_3d[i, :, :] = np.fmin(cmin_temp, - np.power(cmin_g1 * cmin_g2 * cmin_g3 * cmin_g4 * cmin_g5, - 1/utils.group_count(bit))) - """ + cmin = np.power( cmin_g1 * cmin_g2 * cmin_g3 * cmin_g4 * cmin_g5, 1 / utils.group_count(qabit) ) cmin_3d[i, :, :] = cmin + # cmin_3d = np.minimum(cmin_3d, cmin) i += 1 - logger.info(f"Memory usage: {proc.memory_info().rss / 1e6} MB") + logger.info(f"Memory usage #7: {proc.memory_info().rss / 1e6} MB") cmin = np.min(cmin_3d, axis=0) + # cmin = cmin_3d + # pixel_type = xr.open_dataset("pixels_data.nc") + restore = Restoral(data=viirs_data, thresholds=thresholds, scene_flags=pixel_type) + logger.debug("Instance of Restoral class created successfully.") + logger.info(f"Memory usage #8: {proc.memory_info().rss / 1e6} MB") # logger.debug(f"Memory: {tracemalloc.get_traced_memory()}") # bit = {} diff --git a/mvcm/read_data.py b/mvcm/read_data.py index 7ed8b626db12822e73c81d512f817e99b6f33f48..878c9f6ff3cf74a38cecc5c9ad3b808d0ef956ff 100644 --- a/mvcm/read_data.py +++ b/mvcm/read_data.py @@ -6,12 +6,13 @@ from datetime import datetime as dt import numpy as np import numpy.typing as npt +import psutil import xarray as xr from attrs import Factory, define, field, validators import ancillary_data as anc import mvcm.scene as scn -from mvcm.constants import SensorConstants +from mvcm.constants import ConstantsNamespace, SensorConstants _DTR = np.pi / 180.0 _RTD = 180.0 / np.pi @@ -60,6 +61,7 @@ _mod_bands = [ ] logger = logging.getLogger(__name__) +proc = psutil.Process(os.getpid()) @define(kw_only=True, slots=True) @@ -144,6 +146,7 @@ class ReadData(CollectInputs): dataset containing all geolocation data """ logger.debug(f"Reading {self.file_name_geo}") + logger.info(f"read_viirs_geo #1: Memory usage: {proc.memory_info().rss / 1e6} MB") if os.path.exists(self.file_name_geo) is False: err_msg = f"Could not find the file {self.file_name_geo}" @@ -151,31 +154,24 @@ class ReadData(CollectInputs): raise FileNotFoundError(err_msg) geo_data = xr.open_dataset( - self.file_name_geo, - group="geolocation_data", - engine="netcdf4", - chunks="auto", + self.file_name_geo, group="geolocation_data", engine="netcdf4", chunks="auto" ) - relazi = self.relative_azimuth_angle( - geo_data.sensor_azimuth.values, geo_data.solar_azimuth.values + geo_data["relative_azimuth"] = self.relative_azimuth_angle( + geo_data.sensor_azimuth, geo_data.solar_azimuth ) - sunglint = self.sun_glint_angle( - geo_data.sensor_zenith.values, geo_data.solar_zenith.values, relazi + geo_data["sunglint_angle"] = self.sun_glint_angle( + geo_data.sensor_zenith, geo_data.solar_zenith, geo_data.relative_azimuth ) - scatt_angle = self.scattering_angle( - geo_data.solar_zenith.values, geo_data.sensor_zenith.values, relazi + geo_data["scattering_angle"] = self.scattering_angle( + geo_data.solar_zenith, geo_data.sensor_zenith, geo_data.relative_azimuth ) - geo_data["relative_azimuth"] = (self.dims, relazi) - geo_data["sunglint_angle"] = (self.dims, sunglint) - geo_data["scattering_angle"] = (self.dims, scatt_angle) - logger.debug("Geolocation file read correctly") - + logger.info(f"read_viirs_geo #5: Memory usage: {proc.memory_info().rss / 1e6} MB") return geo_data - def read_viirs_l1b(self, solar_zenith: npt.NDArray) -> xr.Dataset: + def read_viirs_l1b(self, solar_zenith: xr.DataArray) -> xr.Dataset: """Read VIIRS L1b data. Parameters @@ -187,6 +183,7 @@ class ReadData(CollectInputs): """ logger.debug(f"Reading {self.file_name_l1b}") + logger.info(f"read_viirs_l1b #1: Memory usage: {proc.memory_info().rss / 1e6} MB") if os.path.exists(self.file_name_l1b) is False: err_msg = f"Could not find the file {self.file_name_l1b}" logger.error(err_msg) @@ -200,6 +197,7 @@ class ReadData(CollectInputs): chunks="auto", ) + logger.info(f"read_viirs_l1b #2: Memory usage: {proc.memory_info().rss / 1e6} MB") rad_data = xr.Dataset() for band in list(l1b_data.variables): if band in SensorConstants.REFLECTIVE_BANDS: @@ -212,7 +210,9 @@ class ReadData(CollectInputs): scale_factor = l1b_data[band].scale_factor rad_data[band] = ( self.dims, - l1b_data[band].values * scale_factor / np.cos(solar_zenith * _DTR), + l1b_data[band].data + * scale_factor + / np.cos(solar_zenith * ConstantsNamespace.DTR), ) else: logger.info(f"Reflective band {band} not found in L1b file") @@ -230,8 +230,9 @@ class ReadData(CollectInputs): pass logger.debug("L1b file read correctly") + logger.info(f"read_viiirs_l1b #3: Memory usage: {proc.memory_info().rss / 1e6} MB") - return rad_data.chunk("auto") + return rad_data.chunk() def preprocess_viirs( self, geo_data: xr.Dataset, viirs: xr.Dataset, hires_data: xr.Dataset | None = None @@ -264,12 +265,12 @@ class ReadData(CollectInputs): viirs_out = xr.Dataset() mod_bands = [ "M01", - # "M02", - # "M03", - # "M04", + "M02", + "M03", + "M04", # "M06", - # "M08", - # "M09", + "M08", + "M09", "M11", "M13", "M14", @@ -287,14 +288,14 @@ class ReadData(CollectInputs): viirs_out["M07"] = (self.dims, hires_data.I02.values) if ("M05" in viirs) and ("M07" in viirs): - m01 = viirs_out.M05.values - m02 = viirs_out.M07.values + m01 = viirs_out.M05 + m02 = viirs_out.M07 r1 = 2.0 * (np.power(m02, 2.0) - np.power(m01, 2.0)) + (1.5 * m02) + (0.5 * m01) r2 = m01 + m02 + 0.5 r3 = r1 / r2 gemi = r3 * (1.0 - 0.25 * r3) - ((m01 - 0.125) / (1.0 - m01)) else: - gemi = np.full((viirs_out.M16.shape), _bad_data) + gemi = xr.full_like(viirs_out.M16, fill_value=ConstantsNamespace.BAD_DATA) # Compute channel differences and ratios that are used in the tests # if ("M05" in viirs) and ("M07" in viirs): @@ -321,25 +322,25 @@ class ReadData(CollectInputs): # ) # viirs_out["M15-M13"] = (self.dims, viirs_out.M15.values - viirs_out.M13.values) # viirs_out["M15-M16"] = (self.dims, viirs_out.M15.values - viirs_out.M16.values) - viirs_out["GEMI"] = (self.dims, gemi) + viirs_out["GEMI"] = (self.dims, gemi.data) if hires_data is not None: - viirs_out["M10"] = (self.dims, hires_data.I03.values) - viirs_out["M12"] = (self.dims, hires_data.I04.values) - viirs_out["M15hi"] = (self.dims, hires_data.I05.values) + viirs_out["M10"] = (self.dims, hires_data.I03.data) + viirs_out["M12"] = (self.dims, hires_data.I04.data) + viirs_out["M15hi"] = (self.dims, hires_data.I05.data) # temp value to force the code to work - viirs_out["M128"] = (self.dims, np.zeros(viirs_out.M16.shape)) + viirs_out["M128"] = xr.zeros_like(viirs_out.M16) viirs_out.update(geo_data) - viirs_out = viirs_out.set_coords(["latitude", "longitude"]).chunk("auto") + viirs_out = viirs_out.set_coords(["latitude", "longitude"]) logger.debug("Viirs preprocessing completed successfully.") return viirs_out def relative_azimuth_angle( - self, sensor_azimuth: npt.NDArray, solar_azimuth: npt.NDArray - ) -> npt.NDArray: + self, sensor_azimuth: xr.DataArray, solar_azimuth: xr.DataArray + ) -> xr.DataArray: """Compute relative azimuth angle. Parameters @@ -354,17 +355,16 @@ class ReadData(CollectInputs): relative_azimuth: np.ndarray """ rel_azimuth = np.abs(180.0 - np.abs(sensor_azimuth - solar_azimuth)) - logger.debug("Relative azimuth calculated successfully.") return rel_azimuth def sun_glint_angle( self, - sensor_zenith: npt.NDArray, - solar_zenith: npt.NDArray, - rel_azimuth: npt.NDArray, - ) -> npt.NDArray: + sensor_zenith: xr.DataArray, + solar_zenith: xr.DataArray, + rel_azimuth: xr.DataArray, + ) -> xr.DataArray: """Compute sun glint angle. Parameters @@ -383,7 +383,8 @@ class ReadData(CollectInputs): cossna = np.sin(sensor_zenith * _DTR) * np.sin(solar_zenith * _DTR) * np.cos( rel_azimuth * _DTR ) + np.cos(sensor_zenith * _DTR) * np.cos(solar_zenith * _DTR) - cossna[cossna > 1] = 1 + # cossna = xr.where(cossna > 1, 1, cossna) + cossna = cossna.clip(None, 1) sunglint_angle = np.arccos(cossna) * _RTD logger.debug("Sunglint generated") @@ -392,10 +393,10 @@ class ReadData(CollectInputs): def scattering_angle( self, - solar_zenith: npt.NDArray, - sensor_zenith: npt.NDArray, - relative_azimuth: npt.NDArray, - ) -> npt.NDArray: + solar_zenith: xr.DataArray, + sensor_zenith: xr.DataArray, + relative_azimuth: xr.DataArray, + ) -> xr.DataArray: """Compute scattering angle. Parameters @@ -493,7 +494,7 @@ class ReadAncillary(CollectInputs): sst, ) logger.debug("SST file read successfully") - return sst.reshape(self.out_shape) + return xr.DataArray(dims=self.dims, data=sst.reshape(self.out_shape)) def get_ndvi(self) -> npt.NDArray: """Read NDVI file. @@ -523,7 +524,7 @@ class ReadAncillary(CollectInputs): ndvi, ) logger.debug("NDVI file read successfully") - return ndvi.reshape(self.out_shape) + return xr.DataArray(dims=self.dims, data=ndvi.reshape(self.out_shape)) def get_eco(self) -> npt.NDArray: """Read ECO file. @@ -554,7 +555,7 @@ class ReadAncillary(CollectInputs): eco, ) logger.debug("Olson ecosystem file read successfully") - return eco.reshape(self.out_shape) + return xr.DataArray(dims=self.dims, data=eco.reshape(self.out_shape)) def get_geos(self) -> dict: """Read GEOS-5 data and interpolate the fields to the sensor resolution. @@ -605,11 +606,11 @@ class ReadAncillary(CollectInputs): "land_ice_fraction", "surface_temperature", ] - geos_data = { + geos_data_dict = { var: np.empty(self.out_shape, dtype=np.float32).ravel() for var in geos_variables } - geos_data = anc.py_get_GEOS( + geos_data_dict = anc.py_get_GEOS( self.latitude.ravel(), self.longitude.ravel(), self.latitude.shape[0], @@ -621,11 +622,40 @@ class ReadAncillary(CollectInputs): self.geos_land, self.geos_ocean, self.geos_constants, - geos_data, + geos_data_dict, ) - for var in list(geos_variables): - geos_data[var] = geos_data[var].reshape(self.out_shape) + # for var in list(geos_variables): + # geos_data[var] = geos_data[var].reshape(self.out_shape) + + geos_data = xr.Dataset().from_dict( + { + "geos_tpw": { + "dims": self.dims, + "data": geos_data_dict["tpw"].reshape(self.out_shape), + }, + "geos_snow_fraction": { + "dims": self.dims, + "data": geos_data_dict["snow_fraction"].reshape(self.out_shape), + }, + "geos_ice_fraction": { + "dims": self.dims, + "data": geos_data_dict["ice_fraction"].reshape(self.out_shape), + }, + "geos_ocean_fraction": { + "dims": self.dims, + "data": geos_data_dict["ocean_fraction"].reshape(self.out_shape), + }, + "geos_land_ice_fraction": { + "dims": self.dims, + "data": geos_data_dict["land_ice_fraction"].reshape(self.out_shape), + }, + "geos_surface_temperature": { + "dims": self.dims, + "data": geos_data_dict["surface_temperature"].reshape(self.out_shape), + }, + } + ) logger.debug("GEOS data read successfully") return geos_data @@ -642,15 +672,32 @@ class ReadAncillary(CollectInputs): ancillary_data: xr.Dataset dataset containing all the ancillary data """ - ancillary_data = xr.Dataset() - ancillary_data["sst"] = (self.dims, self.get_sst()) - ancillary_data["ecosystem"] = (self.dims, self.get_eco()) - ancillary_data["ndvi"] = (self.dims, self.get_ndvi()) - - geos_tmp = self.get_geos() - for var in list(geos_tmp.keys()): - ancillary_data[f"geos_{var}"] = (self.dims, geos_tmp[var]) - + ancillary_data = self.get_geos() + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + ancillary_data.update({"ndvi": self.get_ndvi()}) + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + ancillary_data.update({"sst": self.get_sst()}) + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + ancillary_data.update({"ecosystem": self.get_eco()}) + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + # ancillary_data = xr.Dataset() + # logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + # ancillary_data["sst"] = (self.dims, self.get_sst().data) + # logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + # ancillary_data["ecosystem"] = (self.dims, self.get_eco().data) + # logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + # ancillary_data["ndvi"] = (self.dims, self.get_ndvi().data) + # logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + + # geos_tmp = self.get_geos() + # for var in list(geos_tmp.keys()): + # ancillary_data[f"geos_{var}"] = (self.dims, geos_tmp[var]) + + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + # ancillary_data.to_netcdf("tmp_ancillary_data.nc") + # ancillary_data = xr.open_dataset("tmp_ancillary_data.nc", chunks="auto") + # logger.info("dumping ancillary to temp file") + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") return ancillary_data @@ -702,6 +749,7 @@ def get_data( geo_data = viirs_hires.read_viirs_geo() viirs_data_img = viirs_hires.read_viirs_l1b(geo_data.solar_zenith.values) viirs_data = viirs_hires.preprocess_viirs(geo_data, viirs_data_mod, viirs_data_img) + # viirs_data.update(geo_data) res = 1 else: viirs_data = viirs.preprocess_viirs(geo_data_mod, viirs_data_mod) @@ -725,15 +773,24 @@ def get_data( ) logger.info("Ancillary data read successfully") + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + # ancillary.pack_data() + # ancillary_data = xr.open_dataset("tmp_ancillary_data.nc", chunks="auto") + # viirs_data = xr.merge([viirs_data, ancillary_data]) viirs_data.update(ancillary.pack_data()) logger.info("Ancillary data added to the dataset") + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") scene_xr, pixels_flags = scn.scene_id( viirs_data, geo_data.sunglint_angle, thresholds["Snow_Mask"] ) - input_data = xr.Dataset(viirs_data, coords=scene_xr).chunk("auto") + input_data = xr.Dataset(viirs_data, coords=scene_xr).chunk() input_data.drop_vars(["latitude", "longitude"]) logger.info("get_data() ran successfully") + logger.info(f"read_data: Memory usage: {proc.memory_info().rss / 1e6} MB") + + # input_data.to_netcdf("input_data.nc") + # pixels_flags.to_netcdf("pixels_flags.nc") return input_data, pixels_flags diff --git a/mvcm/scene.py b/mvcm/scene.py index 07ea472527ffbc9931a5e0ac9f324b55400bbeb6..9aa9f5f8181e9857e54b6f5094678fca86bab32a 100644 --- a/mvcm/scene.py +++ b/mvcm/scene.py @@ -1,8 +1,10 @@ """Functions that define scene type.""" import logging # noqa +import os import numpy as np +import psutil import xarray as xr from attrs import Factory, define, field, validators @@ -11,6 +13,7 @@ from mvcm.constants import ConstantsNamespace as Constants from mvcm.constants import SceneConstants logger = logging.getLogger(__name__) +proc = psutil.Process(os.getpid()) def compute_glint_angle(vza, sza, raz): @@ -46,26 +49,39 @@ class IdentifyPixels: def find_day(self): """Find daytime pixels.""" logger.debug("Finding daytime pixels") - return xr.where(self.data.solar_zenith <= SceneConstants.SOLAR_ZENITH_DAY, 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where(self.data.solar_zenith <= SceneConstants.SOLAR_ZENITH_DAY, 1, 0).astype( + np.ubyte + ) def find_visusd(self): """Find where daytime pixels are used.""" logger.debug("Finding pixels where VIS is used") - return xr.where(self.data.solar_zenith <= SceneConstants.SOLAR_ZENITH_DAY, 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where(self.data.solar_zenith <= SceneConstants.SOLAR_ZENITH_DAY, 1, 0).astype( + np.ubyte + ) def find_night(self): """Find nightime pixels.""" logger.debug("Finding nightime pixels") - return xr.where(self.data.solar_zenith > SceneConstants.SOLAR_ZENITH_DAY, 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where(self.data.solar_zenith > SceneConstants.SOLAR_ZENITH_DAY, 1, 0).astype( + np.ubyte + ) def find_polar(self): """Find polar pixels.""" logger.debug("Finding polar pixels") - return xr.where(np.abs(self.data.latitude) > SceneConstants.POLAR_LATITUDES, 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where(np.abs(self.data.latitude) > SceneConstants.POLAR_LATITUDES, 1, 0).astype( + np.ubyte + ) def find_coast(self, day_scenes, uniformity): """Find coast pixels.""" logger.debug("Finding coast pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") coast_scene = xr.where( ( (self.data.land_water_mask == SceneConstants.LAND) @@ -118,21 +134,29 @@ class IdentifyPixels: coast_scene = xr.where(uniformity < 0, 1, coast_scene) - return coast_scene + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return coast_scene.astype(np.ubyte) def find_shallow_lake(self): """Find shallow lake pixels.""" logger.debug("Finding shallow lake pixels") - return xr.where(self.data.land_water_mask == SceneConstants.SHALLOW_INLAND, 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where(self.data.land_water_mask == SceneConstants.SHALLOW_INLAND, 1, 0).astype( + np.ubyte + ) def find_shallow_ocean(self): """Find shallow ocean pixels.""" logger.debug("Finding shallow ocean pixels") - return xr.where(self.data.land_water_mask == SceneConstants.SHALLOW_OCEAN, 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where(self.data.land_water_mask == SceneConstants.SHALLOW_OCEAN, 1, 0).astype( + np.ubyte + ) def find_water(self, uniformity): """Find water pixels.""" logger.debug("Finding water pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") water_scene = xr.zeros_like(self.data.latitude, dtype=np.int8) water_scene = xr.where( @@ -152,13 +176,15 @@ class IdentifyPixels: water_scene, ) - water_scene = xr.where(uniformity < 0, 0, water_scene) + water_scene = xr.where(uniformity < 0, 0, water_scene).astype(np.ubyte) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") return water_scene def find_snow(self, day_scene, land_scene, ndsi_snow, map_snow, new_zealand): """Find snow pixels.""" logger.debug("Finding snow pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") snow_scene = xr.where( ((day_scene == 1) & (land_scene == 1)) & ( @@ -182,11 +208,13 @@ class IdentifyPixels: snow_scene, ) - return snow_scene + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return snow_scene.astype(np.ubyte) def find_ice(self, day_scene, water_scene, ndsi_snow, map_ice): """Find ice pixels.""" logger.debug("Finding ice pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") ice_scene = xr.where( ( (day_scene == 1) @@ -217,11 +245,13 @@ class IdentifyPixels: ice_scene, ) - return ice_scene + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return ice_scene.astype(np.ubyte) def find_new_zealand(self, day_scene, land_scene): """Find new zealand pixels.""" logger.debug("Finding New Zealand pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") new_zealand = xr.where( (day_scene == 1) & (land_scene == 1) @@ -233,21 +263,25 @@ class IdentifyPixels: 0, ) - return new_zealand + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return new_zealand.astype(np.ubyte) def check_uniformity(self): """Check where surrounding pixels are uniform.""" + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") eco = np.array(self.data.ecosystem.values, dtype=np.uint8) lsf = np.array(self.data.land_water_mask.values, dtype=np.uint8) uniformity = anc.py_check_reg_uniformity( eco, eco, self.data.geos_snow_fraction.values, self.data.geos_ice_fraction.values, lsf )["loc_uniform"] + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") return uniformity def find_australia(self): """Find Australia pixels.""" logger.debug("Finding australia pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") australia = xr.where( (self.data.latitude < SceneConstants.AUSTRALIA_S_LAT_BB) & (self.data.latitude > SceneConstants.AUSTRALIA_N_LAT_BB) @@ -257,11 +291,13 @@ class IdentifyPixels: 0, ) - return australia + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return australia.astype(np.ubyte) def find_desert(self, land_scene): """Find desert pixels.""" logger.debug("Finding desert pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") desert_scene = xr.where( ( (land_scene == 1) & (self.data.ndvi < SceneConstants.DESERT_NDVI) @@ -271,11 +307,13 @@ class IdentifyPixels: 0, ) - return desert_scene + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return desert_scene.astype(np.ubyte) def find_land(self, uniformity): """Find land pixels.""" logger.debug("Finding land pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") land_scene = xr.ones_like(self.data.latitude, dtype=np.int8) land_scene = xr.where( @@ -321,18 +359,23 @@ class IdentifyPixels: land_scene, ) - land_scene = xr.where(uniformity < 0, 1, land_scene) + land_scene = xr.where(uniformity < 0, 1, land_scene).astype(np.ubyte) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") return land_scene def find_sunglint(self, sunglint_angle_threshold, scene_day): """Find sun glint pixels.""" logger.debug("Finding sun glint pixels") - return xr.where((scene_day == 1) & (self.glint_angle <= sunglint_angle_threshold), 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where( + (scene_day == 1) & (self.glint_angle <= sunglint_angle_threshold), 1, 0 + ).astype(np.ubyte) def find_greenland(self, land_scene): """Find greenland pixels.""" logger.debug("Finding Greenland pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") greenland_scene = xr.where( (land_scene == 1) & ( @@ -356,11 +399,13 @@ class IdentifyPixels: 0, ) - return greenland_scene + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return greenland_scene.astype(np.ubyte) def find_high_elevation(self, land_scene, greenland_scene): """Find high elevation pixels.""" logger.debug("Finding high elevation pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") high_elevation_scene = xr.where( (self.data.height > SceneConstants.HIGH_ELEVATION) | ( @@ -379,16 +424,19 @@ class IdentifyPixels: 0, ) - return high_elevation_scene + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return high_elevation_scene.astype(np.ubyte) def find_antarctica(self): """Find antarctica pixels.""" logger.debug("Finding Antarctica pixels") - return xr.where(self.data.latitude < SceneConstants.ANTARCTICA_LAT, 1, 0) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return xr.where(self.data.latitude < SceneConstants.ANTARCTICA_LAT, 1, 0).astype(np.ubyte) def find_vis_ratio_used(self): """Find vis ratio used pixels.""" logger.debug("Finding visible ratio used pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") scene_vis_ratio_used = xr.where( (self.data.ecosystem == SceneConstants.ECO_LOW_SPARSE_GRASSLAND) | (self.data.ecosystem == SceneConstants.ECO_BARE_DESERT) @@ -405,11 +453,13 @@ class IdentifyPixels: 1, ) - return scene_vis_ratio_used + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return scene_vis_ratio_used.astype(np.ubyte) def find_map_snow(self): """Find map snow pixels.""" logger.debug("Finding map snow pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") snow_scene = xr.where( ( (self.data.geos_snow_fraction > SceneConstants.MIN_SNOW_ICE_FRACTION) @@ -423,11 +473,13 @@ class IdentifyPixels: 0, ) - return snow_scene + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") + return snow_scene.astype(np.ubyte) def find_map_ice(self): """Find map ice pixels.""" logger.debug("Finding map ice pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") ice_scene = xr.where( (self.data.geos_ice_fraction > SceneConstants.MIN_SNOW_ICE_FRACTION) & (self.data.geos_ice_fraction <= SceneConstants.MAX_SNOW_ICE_FRACTION), @@ -435,13 +487,14 @@ class IdentifyPixels: 0, ) - return ice_scene + return ice_scene.astype(np.ubyte) def find_ndsi_snow( self, greenland, antarctica, land, water, sunglint, high_elevation, snow_mask_thresholds ): """Find ndsi snow pixels.""" logger.debug("Finding ndsi snow pixels") + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") radiances = { "m01": np.array(self.data.M05.values, dtype=np.float32), "m02": np.array(self.data.M07.values, dtype=np.float32), @@ -452,6 +505,7 @@ class IdentifyPixels: "nir": np.array(self.data.M10.values, dtype=np.float32), } + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") ancillary = { "greenland": np.array(greenland.values, dtype=np.ubyte), "antarctic": np.array(antarctica.values, dtype=np.ubyte), @@ -465,15 +519,17 @@ class IdentifyPixels: ), } + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") lsf = np.array(self.data.land_water_mask.values, dtype=np.ubyte) lat = np.array(self.data.latitude.values, dtype=np.float32) + logger.info(f"scene: Memory usage: {proc.memory_info().rss / 1e6} MB") logger.debug("Starting cython function") ndsi_arr = anc.py_snow_mask(lat, lsf, radiances, ancillary, snow_mask_thresholds) logger.debug("Cython function finished") ndsi_snow = xr.DataArray(ndsi_arr, dims=land.dims) - return ndsi_snow + return ndsi_snow.astype(np.ubyte) def scene_id( @@ -584,7 +640,7 @@ def scene_id( ) # Snow Day - scene_type["Snow_Day"] = xr.where( + scene_type["Day_Snow"] = xr.where( (scene_flags.day == 1) & ((scene_flags.ice == 1) | (scene_flags.snow == 1)) & (scene_flags.polar == 0) @@ -594,7 +650,7 @@ def scene_id( ) # Snow Night - scene_type["Snow_Night"] = xr.where( + scene_type["Night_Snow"] = xr.where( (scene_flags.night == 1) & ((scene_flags.ice == 1) | (scene_flags.snow == 1)) & (scene_flags.polar == 0) @@ -761,6 +817,9 @@ def scene_id( scene_type["Desert"] = xr.where(scene_flags.desert == 1, 1, 0) scene_type["Australia"] = xr.where(scene_flags.australia == 1, 1, 0) + for var in scene_type.data_vars: + scene_type[var] = scene_type[var].astype(np.ubyte) + logger.debug("scene_id run completed") - return scene_type, scene_flags.chunk("auto") + return scene_type.chunk(), scene_flags.chunk() diff --git a/mvcm/spectral_tests.py b/mvcm/spectral_tests.py index e18a84e4aa2694e7beb8ea7834a6d67d90bf4b92..d9a8978316052fb4ff1f1246b657d46f878a8993 100644 --- a/mvcm/spectral_tests.py +++ b/mvcm/spectral_tests.py @@ -3,8 +3,10 @@ import functools # noqa import importlib import logging +import os import numpy as np +import psutil import xarray as xr from attrs import Factory, define, field, validators from numpy.lib.stride_tricks import sliding_window_view @@ -24,6 +26,8 @@ importlib.reload(restoral) logger = logging.getLogger(__name__) +proc = psutil.Process(os.getpid()) + @define(kw_only=True, slots=True) class CloudTests: @@ -52,6 +56,9 @@ class CloudTests: validators.instance_of(dict), ] ) + hires: bool = field( + validator=[validators.instance_of(bool)], + ) scene_idx: tuple = field( init=False, default=Factory( @@ -89,7 +96,7 @@ class CloudTests: if test_name not in self.thresholds[self.scene_name]: return args[-2], args[-1] else: - kwargs["confidence"] = np.ones(self.data.latitude.shape) + kwargs["confidence"] = np.ones(self.data.latitude.shape, dtype=np.float32) kwargs["thresholds"] = self.thresholds[self.scene_name][test_name] return func(self, *args, **kwargs) @@ -97,6 +104,72 @@ class CloudTests: return decorate + def xr_compute_test_bits( + self, rad: np.ndarray, thr: np.ndarray, ineq: str, scene_idx=None + ) -> np.ndarray: + """Compute tests bits based on thresholds. + + Parameters + ---------- + rad: np.ndarray + array of radiances + thr: np.ndarray + array of thresholds + ineq: str + string representing inequality to use + + Returns + ------- + test_bits: np.ndarray + binary array where values are 0 if inequality is false, 1 otherwise + """ + scene_idx = scene_idx or self.scene_idx + + idx = None + test_bits = None + # rad = rad.ravel() + if ineq == "<=": + test_bits = xr.where( + (rad.data.ravel() <= thr) + & (rad.notnull()) + & (self.data[self.scene_name].values[scene_idx] == 1), + 1, + 0, + ) + # idx = np.nonzero((rad <= thr) & (self.data[self.scene_name].values[scene_idx] == 1)) + if ineq == "<": + test_bits = xr.where( + (rad < thr) & (rad.notnull()) & (self.data[self.scene_name].values[scene_idx] == 1), + 1, + 0, + ) + # idx = np.nonzero((rad < thr) & (self.data[self.scene_name].values[scene_idx] == 1)) + if ineq == ">": + test_bits = xr.where( + (rad > thr) & (rad.notnull()) & (self.data[self.scene_name].values[scene_idx] == 1), + 1, + 0, + ) + # idx = np.nonzero((rad > thr) & (self.data[self.scene_name].values[scene_idx] == 1)) + if ineq == ">=": + test_bits = xr.where( + (rad >= thr) + & (rad.notnull()) + & (self.data[self.scene_name].values[scene_idx] == 1), + 1, + 0, + ) + # idx = np.nonzero((rad >= thr) & (self.data[self.scene_name].values[scene_idx] == 1)) + + if idx is None: + errstr = "Something went wrong and idx is not defined." + raise ValueError(errstr) + + # test_bits = np.zeros(rad.shape) + # test_bits[idx] = 1 + + return test_bits + def compute_test_bits( self, rad: np.ndarray, thr: np.ndarray, ineq: str, scene_idx=None ) -> np.ndarray: @@ -252,13 +325,20 @@ class CloudTests: self, band: str, cmin: np.ndarray, bits: dict, **kwargs ) -> tuple[np.ndarray, dict]: """Perform 11-12um difference spectral test.""" + logger.info(f"spectral_tests: Memory usage #1: {proc.memory_info().rss / 1e6} MB") threshold = kwargs["thresholds"] - + logger.info("Step 1") + logger.info(f"spectral_tests: Memory usage #2: {proc.memory_info().rss / 1e6} MB") if threshold["perform"] is True and self.pixels_in_scene is True: + # bits["qa"] = xr.where(self.data[self.scene_name] == 1, 1, 0) bits["qa"][self.scene_idx] = 1 + logger.info(f"spectral_tests: Memory usage #3: {proc.memory_info().rss / 1e6} MB") - rad = self.data["M15"].values[self.scene_idx] - self.data["M16"].values[self.scene_idx] + rad = self.data.M15.values[self.scene_idx] - self.data.M16.values[self.scene_idx] + # rad = xr.where(self.data[self.scene_name] == 1, self.data.M15 - self.data.M16, np.nan) + # thr = np.full((5, self.data.latitude.shape[0] * self.data.latitude.shape[1]), np.nan) thr = preproc.thresholds_11_12um(self.data, threshold, self.scene_name, self.scene_idx) + logger.info(f"spectral_tests: Memory usage #4: {proc.memory_info().rss / 1e6} MB") # rad = self.data[band].values[self.scene_idx] logger.info(f'Running test 11-12um_Cirrus_Test on "{self.scene_name}"\n') @@ -268,18 +348,29 @@ class CloudTests: comparison = "<=" if np.ndim(thr) == 1: thr = thr.reshape(thr.shape[0], 1) + logger.info("Step 2") + logger.info(f"spectral_tests: Memory usage #5: {proc.memory_info().rss / 1e6} MB") bits["test"][self.scene_idx] = self.compute_test_bits(rad, thr[1, :], comparison) + # bits["test"] = self.xr_compute_test_bits(rad, thr[1, :], comparison) + logger.info(f"spectral_tests: Memory usage #6: {proc.memory_info().rss / 1e6} MB") + logger.info("Step 3") f_rad = np.array(rad, dtype=np.float32) locut = np.array(thr[0, :], dtype=np.float32) midpt = np.array(thr[1, :], dtype=np.float32) hicut = np.array(thr[2, :], dtype=np.float32) power = np.array(thr[3, :], dtype=np.float32) + logger.info("Step 4") + logger.info(f"spectral_tests: Memory usage #7: {proc.memory_info().rss / 1e6} MB") + # temp_confidence = anc.py_conf_test( f_rad, locut, hicut, power, midpt) kwargs["confidence"][self.scene_idx] = anc.py_conf_test( f_rad, locut, hicut, power, midpt ) + logger.info(f"spectral_tests: Memory usage #8: {proc.memory_info().rss / 1e6} MB") + logger.info("Step 5") cmin = np.fmin(cmin, kwargs["confidence"]) + logger.info(f"spectral_tests: Memory usage #9: {proc.memory_info().rss / 1e6} MB") bits["test"] = bits["test"].reshape(self.data.latitude.shape) return cmin, bits