From a52fd683456e9e7794cebff5a33aa81724db172c Mon Sep 17 00:00:00 2001
From: Paolo Veglio <paolo.veglio@ssec.wisc.edu>
Date: Thu, 23 May 2024 14:43:10 +0000
Subject: [PATCH] changed data type to reduce memory footprint

---
 mvcm/main.py                  | 39 ++++++++++++++++++++++++++---------
 mvcm/preprocess_thresholds.py | 10 ++++-----
 mvcm/scene.py                 | 13 +++++++-----
 3 files changed, 42 insertions(+), 20 deletions(-)

diff --git a/mvcm/main.py b/mvcm/main.py
index 7049ec8..f565bc8 100644
--- a/mvcm/main.py
+++ b/mvcm/main.py
@@ -96,6 +96,8 @@ _mod_bands = [
     "M16",
 ]
 
+_img_bands = ["I01", "I02", "I03", "I04", "I05"]
+
 
 def timer(func):
     """Compute elapsed time."""
@@ -210,13 +212,22 @@ def main(
     # We are not processing night granules
     if use_hires is True:
         with xr.open_dataset(file_names["IMG02"], group="observation_data") as vnp02:
-            for b in _mod_bands:
+            for b in _img_bands:
                 try:
                     vnp02[b]
                 except KeyError:
                     logger.info(f"Band {b} not found in file. No output will be written.")
                     return
             logger.info(f"All bands found in file {file_names['IMG02']}. The code will run.")
+    else:
+        with xr.open_dataset(file_names["MOD02"], group="observation_data") as vnp02:
+            for b in _mod_bands:
+                try:
+                    vnp02[b]
+                except KeyError:
+                    logger.info(f"Band {b} not found in file. No output will be written.")
+                    return
+            logger.info(f"All bands found in file {file_names['MOD02']}. The code will run.")
 
     with Dataset(file_names["MOD03"]) as f:
         # time_coverage_start = f.getncattr("time_coverage_start")
@@ -236,7 +247,7 @@ def main(
         log_level=LOG_LEVELS[verbose_level],
     )
 
-    cmin_3d = np.ones((18, viirs_data.M11.shape[0], viirs_data.M11.shape[1]))
+    cmin_3d = np.ones((18, viirs_data.M11.shape[0], viirs_data.M11.shape[1]), dtype=np.float32)
 
     ##########################################################
     _bitlist = [
@@ -267,8 +278,8 @@ def main(
 
     for b in _bitlist:
         bits[b] = {
-            "test": np.zeros(viirs_data.M11.shape),
-            "qa": np.zeros(viirs_data.M11.shape),
+            "test": np.zeros(viirs_data.M11.shape, dtype=np.int8),
+            "qa": np.zeros(viirs_data.M11.shape, dtype=np.int8),
         }
     i = 0
 
@@ -282,7 +293,7 @@ def main(
         np.zeros(viirs_data.M11.shape) * np.nan, dims=viirs_data.M11.dims
     )
 
-    scene_types = np.zeros(viirs_data.M11.shape)
+    scene_types = np.zeros(viirs_data.M11.shape, dtype=np.int8)
     for _scene_i, scene_name in enumerate(_scene_list):
         scene_types[viirs_data[scene_name].values == 1] = i
 
@@ -295,11 +306,11 @@ def main(
         my_scene = tst.CloudTests(data=viirs_data, scene_name=scene_name, thresholds=thresholds)
 
         # Initialize the confidence arrays for the various test groups
-        cmin_g1 = np.ones(viirs_data.M11.shape)
-        cmin_g2 = np.ones(viirs_data.M11.shape)
-        cmin_g3 = np.ones(viirs_data.M11.shape)
-        cmin_g4 = np.ones(viirs_data.M11.shape)
-        cmin_g5 = np.ones(viirs_data.M11.shape)
+        cmin_g1 = np.ones(viirs_data.M11.shape, dtype=np.float32)
+        cmin_g2 = np.ones(viirs_data.M11.shape, dtype=np.float32)
+        cmin_g3 = np.ones(viirs_data.M11.shape, dtype=np.float32)
+        cmin_g4 = np.ones(viirs_data.M11.shape, dtype=np.float32)
+        cmin_g5 = np.ones(viirs_data.M11.shape, dtype=np.float32)
         # cmin_temp = np.ones(viirs_data.M11.shape)
 
         if use_hires is True:
@@ -570,6 +581,14 @@ def main(
         "M15-M16": {"dims": ("x", "y"), "data": viirs_data["M15-M16"].values},
         # "Land_Day": {"dims": ("x", "y"), "data": viirs_data.Land_Day.values},
         # "Land_Day_Desert": {"dims": ("x", "y"), "data": viirs_data.Land_Day_Desert.values},
+        "Ocean_Day": {"dims": ("x", "y"), "data": viirs_data.Ocean_Day.values},
+        "Polar_Day_Ocean": {"dims": ("x", "y"), "data": viirs_data.Polar_Day_Ocean.values},
+        "Polar_Day_Land": {"dims": ("x", "y"), "data": viirs_data.Polar_Day_Land.values},
+        "Polar_Day_Desert": {"dims": ("x", "y"), "data": viirs_data.Polar_Day_Desert.values},
+        "Polar_Day_Desert_Coast": {
+            "dims": ("x", "y"),
+            "data": viirs_data.Polar_Day_Desert_Coast.values,
+        },
         "elevation": {"dims": ("x", "y"), "data": viirs_data.height.values},
         "scene_type": {"dims": ("x", "y"), "data": scene_types},
         # 'thr': {'dims': ('x', 'y'),
diff --git a/mvcm/preprocess_thresholds.py b/mvcm/preprocess_thresholds.py
index 571f8ed..16066f1 100644
--- a/mvcm/preprocess_thresholds.py
+++ b/mvcm/preprocess_thresholds.py
@@ -1,14 +1,13 @@
 """Preprocessing thresholds module."""
 
-import logging
-from typing import Dict
+import logging  # noqa
 
 import numpy as np
 import xarray as xr
 from numpy.lib.stride_tricks import sliding_window_view
 
 import ancillary_data as anc
-import mvcm.utils as utils
+from mvcm import utils
 
 # _dtr = np.pi / 180
 _DTR = np.pi / 180
@@ -43,7 +42,7 @@ def prepare_11_12um_thresholds(thresholds: dict, dim1: int) -> dict:
 
 
 def thresholds_11_12um(
-    data: xr.Dataset, thresholds: Dict, scene: str, scene_idx: tuple
+    data: xr.Dataset, thresholds: dict, scene: str, scene_idx: tuple
 ) -> np.ndarray:
     """Compute 11-12um Test thresholds."""
     cosvza = np.cos(data.sensor_zenith.values[scene_idx].ravel() * _DTR)
@@ -132,7 +131,8 @@ def thresholds_NIR(data, thresholds, scene, test_name, scene_idx):
         band_n = 7
         vzcpow = thresholds["VZA_correction"]["vzcpow"][2]
     else:
-        raise ValueError("Test name not valid")
+        err_msg = "Test name not valid"
+        raise ValueError(err_msg)
 
     refang = data.sunglint_angle.values[scene_idx].ravel()
     sunglint_thresholds = thresholds["Sun_Glint"]
diff --git a/mvcm/scene.py b/mvcm/scene.py
index 6768f65..18ce837 100644
--- a/mvcm/scene.py
+++ b/mvcm/scene.py
@@ -1,4 +1,5 @@
 """Functions that define scene type."""
+
 import logging
 from typing import Dict
 
@@ -154,7 +155,7 @@ def find_scene(data: xr.Dataset, sunglint_angle: float) -> Dict:
     # tmp[day == 1] = day
     # tmp[day == 0] = night
 
-    scene_flag = {flg: np.zeros((dim1, dim2)) for flg in _flags}
+    scene_flag = {flg: np.zeros((dim1, dim2), dtype=np.int8) for flg in _flags}
 
     scene_flag["day"][sza <= 85] = 1
     scene_flag["visusd"][sza <= 85] = 1
@@ -172,8 +173,8 @@ def find_scene(data: xr.Dataset, sunglint_angle: float) -> Dict:
     eco[idx] = 14
 
     # start by defining everything as land
-    scene_flag["land"] = np.ones((dim1, dim2))
-    scene_flag["water"] = np.zeros((dim1, dim2))
+    scene_flag["land"] = np.ones((dim1, dim2), dtype=np.int8)
+    scene_flag["water"] = np.zeros((dim1, dim2), dtype=np.int8)
 
     # Fix-up for missing ecosystem data in eastern Greenland and north-eastern Siberia.
     # Without this, these regions become completely "coast".
@@ -292,7 +293,7 @@ def find_scene(data: xr.Dataset, sunglint_angle: float) -> Dict:
         | (eco == 71)
         | (eco == 50)
     )
-    scene_flag["vrused"] = np.ones((dim1, dim2))
+    scene_flag["vrused"] = np.ones((dim1, dim2), dtype=np.int8)
     scene_flag["vrused"][idx] = 0
 
     snow_fraction = data["geos_snow_fraction"]
@@ -417,12 +418,13 @@ def find_scene(data: xr.Dataset, sunglint_angle: float) -> Dict:
 def scene_id(scene_flag):
     """Define scene type."""
     dim1, dim2 = scene_flag["day"].shape[0], scene_flag["day"].shape[1]
-    scene = {scn: np.zeros((dim1, dim2)) for scn in _scene_list}
+    scene = {scn: np.zeros((dim1, dim2), dtype=np.int8) for scn in _scene_list}
 
     # Ocean Day
     idx = np.nonzero(
         (scene_flag["water"] == 1)
         & (scene_flag["day"] == 1)
+        & (scene_flag["polar"] == 0)
         & ((scene_flag["ice"] == 0) | (scene_flag["snow"] == 0))
     )  # &
     # (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
@@ -433,6 +435,7 @@ def scene_id(scene_flag):
     idx = np.nonzero(
         (scene_flag["water"] == 1)
         & (scene_flag["night"] == 1)
+        & (scene_flag["polar"] == 0)
         & ((scene_flag["ice"] == 0) | (scene_flag["snow"] == 0))
     )  # &
     # (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
-- 
GitLab