Skip to content
Snippets Groups Projects
preprocess_thresholds.py 23.4 KiB
Newer Older
import numpy as np
import xarray as xr

import ancillary_data as anc
Paolo Veglio's avatar
Paolo Veglio committed
import utils
from numpy.lib.stride_tricks import sliding_window_view



# this case is written for the 11-12um Cirrus Test for scenes that follow pattern 1 (see note below)
def prepare_thresholds(data, thresholds):

    coeff_values = np.empty((data.M01.shape[0], data.M01.shape[1], 2))
    coeff_values[:, :, 0] = np.full(data.M01.shape, thresholds['11-12um_Cirrus_Test']['coeffs'][0])
    coeff_values[:, :, 1] = np.full(data.M01.shape, thresholds['11-12um_Cirrus_Test']['coeffs'][1])
    cmult_values = np.full(data.M01.shape, thresholds['11-12um_Cirrus_Test']['cmult'])
    adj_values = np.full(data.M01.shape, thresholds['11-12um_Cirrus_Test']['adj'])
    if 'bt1' in list(thresholds['11-12um_Cirrus_Test']):
        bt1 = np.full(data.M01.shape, thresholds['11-12um_Cirrus_Test']['bt1'])
    else:
        bt1 = np.full(data.M01.shape, -999)
    if 'lat' in list(thresholds['11-12um_Cirrus_Test']):
        lat = np.full(data.M01.shape, thresholds['11-12um_Cirrus_Test']['lat'])
    else:
        lat = np.full(data.M01.shape, -999)
    thr_dict = {'coeffs': (['number_of_lines', 'number_of_pixels', 'z'], coeff_values),
                'cmult': (['number_of_lines', 'number_of_pixels'], cmult_values),
                'adj': (['number_of_lines', 'number_of_pixels'], adj_values),
                'bt1': (['number_of_lines', 'number_of_pixels'], bt1),
                'lat': (['number_of_lines', 'number_of_pixels'], lat),
# 3. Others
# - Land_Night
# - Polar_Night_Land
# - Polar_Night_Snow
# - Day_Snow
# - Night_Snow
    cosvza = np.cos(data.sensor_zenith * _DTR)
    schi = (1/cosvza).where(cosvza > 0, 99.0)

    schi = schi.values.reshape(np.prod(schi.shape))
    m15 = data.M15.values.reshape(np.prod(data.M15.shape))
    thr = anc.py_cithr(1, schi, m15)
    thr = thr.reshape(data.M15.shape)
    schi = schi.reshape(data.M15.shape)

    # thr_xr = xr.Dataset(np.full(data.sensor_zenith.shape, thresholds['coeffs']),
    #                     dims=('number_of_lines', 'number_of_pixels'))
    thr_xr = prepare_thresholds(data, thresholds)

    midpt = thr_xr.coeffs[:, :, 0].where((thr < 0.1) | (np.abs(schi-99) < 0.0001), thr)
    locut = midpt + (thr_xr.cmult * midpt)
    if scene in ['Land_Day', 'Land_Day_Coast', 'Land_Day_Desert', 'Land_Day_Desert_Coast',
                 'Ocean_Day', 'Ocean_Night', 'Polar_Ocean_Day', 'Polar_Ocean_Night']:
        hicut = midpt - thr_xr.adj
    elif scene in ['Polar_Day_Land', 'Polar_Day_Coast', 'Polar_Day_Desert',
                   'Polar_Day_Desert_Coast', 'Polar_Day_Snow']:
        hicut = midpt - (thr_xr.adj * midpt)
    elif scene in ['Land_Night', 'Polar_Night_Land', 'Polar_Night_Snow', 'Day_Snow', 'Night_Snow']:
        _coeffs = {'Land_Night': 0.3, 'Polar_Night_Land': 0.3, 'Polar_Night_Snow': 0.3,
                   'Day_Snow': 0.0, 'Night_Snow': 0.3}
        midpt = midpt - (_coeffs[scene] * locut)
        if scene in ['Polar_Night_Land', 'Polar_Night_Snow', 'Night_Snow']:
            hicut = (midpt - (0.2 * locut)).where(data.M15 < thr_xr.bt1, midpt - 1.25)
        elif scene in ['Land_Night']:
            hicut = -0.1 - np.power(90.0 - np.abs(data.latitude)/60, 4) * 1.15
            hicut = hicut.where((data.M15 < thr_xr.bt1) & (data.latitude > thr_xr.lat), 1.25)
        elif scene in ['Day_Snow']:
            hicut = locut - (thr_xr.cmult * locut)
        else:
            print('Scene not recognized\n')
    else:
        print('Scene not recognized\n')

    thr_out = xr.DataArray(data=np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape))),
                           dims=('number_of_lines', 'number_of_pixels', 'z'))
    return thr_out
    # return locut, hicut, midpt


