Skip to content
Snippets Groups Projects
tests.py 11.84 KiB
import numpy as np
import xarray as xr

from numpy.lib.stride_tricks import sliding_window_view

import utils
import conf
import conf_xr
import preprocess_thresholds as pt

import importlib

# this is used for testing, eventually we want to remove it
importlib.reload(pt)


# ############## GROUP 1 TESTS ############## #

def test_11um(rad, threshold):

    radshape = rad.shape
    rad = rad.reshape(np.prod(radshape))

    thr = np.array(threshold['bt11'])
    confidence = np.zeros(rad.shape)

    if thr[4] == 1:
        print("11um test running")
        # the C code has the line below that I don't quite understand the purpose of.
        # It seems to be setting the bit to 0 if the BT value is greater than the midpoint
        #
        # if (m31 >= dobt11[1]) (void) set_bit(13, pxout.testbits);

        # confidence = utils.conf_test(rad, thr)
        confidence = conf.conf_test(rad, thr)

    return confidence.reshape(radshape)


def test_11um_var(rad, threshold, var_threshold):

    print("11um variability test running")
    thr = np.array(threshold['11um_var'])

    radshape = rad.shape
    var = np.zeros((radshape[0], radshape[1], 9))

    # chk_spatial2() need to figure out what this is
    # np = rg_var.num_small_diffs * 1.0
    test = sliding_window_view(np.pad(rad, [1, 1], mode='constant'), (3, 3)) - np.expand_dims(rad, (2, 3))

    var[np.abs(test).reshape(radshape[0], radshape[1], 9) < var_threshold['dovar11']] = 1
    var = var.sum(axis=2).reshape(np.prod(radshape))

    rad = rad.reshape(np.prod(radshape))
    confidence = np.zeros(rad.shape)

    confidence[var == 9] = conf.conf_test(rad[var == 9], thr)

    return confidence.reshape(radshape)


def test_11_4diff(rad1, rad2, threshold, viirs_data, sg_thresh):

    print("11um - 4um difference test running")
    radshape = rad1.shape
    raddiff = (rad1 - rad2).reshape(np.prod(radshape))

    day = np.zeros(radshape)
    day[viirs_data.solar_zenith <= 85] = 1
    day = day.reshape(raddiff.shape)
    sunglint = np.zeros(rad1.shape)
    sunglint[viirs_data.sunglint_angle <= sg_thresh] = 1
    sunglint = sunglint.reshape(raddiff.shape)
    thr = np.array(threshold['test11_4lo'])
    confidence = np.zeros(raddiff.shape)

    # confidence[(day == 1) & (sunglint == 0)] = utils.conf_test(raddiff[(day == 1) & (sunglint == 0)], thr)
    confidence[(day == 1) & (sunglint == 0)] = conf.conf_test(raddiff[(day == 1) & (sunglint == 0)], thr)

    return confidence.reshape(radshape)


def vir_refl_test(rad, threshold, viirs_data):

    print('Visible reflectance test running')

    thr = threshold['vis_refl_test']

    radshape = rad.shape()
    rad = rad.reshape(np.prod(radshape))
    confidence = np.zeros(radshape)
    vzcpow = 0.75  # THIS NEEDS TO BE READ FROM THE THRESHOLDS FILE

    vza = viirs_data.sensor_zenith.values
    dtr = np.pi/180
    cosvza = np.cos(vza*dtr)

    coeffs = utils.get_b1_thresholds()
    coeffs[:, :3] = coeffs[:, :3] * threshold['b1_bias_adj']

    # this quantity is the return of get_b1_thresholds() in the C code
    # it's defined here to keep a consistent logic with the original source, for now
    irtn = 0

    if irtn != 0:
        coeffs = thr

    coeffs[:, :3] = coeffs[:, :3] * 1/np.power(cosvza, vzcpow)

    confidence = conf.conf_test(rad, coeffs)

    return confidence.reshape(radshape)


