From c0726dfce447e87edf271fd3bcceb5efc639e1ca Mon Sep 17 00:00:00 2001 From: Paolo Veglio <paolo.veglio@ssec.wisc.edu> Date: Thu, 26 Sep 2024 21:42:45 +0000 Subject: [PATCH] converted a couple tests to use xarray --- mvcm/conf.py | 104 +++++++++++++++++++++++ mvcm/constants.py | 1 + mvcm/main.py | 4 +- mvcm/preprocess_thresholds.py | 155 +++++++++++++++++++++++++++++++++- mvcm/read_data.py | 4 +- tests/test_conf.py | 16 +++- 6 files changed, 279 insertions(+), 5 deletions(-) diff --git a/mvcm/conf.py b/mvcm/conf.py index d2b8676..2d5f040 100644 --- a/mvcm/conf.py +++ b/mvcm/conf.py @@ -3,10 +3,114 @@ import logging import numpy as np +import xarray as xr logger = logging.getLogger(__name__) +def asymm_confidence(rad: xr.DataArray, thr: np.ndarray) -> xr.DataArray: + """Compute confidence based on thresholds.""" + # NOTE: this only works for the linear case + + confidence = xr.ones_like(rad, dtype=np.float32) + + # First condition: + # need to divide by 2 because I want max value to be 0.5 at the midpoint + # Second condition: + # likewise, splitting the interval so my range is between 0 and 0.5, + # then I add 0.5 to shift it to the right place + if thr.ndim == 1: + thr = thr[:-1] + confidence = xr.where( + rad <= thr[1], + (np.clip(rad, np.min(thr), thr[1]) - thr[0]) / ((thr[1] - thr[0]) * 2), + (np.clip(rad, thr[1], np.max(thr)) - thr[1]) / ((thr[2] - thr[1]) * 2) + 0.5, + ) + else: + thr = thr[:, :, :-1] + confidence = xr.where( + rad <= thr[:, :, 1], + (np.clip(rad, np.min(thr, axis=2), thr[:, :, 1]) - thr[:, :, 0]) + / ((thr[:, :, 1] - thr[:, :, 0]) * 2), + (np.clip(rad, thr[:, :, 1], np.max(thr, axis=2)) - thr[:, :, 1]) + / ((thr[:, :, 2] - thr[:, :, 1]) * 2) + + 0.5, + ) + + return np.abs(confidence) + + +def conf_test_xr(rad: xr.DataArray, thr: np.ndarray) -> xr.DataArray: + """Compute confidence based on thresholds.""" + """Assuming a linear function between min and max confidence level, the plot below shows + how the confidence (y axis) is computed as function of radiance (x axis). + This case illustrates alpha < gamma, obviously in case alpha > gamma, the plot would be + flipped. + gamma + c 1 ________ + o | / + n | / + f | / + i | beta / + d 1/2 |....../ + e | / + n | / + c | / + e 0________/ + | alpha + --------- radiance ----------> + """ + + radshape = rad.shape + rad = rad.ravel() + thr = np.array(thr) + if thr.ndim == 1: + thr = np.full((rad.shape[0], 4), thr[:4]).T + + coeff = np.power(2, (thr[3, :] - 1)) + locut = thr[0, :] + beta = thr[1, :] + hicut = thr[2, :] + power = thr[3, :] + confidence = np.zeros(rad.shape) + + alpha, gamma = np.empty(rad.shape), np.empty(rad.shape) + flipped = np.zeros(rad.shape) + + gamma[hicut > locut] = thr[2, hicut > locut] + alpha[hicut > locut] = thr[0, hicut > locut] + flipped[hicut > locut] = 0 + gamma[hicut < locut] = thr[0, hicut < locut] + alpha[hicut < locut] = thr[2, hicut < locut] + flipped[hicut < locut] = 1 + + # Rad between alpha and beta + range_ = 2.0 * (beta - alpha) + s1 = (rad - alpha) / range_ + idx = np.nonzero((rad <= beta) & (flipped == 0)) + confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx]) + idx = np.nonzero((rad <= beta) & (flipped == 1)) + confidence[idx] = 1.0 - (coeff[idx] * np.power(s1[idx], power[idx])) + + # Rad between beta and gamma + range_ = 2.0 * (beta - gamma) + s1 = (rad - gamma) / range_ + idx = np.nonzero((rad > beta) & (flipped == 0)) + confidence[idx] = 1.0 - (coeff[idx] * np.power(s1[idx], power[idx])) + idx = np.nonzero((rad > beta) & (flipped == 1)) + confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx]) + + # Rad outside alpha-gamma interval + confidence[(rad > gamma) & (flipped is False)] = 1 + confidence[(rad < alpha) & (flipped is False)] = 0 + confidence[(rad > gamma) & (flipped is True)] = 0 + confidence[(rad < alpha) & (flipped is True)] = 1 + + confidence = np.clip(confidence, 0, 1) + + return confidence.reshape(radshape) + + def conf_test_new(rad: np.ndarray, thr: np.ndarray) -> np.ndarray: """Compute confidence based on thresholds.""" """Assuming a linear function between min and max confidence level, the plot below shows diff --git a/mvcm/constants.py b/mvcm/constants.py index 1b6e101..ce4d986 100644 --- a/mvcm/constants.py +++ b/mvcm/constants.py @@ -30,6 +30,7 @@ class SensorConstants: "M03", "M04", "M05", + "M06", "M07", "M08", "M09", diff --git a/mvcm/main.py b/mvcm/main.py index 0e17ea5..8d0e1ca 100644 --- a/mvcm/main.py +++ b/mvcm/main.py @@ -13,7 +13,7 @@ from rich.logging import RichHandler from ruamel.yaml import YAML import mvcm.read_data as rd -import mvcm.spectral_tests as tst +import mvcm.temp_spectral_tests as tst import mvcm.utility_functions as utils from mvcm.constants import SceneConstants, SensorConstants from mvcm.restoral import Restoral @@ -241,6 +241,8 @@ def main( # cmin_g3 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") # cmin_g4 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") # cmin_g5 = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") + cmin_temp = xr.ones_like(viirs_data.latitude, dtype=np.float32, chunks="auto") + cmin_g1 = np.ones(viirs_data.latitude.shape, dtype=np.float32) cmin_g2 = np.ones(viirs_data.latitude.shape, dtype=np.float32) cmin_g3 = np.ones(viirs_data.latitude.shape, dtype=np.float32) diff --git a/mvcm/preprocess_thresholds.py b/mvcm/preprocess_thresholds.py index a1c1475..8339b0e 100644 --- a/mvcm/preprocess_thresholds.py +++ b/mvcm/preprocess_thresholds.py @@ -2,12 +2,13 @@ import logging # noqa +import ancillary_data as anc import numpy as np import xarray as xr from numpy.lib.stride_tricks import sliding_window_view -import ancillary_data as anc from mvcm import utils +from mvcm.constants import ConstantsNamespace # _dtr = np.pi / 180 _DTR = np.pi / 180 @@ -41,6 +42,121 @@ def prepare_11_12um_thresholds(thresholds: dict, dim1: int) -> dict: return thr_dict +def prepare_11_12um_thresholds_2d(thresholds: dict, dim: int) -> dict: + """Prepare 11-12um test thresholds.""" + coeff_values = np.empty((dim, 2)) + coeff_values[:, 0] = np.full(dim, thresholds["coeffs"][0]) + coeff_values[:, 1] = np.full(dim, thresholds["coeffs"][1]) + cmult_values = np.full(dim, thresholds["cmult"]) + adj_values = np.full(dim, thresholds["adj"]) + if "bt1" in list(thresholds): + bt1 = np.full(dim, thresholds["bt1"]) + else: + bt1 = np.full(dim, -999) + if "lat" in list(thresholds): + lat = np.full(dim, thresholds["lat"]) + else: + lat = np.full(dim, -999) + thr_dict = { + "coeffs": coeff_values, + "cmult": cmult_values, + "adj": adj_values, + "bt1": bt1, + "lat": lat, + } + + return thr_dict + + +def xr_thresholds_11_12um(data: xr.Dataset, thresholds: dict, scene: str) -> np.ndarray: + """Compute 11-12um Test thresholds.""" + cosvza = np.cos(data.sensor_zenith.values.ravel() * ConstantsNamespace.DTR) + schi = np.full(cosvza.shape, 99.0) + schi[np.abs(cosvza) > 0] = 1.0 / cosvza[np.abs(cosvza) > 0] + schi = np.array(schi, dtype=np.float32) # this is because the C function expects a float + + m15 = np.array(data.M15.values.ravel(), dtype=np.float32) + latitude = data.latitude.values.ravel() + thr = anc.py_cithr(1, schi, m15) + thr_dict = prepare_11_12um_thresholds(thresholds, m15.shape[0]) + + midpt = np.full(m15.shape, thr) + if scene in ["Day_Snow"]: + midpt = np.full(m15.shape, thr) + thr_dict["adj"] + idx = np.nonzero((thr < 0.1) | (np.abs(schi - 99.0) < 0.0001)) + midpt[idx] = thr_dict["coeffs"][idx, 0] + locut = midpt + (thr_dict["cmult"] * midpt) + + if scene in [ + "Land_Day", + "Land_Day_Coast", + "Land_Day_Desert", + "Land_Day_Desert_Coast", + "Ocean_Day", + "Ocean_Night", + "Polar_Day_Ocean", + "Polar_Night_Ocean", + ]: + hicut = midpt - thr_dict["adj"] + + elif scene in [ + "Polar_Day_Land", + "Polar_Day_Coast", + "Polar_Day_Desert", + "Polar_Day_Desert_Coast", + "Polar_Day_Snow", + ]: + hicut = midpt - (thr_dict["adj"] * midpt) + + elif scene in [ + "Land_Night", + "Polar_Night_Land", + "Polar_Night_Snow", + "Day_Snow", + "Night_Snow", + ]: + _coeffs = { + "Land_Night": 0.3, + "Polar_Night_Land": 0.3, + "Polar_Night_Snow": 0.3, + "Day_Snow": 0.8, # this is not needed anymore, replaced by "adj" in thresholds file + "Night_Snow": 0.3, + } + + if scene in ["Day_Snow"]: + hicut = midpt - (thr_dict["cmult"] * midpt) + else: + # This needs to be rewritten properly + midpt = midpt - (_coeffs[scene] * locut) + + if scene in ["Polar_Night_Land", "Polar_Night_Snow", "Night_Snow"]: + hicut = np.full(m15.shape, midpt - 1.25) + idx = np.nonzero(m15 < thr_dict["bt1"]) + hicut[idx] = midpt[idx] - (0.2 * locut[idx]) + elif scene in ["Land_Night"]: + hicut = np.full(m15.shape, 1.25) + idx = np.nonzero((m15 < thr_dict["bt1"]) & (latitude > thr_dict["lat"])) + hicut[idx] = -0.1 - np.power(90.0 - np.abs(latitude[idx]) / 60, 4) * 1.15 + else: + pass + # err_msg = "Scene name not valid" + # raise ValueError(err_msg) + else: + err_msg = "Scene name not valid" + raise ValueError(err_msg) + + thr_out = np.dstack( + ( + locut.reshape(data.latitude.shape), + midpt.reshape(data.latitude.shape), + hicut.reshape(data.latitude.shape), + np.ones(data.latitude.shape), + np.ones(data.latitude.shape), + ) + ) + return np.squeeze(thr_out) + + def thresholds_11_12um( data: xr.Dataset, thresholds: dict, scene: str, scene_idx: tuple ) -> np.ndarray: @@ -718,6 +834,43 @@ def thresholds_1_38um_test(data, thresholds, scene_name, scene_idx): return np.squeeze(out_thr.T) +def xr_thresholds_1_38um_test(data: xr.Dataset, thresholds: dict, scene_name: str): + """Compute thresholds for 1.38um test.""" + thresh = thresholds[scene_name]["1.38um_High_Cloud_Test"] + c = thresh["coeffs"] + vzcpow = thresholds["VZA_correction"]["vzcpow"][1] + + hicut = ( + c[0] + + c[1] * data.solar_zenith.values + + c[2] * np.power(data.solar_zenith.values, 2) + + c[3] * np.power(data.solar_zenith.values, 3) + + c[4] * np.power(data.solar_zenith.values, 4) + + c[5] * np.power(data.solar_zenith.values, 5) + ) + hicut = hicut * 0.01 + ( + np.maximum( + (data.solar_zenith.values / 90.0) * thresh["szafac"] * thresh["adj"], thresh["adj"] + ) + ) + midpt = hicut + 0.001 + locut = midpt + 0.001 + + cosvza = np.cos(data.sensor_zenith.values * ConstantsNamespace.DTR) + + locut = locut * (1 / np.power(cosvza, vzcpow)) + midpt = midpt * (1 / np.power(cosvza, vzcpow)) + hicut = hicut * (1 / np.power(cosvza, vzcpow)) + + # out_thr = xr.DataArray(data=np.dstack((locut, midpt, hicut, + # np.ones(data.ndvi.shape), + # np.ones(data.ndvi.shape))), + # dims=('number_of_lines', 'number_of_pixels', 'z')) + out_thr = np.stack([locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)], axis=2) + + return out_thr + + # NOTE: 11-12um Cirrus Test # hicut is computed in different ways depending on the scene # 1. midpt - adj diff --git a/mvcm/read_data.py b/mvcm/read_data.py index 878c9f6..704e1b2 100644 --- a/mvcm/read_data.py +++ b/mvcm/read_data.py @@ -4,13 +4,13 @@ import logging # noqa - pyright complains about imports not being sorted but py import os from datetime import datetime as dt +import ancillary_data as anc import numpy as np import numpy.typing as npt import psutil import xarray as xr from attrs import Factory, define, field, validators -import ancillary_data as anc import mvcm.scene as scn from mvcm.constants import ConstantsNamespace, SensorConstants @@ -268,7 +268,7 @@ class ReadData(CollectInputs): "M02", "M03", "M04", - # "M06", + "M06", "M08", "M09", "M11", diff --git a/tests/test_conf.py b/tests/test_conf.py index 6da43a3..9539cf8 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -2,10 +2,11 @@ import os # noqa +import ancillary_data as anc import numpy as np import pytest +import xarray as xr -import ancillary_data as anc from mvcm import conf @@ -88,3 +89,16 @@ def test_c_single_threshold(rad, single_threshold, reference_data): c = anc.py_conf_test(np.array(rad, dtype=np.float32), locut, hicut, power, midpt) assert np.allclose(c, ref_confidence_flipped) + + +def test_xr_single_threshold(rad, single_threshold, reference_data): + """Test xarray version of single threshold computation.""" + ref_confidence = np.load(reference_data)["ref_sgl"] + ref_confidence_flipped = np.load(reference_data)["ref_sgl_flipped"] + + c = conf.asymm_confidence(xr.DataArray(rad), np.array(single_threshold)) + assert np.allclose(c.values, ref_confidence) + + single_threshold[0:-1] = single_threshold[-2::-1] + c = conf.asymm_confidence(xr.DataArray(rad), np.array(single_threshold)) + assert np.allclose(c.values, ref_confidence_flipped) -- GitLab