def preproc_sst(data, thresholds):
    m31c = data.M15 - 273.16
    m32c = data.M16 - 273.16
    m31c_m32c = m31c - m32c
    sstc = data.geos_sfct - 273.16
    cosvza = np.cos(data.sensor_zenith*_DTR)

    a = thresholds['coeffs']

    modsst = 273.16 + a[0] + a[1]*m31c + a[2]*m31c_m32c*sstc + a[3]*m31c_m32c*((1/cosvza) - 1)
    sfcdif = data.geos_sfct - modsst

    return sfcdif


Paolo Veglio's avatar
Paolo Veglio committed
def preproc_nir(data, thresholds, scene):

    sza = data.solar_zenith
    band_n = 2
    # NOTE: the visud condition in the C code is equivalent to having sza <= 85
    # For the time being the visud filtering is not implemented

    a = np.array(thresholds[scene]['NIR_Reflectance_Test']['coeffs'])
    vzcpow = thresholds['VZA_correction']['vzcpow'][0]
    refang = data.sunglint_angle.values.reshape(np.prod(data.sunglint_angle.shape))
    sunglint_thresholds = thresholds['Sun_Glint']
    sunglint_flag = utils.sunglint_scene(refang, sunglint_thresholds).reshape(refang.shape)
    nir_thresh = thresholds[scene]['NIR_Reflectance_Test']

    hicut0 = a[0] + a[1]*sza + a[2]*np.power(sza, 2) + a[3]*np.power(sza, 3)
    hicut0 = (hicut0 * 0.01) + nir_thresh['adj']
    hicut0 = (hicut0 + nir_thresh['bias']).values.reshape(refang.shape)
    midpt0 = hicut0 + (nir_thresh['midpt_coeff'] * nir_thresh['bias'])
    locut0 = midpt0 + (nir_thresh['locut_coeff'] * nir_thresh['bias'])
    thr = np.array([locut0, midpt0, hicut0, nir_thresh['thr'][3]*np.ones(refang.shape)])

    cosvza = np.cos(data.sensor_zenith*_dtr).values.reshape(refang.shape)
    corr_thr = np.zeros((4, refang.shape[0]))

    corr_thr[:3, sunglint_flag == 0] = thr[:3, sunglint_flag == 0] * (1./np.power(cosvza[sunglint_flag == 0],
                                                                                  vzcpow))
    corr_thr[3, sunglint_flag == 0] = thr[3, sunglint_flag == 0]

    for flag in range(1, 4):
        if len(refang[sunglint_flag == flag]) > 0:
            dosgref = utils.get_sunglint_thresholds(refang, sunglint_thresholds, band_n, flag, thr)
            corr_thr[:3, sunglint_flag == flag] = dosgref[:3, sunglint_flag == flag] * \
                (1./np.power(cosvza[sunglint_flag == flag], vzcpow))
            corr_thr[3, sunglint_flag == flag] = dosgref[3, sunglint_flag == flag]

    corr_thr = np.transpose(corr_thr.reshape((4, sza.shape[0], sza.shape[1])), (1, 2, 0))

    return corr_thr