def nir_refl_test(rad, threshold, sunglint_thresholds, viirs_data):

    print("NIR reflectance test running")
    sza = viirs_data.solar_zenith.values
    refang = viirs_data.sunglint_angle.values
    vza = viirs_data.sensor_zenith.values
    dtr = np.pi/180
    # Keep in mind that band_n uses MODIS band numbers (i.e. 2=0.86um and 7=2.1um)
    # For VIIRS would be 2=M07 (0.865um) and 7=M11 (2.25um)
    band_n = 2
    vzcpow = 0.75  # THIS NEEDS TO BE READ FROM THE THRESHOLDS FILE

    radshape = rad.shape
    rad = rad.reshape(np.prod(radshape))
    confidence = np.zeros(rad.shape)
    sza = sza.reshape(rad.shape)
    vza = vza.reshape(rad.shape)
    refang = refang.reshape(rad.shape)
    sunglint_flag = utils.sunglint_scene(refang, sunglint_thresholds).reshape(rad.shape)

    # ref2 [5]
    # b2coeffs [4]
    # b2mid [1]
    # b2bias_adj [1]
    # b2lo [1]
    # vzcpow [3] (in different place)

    cosvza = np.cos(vza*dtr)
    coeffs = threshold['b2coeffs']
    hicut0 = np.array(coeffs[0] + coeffs[1]*sza + coeffs[2]*np.power(sza, 2) + coeffs[3]*np.power(sza, 3))
    hicut0 = (hicut0 * 0.01) + threshold['b2adj']
    hicut0 = hicut0 * threshold['b2bias_adj']
    midpt0 = hicut0 + (threshold['b2mid'] * threshold['b2bias_adj'])
    locut0 = midpt0 + (threshold['b2lo'] * threshold['b2bias_adj'])
    thr = np.array([locut0, midpt0, hicut0, threshold['ref2'][3]*np.ones(rad.shape)])
    # corr_thr = np.zeros((4, 4))
    corr_thr = np.zeros((4, rad.shape[0]))

    corr_thr[:3, sunglint_flag == 0] = thr[:3, sunglint_flag == 0] * (1./np.power(cosvza[sunglint_flag == 0], vzcpow))
    corr_thr[3, sunglint_flag == 0] = thr[3, sunglint_flag == 0]

    for flag in range(1, 4):
        if len(refang[sunglint_flag == flag]) > 0:
            sunglint_thr = utils.get_sunglint_thresholds(refang, sunglint_thresholds, band_n, flag, thr)
            corr_thr[:3, sunglint_flag == flag] = sunglint_thr[:3, sunglint_flag == flag] * (1./np.power(cosvza[sunglint_flag == flag], vzcpow))
            corr_thr[3, sunglint_flag == flag] = sunglint_thr[3, sunglint_flag == flag]

    confidence = conf.conf_test(rad, corr_thr)

    return confidence.reshape(radshape)


def vis_nir_ratio_test(rad1, rad2, threshold, sg_threshold):
    print("NIR-Visible ratio test running")
    if threshold['vis_nir_ratio'][6] == 1:

        radshape = rad1.shape
        rad1 = rad1.reshape(np.prod(radshape))
        rad2 = rad2.reshape(np.prod(radshape))

        vrat = rad2/rad1

        thresh = np.zeros((7,))

        # temp value to avoid linter bitching at me
        # eventually we would have the test run in two blocks as:
        # confidence[sunglint == 1] = conf.conf_test_dble(vrat[sunglint == 1], sg_threshold['snglnt'])
        # confidence[sunglint == 0] = conf.conf_test_dble(vrat[sunglint == 0], threshold['vis_nir_ratio'])
        # sunglint needs to be defined somewhere
        sunglint = 0
        if sunglint:
            thresh = threshold['snglnt']
        else:
            thresh = threshold['vis_nir_ratio']

        confidence = conf.conf_test_dble(vrat, thresh)

        return confidence.reshape(radshape)


class CloudMaskTests(object):

    def __init__(self, scene, radiance, coefficients):
        self.scene = scene
        self.coefficients = coefficients

    def select_coefficients(self):
        pass

    def test_G1(self):
        pass

    def test_G2(self):
        pass

    def test_G3(self):
        pass

    def test_G4(self):
        pass

    def overall_confidence(self):
        pass


# old class, doesn't use xarray much
class CloudTests_old:

    def __init__(self, scene_ids, scene_name, thresholds):
        self.scene = scene_ids
        self.scene_name = scene_name
        self.idx = np.where(scene_ids[scene_name] == 1)
        self.threshold = thresholds[scene_name]

    def single_threshold_test(self, test_name, rad, cmin):

        radshape = rad.shape
        rad = rad.reshape(np.prod(radshape))

        thr = np.array(self.threshold[test_name])
        confidence = np.zeros(radshape)

        if thr[4] == 1:
            print('test running')
            confidence[self.idx] = conf.conf_test(rad[self.idx], thr)

        cmin[self.idx] = np.minimum(cmin[self.idx], confidence[self.idx])
        return cmin

    def double_threshold_test(self):
        pass


