From c0726dfce447e87edf271fd3bcceb5efc639e1ca Mon Sep 17 00:00:00 2001
From: Paolo Veglio <paolo.veglio@ssec.wisc.edu>
Date: Thu, 26 Sep 2024 21:42:45 +0000
Subject: [PATCH] converted a couple tests to use xarray

---
 mvcm/conf.py                  | 104 +++++++++++++++++++++++
 mvcm/constants.py             |   1 +
 mvcm/main.py                  |   4 +-
 mvcm/preprocess_thresholds.py | 155 +++++++++++++++++++++++++++++++++-
 mvcm/read_data.py             |   4 +-
 tests/test_conf.py            |  16 +++-
 6 files changed, 279 insertions(+), 5 deletions(-)

diff --git a/mvcm/conf.py b/mvcm/conf.py
index d2b8676..2d5f040 100644
--- a/mvcm/conf.py
+++ b/mvcm/conf.py
@@ -3,10 +3,114 @@
 import logging
 
 import numpy as np
+import xarray as xr
 
 logger = logging.getLogger(__name__)
 
 
+def asymm_confidence(rad: xr.DataArray, thr: np.ndarray) -> xr.DataArray:
+    """Compute confidence based on thresholds."""
+    # NOTE: this only works for the linear case
+
+    confidence = xr.ones_like(rad, dtype=np.float32)
+
+    # First condition:
+    # need to divide by 2 because I want max value to be 0.5 at the midpoint
+    # Second condition:
+    # likewise, splitting the interval so my range is between 0 and 0.5,
+    # then I add 0.5 to shift it to the right place
+    if thr.ndim == 1:
+        thr = thr[:-1]
+        confidence = xr.where(
+            rad <= thr[1],
+            (np.clip(rad, np.min(thr), thr[1]) - thr[0]) / ((thr[1] - thr[0]) * 2),
+            (np.clip(rad, thr[1], np.max(thr)) - thr[1]) / ((thr[2] - thr[1]) * 2) + 0.5,
+        )
+    else:
+        thr = thr[:, :, :-1]
+        confidence = xr.where(
+            rad <= thr[:, :, 1],
+            (np.clip(rad, np.min(thr, axis=2), thr[:, :, 1]) - thr[:, :, 0])
+            / ((thr[:, :, 1] - thr[:, :, 0]) * 2),
+            (np.clip(rad, thr[:, :, 1], np.max(thr, axis=2)) - thr[:, :, 1])
+            / ((thr[:, :, 2] - thr[:, :, 1]) * 2)
+            + 0.5,
+        )
+
+    return np.abs(confidence)
+
+
+def conf_test_xr(rad: xr.DataArray, thr: np.ndarray) -> xr.DataArray:
+    """Compute confidence based on thresholds."""
+    """Assuming a linear function between min and max confidence level, the plot below shows
+    how the confidence (y axis) is computed as function of radiance (x axis).
+    This case illustrates alpha < gamma, obviously in case alpha > gamma, the plot would be
+    flipped.
+                       gamma
+    c  1                 ________
+    o  |                /
+    n  |               /
+    f  |              /
+    i  |     beta    /
+    d 1/2    |....../
+    e  |           /
+    n  |          /
+    c  |         /
+    e  0________/
+       |      alpha
+    --------- radiance ---------->
+    """
+
+    radshape = rad.shape
+    rad = rad.ravel()
+    thr = np.array(thr)
+    if thr.ndim == 1:
+        thr = np.full((rad.shape[0], 4), thr[:4]).T
+
+    coeff = np.power(2, (thr[3, :] - 1))
+    locut = thr[0, :]
+    beta = thr[1, :]
+    hicut = thr[2, :]
+    power = thr[3, :]
+    confidence = np.zeros(rad.shape)
+
+    alpha, gamma = np.empty(rad.shape), np.empty(rad.shape)
+    flipped = np.zeros(rad.shape)
+
+    gamma[hicut > locut] = thr[2, hicut > locut]
+    alpha[hicut > locut] = thr[0, hicut > locut]
+    flipped[hicut > locut] = 0
+    gamma[hicut < locut] = thr[0, hicut < locut]
+    alpha[hicut < locut] = thr[2, hicut < locut]
+    flipped[hicut < locut] = 1
+
+    # Rad between alpha and beta
+    range_ = 2.0 * (beta - alpha)
+    s1 = (rad - alpha) / range_
+    idx = np.nonzero((rad <= beta) & (flipped == 0))
+    confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx])
+    idx = np.nonzero((rad <= beta) & (flipped == 1))
+    confidence[idx] = 1.0 - (coeff[idx] * np.power(s1[idx], power[idx]))
+
+    # Rad between beta and gamma
+    range_ = 2.0 * (beta - gamma)
+    s1 = (rad - gamma) / range_
+    idx = np.nonzero((rad > beta) & (flipped == 0))
+    confidence[idx] = 1.0 - (coeff[idx] * np.power(s1[idx], power[idx]))
+    idx = np.nonzero((rad > beta) & (flipped == 1))
+    confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx])
+
+    # Rad outside alpha-gamma interval
+    confidence[(rad > gamma) & (flipped is False)] = 1
+    confidence[(rad < alpha) & (flipped is False)] = 0
+    confidence[(rad > gamma) & (flipped is True)] = 0
+    confidence[(rad < alpha) & (flipped is True)] = 1
+
+    confidence = np.clip(confidence, 0, 1)
+
+    return confidence.reshape(radshape)
+
+
 def conf_test_new(rad: np.ndarray, thr: np.ndarray) -> np.ndarray:
     """Compute confidence based on thresholds."""
     """Assuming a linear function between min and max confidence level, the plot below shows
diff --git a/mvcm/constants.py b/mvcm/constants.py
index 1b6e101..ce4d986 100644
--- a/mvcm/constants.py
+++ b/mvcm/constants.py
@@ -30,6 +30,7 @@ class SensorConstants:
         "M03",
         "M04",
         "M05",
+        "M06",
         "M07",
         "M08",
         "M09",
diff --git a/mvcm/main.py b/mvcm/main.py
index 0e17ea5..8d0e1ca 100644
--- a/mvcm/main.py
+++ b/mvcm/main.py
@@ -13,7 +13,7 @@ from rich.logging import RichHandler
 from ruamel.yaml import YAML
 
 import mvcm.read_data as rd
-import mvcm.spectral_tests as tst
+import mvcm.temp_spectral_tests as tst
 import mvcm.utility_functions as utils
 from mvcm.constants import SceneConstants, SensorConstants
 from mvcm.restoral import Restoral
@@ -241,6 +241,8 @@ def main(
         # cmin_g3 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto")
         # cmin_g4 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto")
         # cmin_g5 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto")
+        cmin_temp = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto")
+
         cmin_g1 = np.ones(viirs_data.latitude.shape, dtype=np.float32)
         cmin_g2 = np.ones(viirs_data.latitude.shape, dtype=np.float32)
         cmin_g3 = np.ones(viirs_data.latitude.shape, dtype=np.float32)
diff --git a/mvcm/preprocess_thresholds.py b/mvcm/preprocess_thresholds.py
index a1c1475..8339b0e 100644
--- a/mvcm/preprocess_thresholds.py
+++ b/mvcm/preprocess_thresholds.py
@@ -2,12 +2,13 @@
 
 import logging  # noqa
 
+import ancillary_data as anc
 import numpy as np
 import xarray as xr
 from numpy.lib.stride_tricks import sliding_window_view
 
-import ancillary_data as anc
 from mvcm import utils
+from mvcm.constants import ConstantsNamespace
 
 # _dtr = np.pi / 180
 _DTR = np.pi / 180
@@ -41,6 +42,121 @@ def prepare_11_12um_thresholds(thresholds: dict, dim1: int) -> dict:
     return thr_dict
 
 
+def prepare_11_12um_thresholds_2d(thresholds: dict, dim: int) -> dict:
+    """Prepare 11-12um test thresholds."""
+    coeff_values = np.empty((dim, 2))
+    coeff_values[:, 0] = np.full(dim, thresholds["coeffs"][0])
+    coeff_values[:, 1] = np.full(dim, thresholds["coeffs"][1])
+    cmult_values = np.full(dim, thresholds["cmult"])
+    adj_values = np.full(dim, thresholds["adj"])
+    if "bt1" in list(thresholds):
+        bt1 = np.full(dim, thresholds["bt1"])
+    else:
+        bt1 = np.full(dim, -999)
+    if "lat" in list(thresholds):
+        lat = np.full(dim, thresholds["lat"])
+    else:
+        lat = np.full(dim, -999)
+    thr_dict = {
+        "coeffs": coeff_values,
+        "cmult": cmult_values,
+        "adj": adj_values,
+        "bt1": bt1,
+        "lat": lat,
+    }
+
+    return thr_dict
+
+
+def xr_thresholds_11_12um(data: xr.Dataset, thresholds: dict, scene: str) -> np.ndarray:
+    """Compute 11-12um Test thresholds."""
+    cosvza = np.cos(data.sensor_zenith.values.ravel() * ConstantsNamespace.DTR)
+    schi = np.full(cosvza.shape, 99.0)
+    schi[np.abs(cosvza) > 0] = 1.0 / cosvza[np.abs(cosvza) > 0]
+    schi = np.array(schi, dtype=np.float32)  # this is because the C function expects a float
+
+    m15 = np.array(data.M15.values.ravel(), dtype=np.float32)
+    latitude = data.latitude.values.ravel()
+    thr = anc.py_cithr(1, schi, m15)
+    thr_dict = prepare_11_12um_thresholds(thresholds, m15.shape[0])
+
+    midpt = np.full(m15.shape, thr)
+    if scene in ["Day_Snow"]:
+        midpt = np.full(m15.shape, thr) + thr_dict["adj"]
+    idx = np.nonzero((thr < 0.1) | (np.abs(schi - 99.0) < 0.0001))
+    midpt[idx] = thr_dict["coeffs"][idx, 0]
+    locut = midpt + (thr_dict["cmult"] * midpt)
+
+    if scene in [
+        "Land_Day",
+        "Land_Day_Coast",
+        "Land_Day_Desert",
+        "Land_Day_Desert_Coast",
+        "Ocean_Day",
+        "Ocean_Night",
+        "Polar_Day_Ocean",
+        "Polar_Night_Ocean",
+    ]:
+        hicut = midpt - thr_dict["adj"]
+
+    elif scene in [
+        "Polar_Day_Land",
+        "Polar_Day_Coast",
+        "Polar_Day_Desert",
+        "Polar_Day_Desert_Coast",
+        "Polar_Day_Snow",
+    ]:
+        hicut = midpt - (thr_dict["adj"] * midpt)
+
+    elif scene in [
+        "Land_Night",
+        "Polar_Night_Land",
+        "Polar_Night_Snow",
+        "Day_Snow",
+        "Night_Snow",
+    ]:
+        _coeffs = {
+            "Land_Night": 0.3,
+            "Polar_Night_Land": 0.3,
+            "Polar_Night_Snow": 0.3,
+            "Day_Snow": 0.8,  # this is not needed anymore, replaced by "adj" in thresholds file
+            "Night_Snow": 0.3,
+        }
+
+        if scene in ["Day_Snow"]:
+            hicut = midpt - (thr_dict["cmult"] * midpt)
+        else:
+            # This needs to be rewritten properly
+            midpt = midpt - (_coeffs[scene] * locut)
+
+        if scene in ["Polar_Night_Land", "Polar_Night_Snow", "Night_Snow"]:
+            hicut = np.full(m15.shape, midpt - 1.25)
+            idx = np.nonzero(m15 < thr_dict["bt1"])
+            hicut[idx] = midpt[idx] - (0.2 * locut[idx])
+        elif scene in ["Land_Night"]:
+            hicut = np.full(m15.shape, 1.25)
+            idx = np.nonzero((m15 < thr_dict["bt1"]) & (latitude > thr_dict["lat"]))
+            hicut[idx] = -0.1 - np.power(90.0 - np.abs(latitude[idx]) / 60, 4) * 1.15
+        else:
+            pass
+            # err_msg = "Scene name not valid"
+            # raise ValueError(err_msg)
+    else:
+        err_msg = "Scene name not valid"
+        raise ValueError(err_msg)
+
+    thr_out = np.dstack(
+        (
+            locut.reshape(data.latitude.shape),
+            midpt.reshape(data.latitude.shape),
+            hicut.reshape(data.latitude.shape),
+            np.ones(data.latitude.shape),
+            np.ones(data.latitude.shape),
+        )
+    )
+    return np.squeeze(thr_out)
+
+
 def thresholds_11_12um(
     data: xr.Dataset, thresholds: dict, scene: str, scene_idx: tuple
 ) -> np.ndarray:
@@ -718,6 +834,43 @@ def thresholds_1_38um_test(data, thresholds, scene_name, scene_idx):
     return np.squeeze(out_thr.T)
 
 
+def xr_thresholds_1_38um_test(data: xr.Dataset, thresholds: dict, scene_name: str):
+    """Compute thresholds for 1.38um test."""
+    thresh = thresholds[scene_name]["1.38um_High_Cloud_Test"]
+    c = thresh["coeffs"]
+    vzcpow = thresholds["VZA_correction"]["vzcpow"][1]
+
+    hicut = (
+        c[0]
+        + c[1] * data.solar_zenith.values
+        + c[2] * np.power(data.solar_zenith.values, 2)
+        + c[3] * np.power(data.solar_zenith.values, 3)
+        + c[4] * np.power(data.solar_zenith.values, 4)
+        + c[5] * np.power(data.solar_zenith.values, 5)
+    )
+    hicut = hicut * 0.01 + (
+        np.maximum(
+            (data.solar_zenith.values / 90.0) * thresh["szafac"] * thresh["adj"], thresh["adj"]
+        )
+    )
+    midpt = hicut + 0.001
+    locut = midpt + 0.001
+
+    cosvza = np.cos(data.sensor_zenith.values * ConstantsNamespace.DTR)
+
+    locut = locut * (1 / np.power(cosvza, vzcpow))
+    midpt = midpt * (1 / np.power(cosvza, vzcpow))
+    hicut = hicut * (1 / np.power(cosvza, vzcpow))
+
+    # out_thr = xr.DataArray(data=np.dstack((locut, midpt, hicut,
+    #                                        np.ones(data.ndvi.shape),
+    #                                        np.ones(data.ndvi.shape))),
+    #                        dims=('number_of_lines', 'number_of_pixels', 'z'))
+    out_thr = np.stack([locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)], axis=2)
+
+    return out_thr
+
+
 # NOTE: 11-12um Cirrus Test
 # hicut is computed in different ways depending on the scene
 # 1. midpt - adj
diff --git a/mvcm/read_data.py b/mvcm/read_data.py
index 878c9f6..704e1b2 100644
--- a/mvcm/read_data.py
+++ b/mvcm/read_data.py
@@ -4,13 +4,13 @@ import logging  # noqa - pyright complains about imports not being sorted but py
 import os
 from datetime import datetime as dt
 
+import ancillary_data as anc
 import numpy as np
 import numpy.typing as npt
 import psutil
 import xarray as xr
 from attrs import Factory, define, field, validators
 
-import ancillary_data as anc
 import mvcm.scene as scn
 from mvcm.constants import ConstantsNamespace, SensorConstants
 
@@ -268,7 +268,7 @@ class ReadData(CollectInputs):
                 "M02",
                 "M03",
                 "M04",
-                # "M06",
+                "M06",
                 "M08",
                 "M09",
                 "M11",
diff --git a/tests/test_conf.py b/tests/test_conf.py
index 6da43a3..9539cf8 100644
--- a/tests/test_conf.py
+++ b/tests/test_conf.py
@@ -2,10 +2,11 @@
 
 import os  # noqa
 
+import ancillary_data as anc
 import numpy as np
 import pytest
+import xarray as xr
 
-import ancillary_data as anc
 from mvcm import conf
 
 
@@ -88,3 +89,16 @@ def test_c_single_threshold(rad, single_threshold, reference_data):
 
     c = anc.py_conf_test(np.array(rad, dtype=np.float32), locut, hicut, power, midpt)
     assert np.allclose(c, ref_confidence_flipped)
+
+
+def test_xr_single_threshold(rad, single_threshold, reference_data):
+    """Test xarray version of single threshold computation."""
+    ref_confidence = np.load(reference_data)["ref_sgl"]
+    ref_confidence_flipped = np.load(reference_data)["ref_sgl_flipped"]
+
+    c = conf.asymm_confidence(xr.DataArray(rad), np.array(single_threshold))
+    assert np.allclose(c.values, ref_confidence)
+
+    single_threshold[0:-1] = single_threshold[-2::-1]
+    c = conf.asymm_confidence(xr.DataArray(rad), np.array(single_threshold))
+    assert np.allclose(c.values, ref_confidence_flipped)
-- 
GitLab