def preproc_surf_temp(data, thresholds):
    thr_sfc1 = thresholds['Surface_Temperature_Test_1']
    thr_sfc2 = thresholds['Surface_Temperature_Test_2']
    thr_df1 = thresholds['Surface_Temperature_Test_df1']
    thr_df2 = thresholds['Surface_Temperature_Test_df2']
    max_vza = 70.13  # This values is set based on sensor. Check mask_processing_constants.h for MODIS value

    rs = np.prod(data.M15.shape)
    df1 = (data.M15 - data.M16).values.reshape(rs)
    df2 = (data.M15 - data.M13).values.reshape(rs)
    desert_flag = data.Desert.values.reshape(rs)
    thresh = np.ones((rs, )) * thr_sfc1

    idx = np.where((df1 >= thr_df1[0]) | ((df1 < thr_df1[0]) & ((df2 <= thr_df2[0]) | (df2 >= thr_df2[1]))))
    thresh[idx] = thr_sfc2
    idx = np.where(desert_flag == 1)
    thresh[idx] == thr_sfc1

    midpt = thresh
    idx = np.where(df1 >= thr_df1[1])
    midpt[idx] = thresh[idx] + 2.0*df1[idx]

    corr = np.power(data.sensor_zenith.values/max_vza, 4) * 3.0
    midpt = midpt.reshape(corr.shape) + corr
    locut = midpt + 2.0
    hicut = midpt - 2.0

    thr_out = xr.DataArray(data=np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape))),
                           dims=('number_of_lines', 'number_of_pixels', 'z'))

    return thr_out


def var_11um(data, thresholds):
    rad = data.M15.values
    var = np.zeros((rad.shape[0], rad.shape[1], 9))
    var_thr = thresholds['Daytime_Ocean_Spatial_Variability']['dovar11']

    test = sliding_window_view(np.pad(rad, [1, 1], mode='constant'), (3, 3)) - np.expand_dims(rad, (2, 3))

    var[np.abs(test).reshape(rad.shape[0], rad.shape[1], 9) < var_thr] = 1
    var = var.sum(axis=2)

    return var