# new class to try to use xarray more extensively
class CloudTests:

    def __init__(self, data, scene_name, thresholds):
        self.data = data
        self.scene_name = scene_name
        self.thresholds = thresholds

    def single_threshold_test(self, test_name, band, cmin):

        if band == 'bad_data':
            return cmin
        print(f'Running test "{test_name}" for "{self.scene_name}"')

        # preproc_thresholds()
        if 'thr' in self.thresholds[self.scene_name][test_name]:
            thr = np.array(self.thresholds[self.scene_name][test_name]['thr'])
        else:
            thr = np.array(self.thresholds[self.scene_name][test_name])
        thr_xr = xr.Dataset()

        if test_name == '11-12um_Cirrus_Test':
            thr_xr['threshold'] = pt.preproc(self.data, self.thresholds[self.scene_name])
            thr = np.ones((5,))  # This is only temporary to force the logic of the code
                                 # I need to find a better solution at some point
        elif test_name == 'Surface_Temperature_Test':
            thr_xr['threshold'] = pt.preproc_surf_temp(self.data, self.thresholds[self.scene_name])
            thr = np.ones((5,))
        elif test_name == 'SST_Test':
            thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'),
                                   np.ones((self.data[band].shape[0], self.data[band].shape[1], 5))*thr)
        elif test_name == 'NIR_Reflectance_Test':
            corr_thr = pt.preproc_nir(self.data, self.thresholds, self.scene_name)
            thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'), corr_thr)
        else:
            thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'),
                                   np.ones((self.data[band].shape[0], self.data[band].shape[1], 5))*thr)
        data = xr.Dataset(self.data, coords=thr_xr)

        if test_name == 'SST_Test':
            data['sfcdif'] = (('number_of_lines', 'number_of_pixels'),
                              pt.preproc_sst(data, self.thresholds[self.scene_name][test_name]).values)
            band = 'sfcdif'
        if test_name == 'NIR_Reflectance_Test':
            pass

        if thr[4] == 1:
            print('test running...')
            confidence = conf_xr.conf_test(data, band)

        cmin = np.fmin(cmin, confidence)

        return cmin


def preproc_thresholds(thresholds, data):
    thr = np.array(thresholds)
    thr_xr = xr.Dataset()
    thr_xr['tresholds'] = (('number_of_lines', 'number_of_pixels', 'z'),
                           np.ones((data['M01'].shape[0], data['M01'].shape[1], 5))*thr)

    nl_sfct1 = thresholds['Land_Night']['Surface_Temperature_Test'][0]
#    nl_sfct2 = thresholds['Land_Night']['Surface_Temperature_Test'][1]
#    nlsfct_pfm = thresholds['Land_Night']['Surface_Temperature_Test'][2]
    nl_df1 = thresholds['Land_Night']['Surface_Temperature_Test_difference'][0:2]
    nl_df2 = thresholds['Land_Night']['Surface_Temperature_Test_difference'][2:]

#    df1 = data.M15 - data.M16
#    df2 = data.M15 - data.M13
    thr_xr = thr_xr.where(data.desert != 1, nl_sfct1)
    thr_xr = thr_xr.where((data['M15-M16'] > nl_df1[0]) |
                          ((data['M15-M16'] < nl_df1[0]) &
                           ((data['M15-M13'] <= nl_df2[0]) | (data['M15-M13'] >= nl_df2[1]))),
                          nl_sfct1[0])

    data = xr.Dataset(data, coords=thr_xr)

    return data


def single_threshold_test(test, rad, threshold):

    radshape = rad.shape
    rad = rad.reshape(np.prod(radshape))

    thr = np.array(threshold[test])
    confidence = np.zeros(rad.shape)

    if thr[4] == 1:
        print(f"{test} test running")
        # the C code has the line below that I don't quite understand the purpose of.
        # It seems to be setting the bit to 0 if the BT value is greater than the midpoint
        #
        # if (m31 >= dobt11[1]) (void) set_bit(13, pxout.testbits);

        # confidence = utils.conf_test(rad, thr)
        confidence = conf.conf_test(rad, thr)

    return confidence.reshape(radshape)


def test():
    rad = np.random.randint(50, size=[4, 8])
    # coeffs = [5, 42, 20, 28, 15, 35, 1]
    # coeffs = [20, 28, 5, 42, 15, 35, 1]
    coeffs = [35, 15, 20, 1, 1]
    # confidence = conf_test_dble(rad, coeffs)
    confidence = test_11um(rad, coeffs)
    print(rad)
    print('\n')
    print(confidence)


if __name__ == "__main__":
    test()