def get_b1_thresholds(data, thresholds):
    ndvi = data.ndvi.values.reshape(data.ndvi.shape[0]*data.ndvi.shape[1])
    sctang = data.scattering_angle.values.reshape(data.ndvi.shape[0]*data.ndvi.shape[1])

    # this is hardcoded in the function
    delta_ndvi_bin = 0.1

    des_ndvi = thresholds['Misc']['des_ndvi']
    thr_adj_fac_desert = thresholds['Misc']['adj_fac_desert']
    thr_adj_fac_land = thresholds['Misc']['adj_fac_land']
    ndvi_bnd1 = thresholds['Misc']['ndvi_bnd1']
    ndvi_bnd2 = thresholds['Misc']['ndvi_bnd2']
    fill_ndvi = thresholds['Misc']['fill_ndvi']

    coeff1 = np.array(thresholds['Coeffs_Band1_land_thresh']).reshape(10, 3, 4)
    coeff2 = np.zeros((10, 3, 4))
    coeff2[:3, :, :] = np.array(thresholds['Coeffs_Band8_land_thresh']).reshape(3, 3, 4)
    coeff = np.stack((coeff1, coeff2))

    indvi = np.zeros(ndvi.shape)
    indvi[ndvi >= ndvi_bnd2] = 9

    x, y2 = np.zeros(ndvi.shape), np.zeros(ndvi.shape)

    # this is equivalent to interp=1 in the C code
    idx = np.nonzero((ndvi >= ndvi_bnd1) & (ndvi < ndvi_bnd2))

    indvi[idx] = (ndvi[idx]/delta_ndvi_bin) - 0.5
    indvi[ndvi < 0] = 0

    x1 = delta_ndvi_bin*indvi + delta_ndvi_bin/2.0
    x2 = x1 + delta_ndvi_bin
    x[idx] = (ndvi[idx] - x1[idx])/(x2[idx] - x1[idx])
    x = np.clip(x, 0, 1)

    indvi = np.array(indvi, dtype=np.int)
    thr = np.empty((ndvi.shape[0], 4))
    thr_adj = np.empty((ndvi.shape[0], 4))

    for i in range(3):
        y1 = coeff[0, indvi, i, 0] + coeff[0, indvi, i, 1]*sctang + \
             coeff[0, indvi, i, 2]*sctang**2 + coeff[0, indvi, i, 3]*sctang**3
        des = np.nonzero(ndvi < des_ndvi)
        y1[des] = coeff[1, indvi[des], i, 0] + coeff[1, indvi[des], i, 1]*sctang[des] + \
            coeff[1, indvi[des], i, 2]*sctang[des]**2 + coeff[1, indvi[des], i, 3]*sctang[des]**3

        y2[idx] = coeff[0, indvi[idx], i, 0] + \
            coeff[0, indvi[idx], i, 1]*sctang[idx] + \
            coeff[0, indvi[idx], i, 2]*sctang[idx]**2 + \
            coeff[0, indvi[idx], i, 3]*sctang[idx]**3
        idxdes = np.nonzero((ndvi >= ndvi_bnd1) & (ndvi < ndvi_bnd2) & (ndvi < des_ndvi))
        y2[idxdes] = coeff[0, indvi[idxdes], i, 0] + \
            coeff[0, indvi[idxdes], i, 1]*sctang[idxdes] + \
            coeff[0, indvi[idxdes], i, 2]*sctang[idxdes]**2 + \
            coeff[0, indvi[idxdes], i, 3]*sctang[idxdes]**3

        thr[:, i] = (1.0 - x) + (x + y2)
        thr_adj[:, i] = thr[:, i] * thr_adj_fac_desert
        thr_adj[ndvi >= des_ndvi, i] = thr[ndvi >= des_ndvi, i] * thr_adj_fac_land

    hicut = ((thr[:, 0] + thr_adj[:, 0])/100)  # .reshape(data.ndvi.shape)
    midpt = ((thr[:, 1] + thr_adj[:, 1])/100)  # .reshape(data.ndvi.shape)
    locut = ((thr[:, 2] + thr_adj[:, 2])/100)  # .reshape(data.ndvi.shape)

    idx = np.nonzero((ndvi >= fill_ndvi[0]) | (ndvi <= fill_ndvi[1]))
    hicut[idx] = -999
    midpt[idx] = -999
    locut[idx] = -999

#    out_thr = xr.DataArray(data=np.dstack((locut, midpt, hicut, np.ones(data.ndvi.shape),
#                                           np.full(data.ndvi.shape, 2))),
#                           dims=('number_of_lines', 'number_of_pixels', 'z'))
#
#    return out_thr
    return locut, midpt, hicut


def get_pn_thresholds(data, thresholds, scene, test_name):

    thresholds = thresholds[scene]
    if ((test_name == '4-12um_BTD_Thin_Cirrus_Test') and (scene in ['Land_Night', 'Night_Snow']) or
       (test_name == '7.3-11um_BTD_Mid_Level_Cloud_Test') and (scene == 'Land_Night')):
        locut = thresholds[test_name]['thr'][0] * np.ones(data.M15.shape)
        midpt = thresholds[test_name]['thr'][1] * np.ones(data.M15.shape)
        hicut = thresholds[test_name]['thr'][2] * np.ones(data.M15.shape)
        power = thresholds[test_name]['thr'][3] * np.ones(data.M15.shape)

        out_thr = xr.DataArray(data=np.dstack((locut, midpt, hicut, np.ones(data.ndvi.shape), power)),
                               dims=('number_of_lines', 'number_of_pixels', 'z'))
        return out_thr

    rad = data.M15.values.reshape(data.M15.shape[0]*data.M15.shape[1])
    bt_bounds = thresholds[test_name]['bt11_bounds']

    locut, midpt = np.empty(rad.shape), np.empty(rad.shape)
    hicut, power = np.empty(rad.shape), np.empty(rad.shape)
    lo, hi = np.empty(rad.shape), np.empty(rad.shape)
    lo_thr, hi_thr = np.empty(rad.shape), np.empty(rad.shape)
    conf_range = np.empty(rad.shape)

    idx = np.nonzero(rad < bt_bounds[0])
    locut[idx] = thresholds[test_name]['low'][0]
    midpt[idx] = thresholds[test_name]['low'][1]
    hicut[idx] = thresholds[test_name]['low'][2]
    power[idx] = thresholds[test_name]['low'][3]

    idx = np.nonzero(rad > bt_bounds[3])
    locut[idx] = thresholds[test_name]['high'][0]
    midpt[idx] = thresholds[test_name]['high'][1]
    hicut[idx] = thresholds[test_name]['high'][2]
    power[idx] = thresholds[test_name]['high'][3]

    # # # # #
    idx = np.nonzero((rad >= bt_bounds[0]) & (rad <= bt_bounds[3]) &
                     (bt_bounds[1] == 0) & (bt_bounds[2] == 0))
    lo[idx] = thresholds[test_name]['bt11_bounds'][0]
    hi[idx] = thresholds[test_name]['bt11_bounds'][3]
    lo_thr[idx] = thresholds[test_name]['mid1'][0]
    hi_thr[idx] = thresholds[test_name]['mid1'][1]
    power[idx] = thresholds[test_name]['mid1'][3]
    conf_range[idx] = thresholds[test_name]['mid1'][2]

    idx = np.nonzero((rad >= bt_bounds[0]) & (rad < bt_bounds[1]))
    lo[idx] = thresholds[test_name]['bt11_bounds'][0]
    hi[idx] = thresholds[test_name]['bt11_bounds'][1]
    lo_thr[idx] = thresholds[test_name]['mid1'][0]
    hi_thr[idx] = thresholds[test_name]['mid1'][1]
    power[idx] = thresholds[test_name]['mid1'][3]
    conf_range[idx] = thresholds[test_name]['mid1'][2]

    idx = np.nonzero((rad >= bt_bounds[1]) & (rad < bt_bounds[2]))
    lo[idx] = thresholds[test_name]['bt11_bounds'][1]
    hi[idx] = thresholds[test_name]['bt11_bounds'][2]
    lo_thr[idx] = thresholds[test_name]['mid2'][0]
    hi_thr[idx] = thresholds[test_name]['mid2'][1]
    power[idx] = thresholds[test_name]['mid2'][3]
    conf_range[idx] = thresholds[test_name]['mid2'][2]

    idx = np.nonzero((rad >= bt_bounds[2]) & (rad < bt_bounds[3]))
    lo[idx] = thresholds[test_name]['bt11_bounds'][2]
    hi[idx] = thresholds[test_name]['bt11_bounds'][3]
    lo_thr[idx] = thresholds[test_name]['mid3'][0]
    hi_thr[idx] = thresholds[test_name]['mid3'][1]
    power[idx] = thresholds[test_name]['mid3'][3]
    conf_range[idx] = thresholds[test_name]['mid3'][2]

    idx = np.nonzero(((rad >= bt_bounds[0]) & (rad < bt_bounds[3])) |
                     (bt_bounds[1] == 0.0) | (bt_bounds[2] == 0))

    a = (rad[idx] - lo[idx])/(hi[idx] - lo[idx])
    midpt[idx] = lo_thr[idx] + (a*(hi_thr[idx] - lo_thr[idx]))
    hicut[idx] = midpt[idx] - conf_range[idx]
    locut[idx] = midpt[idx] + conf_range[idx]

    locut = locut.reshape(data.M15.shape)
    midpt = midpt.reshape(data.M15.shape)
    hicut = hicut.reshape(data.M15.shape)
    power = power.reshape(data.M15.shape)

    out_thr = xr.DataArray(data=np.dstack((locut, midpt, hicut, np.ones(data.ndvi.shape), power)),
                           dims=('number_of_lines', 'number_of_pixels', 'z'))

    return out_thr


def get_nl_thresholds(data, threshold):
    lo_val = threshold['bt_diff_bounds'][0]
    hi_val = threshold['bt_diff_bounds'][1]
    lo_val_thr = threshold['nl_11_4m'][0]
    hi_val_thr = threshold['nl_11_4m'][1]
    conf_range = threshold['nl_11_4m'][2]
    power = threshold['nl_11_4m'][3]

    a = (data['M15-M16'].values - lo_val) / (hi_val - lo_val)
    midpt = lo_val_thr + a*(hi_val_thr - lo_val_thr)
    hicut = midpt - conf_range
    locut = midpt + conf_range

    idx = np.nonzero(data['M15-M16'].values > threshold['bt_diff_bounds'][0])
    locut[idx] = threshold['nl_11_4l'][0]
    midpt[idx] = threshold['nl_11_4l'][1]
    hicut[idx] = threshold['nl_11_4l'][2]
    power[idx] = threshold['nl_11_4l'][3]

    idx = np.nonzero(data['M15-M16'].values < threshold['bt_diff_bounds'][1])
    locut[idx] = threshold['nl_11_4h'][0]
    midpt[idx] = threshold['nl_11_4h'][1]
    hicut[idx] = threshold['nl_11_4h'][2]
    power[idx] = threshold['nl_11_4h'][3]

    out_thr = xr.DataArray(data=np.stack((locut, midpt, hicut, np.ones(data.M01.shape), power)),
                           dims=('number_of_lines', 'number_of_pixels', 'z'))
    return out_thr


def vis_refl_thresholds(data, thresholds, scene):

    locut, midpt, hicut = get_b1_thresholds(data, thresholds)
    bias_adj = thresholds[scene]['Visible_Reflectance_Test']['adj']
    ndvi = data.ndvi.values.reshape(data.ndvi.shape[0]*data.ndvi.shape[1])
    m01 = data.M05.values.reshape(data.ndvi.shape[0]*data.ndvi.shape[1])
    m02 = data.M07.values.reshape(data.ndvi.shape[0]*data.ndvi.shape[1])
    m08 = data.M01.values.reshape(data.ndvi.shape[0]*data.ndvi.shape[1])

    m128 = m01

    b1_locut = locut * bias_adj
    b1_midpt = midpt * bias_adj
    b1_hicut = hicut * bias_adj

    if ((scene == 'Land_Day_Desert') | (scene == 'Land_Day_Desert_Coast')):
        ndvi_desert_thr = thresholds[scene]['Visible_Reflectance_Test']['ndvi_thr']
        idx = np.nonzero(ndvi < ndvi_desert_thr)

        b1_locut[idx] = locut[idx]
        b1_midpt[idx] = midpt[idx]
        b1_hicut[idx] = hicut[idx]
        m128[idx] = m08[idx]

    b1_power = np.full(b1_locut.shape, 2)

    idx = np.nonzero(locut == -999)
    b1_locut[idx] = thresholds[scene]['Visible_Reflectance_Test']['thr'][0]
    b1_midpt[idx] = thresholds[scene]['Visible_Reflectance_Test']['thr'][1]
    b1_hicut[idx] = thresholds[scene]['Visible_Reflectance_Test']['thr'][2]
    b1_power[idx] = thresholds[scene]['Visible_Reflectance_Test']['thr'][3]
    m128[idx] = m02[idx]

    cosvza = np.cos(data.sensor_zenith.values * _dtr).reshape(ndvi.shape)

    vzcpow = thresholds['VZA_correction']['vzcpow'][0]
    b1_locut = (b1_locut * (1.0 / np.power(cosvza, vzcpow))).reshape(data.ndvi.shape)
    b1_midpt = (b1_midpt * (1.0 / np.power(cosvza, vzcpow))).reshape(data.ndvi.shape)
    b1_hicut = (b1_hicut * (1.0 / np.power(cosvza, vzcpow))).reshape(data.ndvi.shape)

    out_thr = xr.DataArray(data=np.dstack((b1_locut, b1_midpt, b1_hicut, np.ones(data.ndvi.shape),
                                           b1_power.reshape(data.ndvi.shape))),
                           dims=('number_of_lines', 'number_of_pixels', 'z'))
    out_rad = xr.DataArray(data=m128.reshape(data.M01.shape), dims=('number_of_lines', 'number_of_pixels'))

    return out_thr, out_rad
def nir_refl(data, thresholds, scene_name):
    sza = data.solar_zenith.values
    band_n = 2

    thr_nir = thresholds[scene_name]['1.6_2.1um_NIR_Reflectance_Test']
    coeffs = thr_nir['coeffs']
    vzcpow = thresholds['VZA_correction']['vzcpow'][0]

    refang = data.sunglint_angle.values.reshape(np.prod(data.sunglint_angle.shape))

    sunglint_thresholds = thresholds['Sun_Glint']
    sunglint_flag = utils.sunglint_scene(refang, sunglint_thresholds).reshape(refang.shape)

    hicut0 = np.array(coeffs[0] + coeffs[1]*sza + coeffs[2]*np.power(sza, 2) +
                      coeffs[3]*np.power(sza, 3) + coeffs[4]*np.power(sza, 4))
    hicut0 = (hicut0*0.01) + thr_nir['adj']
    hicut0 = (hicut0*thr_nir['bias']).reshape(refang.shape)
    midpt0 = hicut0 + (thr_nir['midpt_coeff']*thr_nir['bias'])
    locut0 = midpt0 + (thr_nir['locut_coeff']*thr_nir['bias'])
    thr = np.array([locut0, midpt0, hicut0, np.ones(refang.shape)])

    cosvza = np.cos(data.sensor_zenith*_dtr).values.reshape(refang.shape)
    corr_thr = np.zeros((4, refang.shape[0]))

    corr_thr[:3, sunglint_flag == 0] = thr[:3, sunglint_flag == 0] * (1./np.power(cosvza[sunglint_flag == 0],
                                                                                  vzcpow))
    corr_thr[3, sunglint_flag == 0] = thr[3, sunglint_flag == 0]

    for flag in range(1, 4):
        if len(refang[sunglint_flag == flag]) > 0:
            dosgref = utils.get_sunglint_thresholds(refang, sunglint_thresholds, band_n, flag, thr)
            corr_thr[:3, sunglint_flag == flag] = dosgref[:3, sunglint_flag == flag] * \
                (1./np.power(cosvza[sunglint_flag == flag], vzcpow))
            corr_thr[3, sunglint_flag == flag] = dosgref[3, sunglint_flag == flag]

    corr_thr = np.transpose(corr_thr.reshape((4, sza.shape[0], sza.shape[1])), (1, 2, 0))

    return corr_thr


def GEMI_test(data, thresholds, scene_name):
    thresh = thresholds[scene_name]['GEMI_Test']

    gemi_thr = np.ones((data.M01.shape[0], data.M01.shape[1], 5))

    idx = np.nonzero(data.ndvi < 0.1)
    gemi_thr[idx[0], idx[1], :3] = thresh['gemi0'][:3]
    idx = np.nonzero((data.ndvi >= 0.1) & (data.ndvi < 0.2))
    gemi_thr[idx[0], idx[1], :3] = thresh['gemi1'][:3]
    idx = np.nonzero((data.ndvi >= 0.2) & (data.ndvi < 0.3))
    gemi_thr[idx[0], idx[1], :3] = thresh['gemi2'][:3]

    thr_out = xr.DataArray(data=np.dstack((gemi_thr[:, :, 0], gemi_thr[:, :, 1], gemi_thr[:, :, 2],
                                           np.ones(gemi_thr[:, :, 0].shape),
                                           np.ones(gemi_thr[:, :, 0].shape))),
                           dims=('number_of_lines', 'number_of_pixels', 'z'))
    return thr_out


def bt11_4um_preproc(data, thresholds, scene_name):
    thresh = thresholds[scene_name]['11-4um_BT_Difference_Test']
    c = thresh['coeffs']
    tpw = data.geos_tpw.values

    thr = c[0] + thresh['corr'] + c[1]*tpw + c[2]*np.power(tpw, 2)
    hicut0 = (thr + thresh['hicut_coeff'][0]).reshape(1, np.prod(tpw.shape))
    hicut1 = (thr + thresh['hicut_coeff'][1]).reshape(1, np.prod(tpw.shape))
    midpt0 = (hicut0 + thresh['midpt_coeff'][0]).reshape(1, np.prod(tpw.shape))
    midpt1 = (hicut1 + thresh['midpt_coeff'][1]).reshape(1, np.prod(tpw.shape))
    locut0 = (hicut0 + thresh['locut_coeff'][0]).reshape(1, np.prod(tpw.shape))
    locut1 = (hicut1 + thresh['locut_coeff'][1]).reshape(1, np.prod(tpw.shape))

    thr_out = np.vstack([hicut0, midpt0, locut0, locut1, midpt1, hicut1,
                         np.ones(hicut0.shape), np.ones(hicut0.shape)])
    return thr_out


def test_1_38um_preproc(data, thresholds, scene_name):
    sza = data.solar_zenith.values
    vza = data.sensor_zenith.values
    thresh = thresholds[scene_name]['1.38um_High_Cloud_Test']
    c = thresh['coeffs']
    vzcpow = thresholds['VZA_correction']['vzcpow'][1]

    hicut0 = c[0] + c[1]*sza + c[2]*np.power(sza, 2) + c[3]*np.power(sza, 3) + \
        c[4]*np.power(sza, 4) + c[5]*np.power(sza, 5)
    hicut0 = hicut0*0.01 + (np.maximum((sza/90.)*thresh['szafac']*thresh['adj'], thresh['adj']))
    midpt0 = hicut0 + 0.001
    locut0 = midpt0 + 0.001

    cosvza = np.cos(vza*_dtr)

    locut0 = locut0 * (1/np.power(cosvza, vzcpow))
    midpt0 = midpt0 * (1/np.power(cosvza, vzcpow))
    hicut0 = hicut0 * (1/np.power(cosvza, vzcpow))

    out_thr = xr.DataArray(data=np.dstack((locut0, midpt0, hicut0, np.ones(data.ndvi.shape),
                                           np.ones(data.ndvi.shape))),
                           dims=('number_of_lines', 'number_of_pixels', 'z'))
    return out_thr

# NOTE: 11-12um Cirrus Test
# hicut is computed in different ways depending on the scene
# 1. midpt - adj
# - Land_Day
# - Land_Day_Coast
# - Land_Day_Desert
# - Land_Day_Desert_Coast
# - Ocean_Day
# - Ocean_Night
# - Polar_Day_Ocean
# - Polar_Night_Ocean
#
# 2. midpt - (btd_thr * adj)
# - Polar_Day_Land
# - Polar_Day_Coast
# - Polar_Day_Desert
# - Polar_Day_Desert_Coast
# - Polar_Day_Snow
#
# 3. Others
# - Land_Night
# - Polar_Night_Land
# - Polar_Night_Snow
# - Day_Snow
# - Night_Snow

# NOTE: 1.38um High Cloud Test
# thresholds are not always computed the same way. In group 1 there's no preprocessing required,
# in group 2 some calcuations are needed
# 1.
# - Land_Day
# - Land_Day_Coast
# - Land_Day_Desert
# - Land_Day_Desert_Coast
# - Polar_Day_Land
# - Polar_Day_Coast
# - Polar_Day_Desert
# - Polar_Day_Desert_Coast
# - Polar_Day_Snow
# - Day_Snow
#
# 2.
# - Ocean_Day
# - Polar_Ocean_Day