import numpy as np
import xarray as xr

# from numpy.lib.stride_tricks import sliding_window_view
from typing import Dict

import functools

# import utils
import conf
import conf_xr
import scene as scn
import preprocess_thresholds as preproc
import restoral

import importlib

_RTD = 180./np.pi
_DTR = np.pi/180

# this is used for testing, eventually we want to remove it
importlib.reload(preproc)
importlib.reload(conf)
importlib.reload(restoral)


class CloudTests(object):

    def __init__(self,
                 data: xr.Dataset,
                 scene_name: str,
                 thresholds: Dict) -> None:
        self.data = data
        self.scene_name = scene_name
        self.thresholds = thresholds
        self.scene_idx = tuple(np.nonzero(data[scene_name] == 1))
        if self.scene_idx[0].shape[0] == 0:
            self.pixels_in_scene = False
        else:
            self.pixels_in_scene = True

    def run_if_test_exists_for_scene(func):
        @functools.wraps(func)
        def wrapper(self, *args, test_name, **kwargs):
            if test_name not in self.thresholds[self.scene_name]:
                # print('Not running test for this scene')
                # returns cmin. This could be changed into a keyworded argument for readability
                test_bit = np.zeros(args[-1].shape)
                return args[-1], test_bit
            return func(self, *args, test_name, **kwargs)
        return wrapper

    @run_if_test_exists_for_scene
    def test_11um(self,
                  band: str,
                  cmin: np.ndarray,
                  test_name: str = '11um_Test') -> np.ndarray:

        confidence = np.ones(self.data[band].shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            idx = np.nonzero((self.data[band].values >= threshold['thr'][1]) &
                             (self.data[self.scene_name] == 1))
            test_bit[idx] = 1
            confidence[self.scene_idx] = conf.conf_test_new(rad, threshold['thr'])

        cmin = np.fmin(cmin, confidence)

        return cmin, test_bit

    @run_if_test_exists_for_scene
    def surface_temperature_test(self,
                                 band: str,
                                 viirs_data: xr.Dataset,
                                 cmin: np.ndarray,
                                 test_name: str = 'Surface_Temperature_Test') -> np.ndarray:

        confidence = np.ones(self.data[band].shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            sfcdif = viirs_data.geos_sfct.values[self.scene_idx] - rad
            # need to write the test_bit here
            thr = preproc.thresholds_surface_temperature(viirs_data, threshold, self.scene_idx)
            confidence[self.scene_idx] = conf.conf_test_new(sfcdif, thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, test_bit

    @run_if_test_exists_for_scene
    def sst_test(self,
                 band31: str,
                 band32: str,
                 cmin: np.ndarray,
                 test_name: str = 'SST_Test') -> np.ndarray:

        confidence = np.ones(self.data[band31].shape)
        qa_bit = np.zeros(self.data[band31].shape)
        test_bit = np.zeros(self.data[band31].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            m31 = self.data[band31].values - 273.16
            bt_diff = self.data[band31].values - self.data[band32].values
            sst = self.data.sst.values - 273.16
            cosvza = np.cos(self.data.sensor_zenith.values*_DTR)

            c = threshold['coeffs']

            modsst = 273.16 + c[0] + c[1]*m31 + c[2]*bt_diff*sst + c[3]*bt_diff*((1.0/cosvza) - 1.0)
            sfcdif = self.data.sst.values - modsst

            idx = np.nonzero((sfcdif < threshold['thr'][1]) &
                             (self.data[self.scene_name] == 1))
            test_bit[idx] = 1
            print(f'Testing "{self.scene_name}"\n')
            confidence[self.scene_idx] = conf.conf_test_new(sfcdif[self.scene_idx], threshold['thr'])

        cmin = np.fmin(cmin, confidence)

        # return cmin, np.abs(1-test_bit)*qa_bit
        return cmin, test_bit

    @run_if_test_exists_for_scene
    def bt_diff_86_11um(self,
                        band: str,
                        cmin: np.ndarray,
                        test_name: str = '8.6-11um_Test') -> np.ndarray:

        confidence = np.ones(self.data[band].shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            idx = np.nonzero((self.data[band].values < threshold['thr'][1]) &
                             (self.data[self.scene_name] == 1))
            test_bit[idx] = 1
            confidence[self.scene_idx] = conf.conf_test_new(rad, threshold['thr'])

        cmin = np.fmin(cmin, confidence)

        return cmin, test_bit  # np.abs(1-test_bit)*qa_bit

    @run_if_test_exists_for_scene
    def test_11_12um_diff(self,
                          band: str,
                          cmin: np.ndarray,
                          test_name: str = '11-12um_Cirrus_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            thr = preproc.thresholds_11_12um(self.data, threshold, self.scene_name, self.scene_idx)
            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            idx = np.nonzero((rad <= thr[1, :]) &
                             (self.data[self.scene_name].values[self.scene_idx] == 1))
            tmp_bit = test_bit[self.scene_idx]
            tmp_bit[idx] = 1
            test_bit[self.scene_idx] = tmp_bit
            confidence[self.scene_idx] = conf.conf_test_new(rad, thr)

        cmin = np.fmin(cmin, confidence)
        test_bit = test_bit.reshape(self.data[band].shape)
        # return cmin, np.abs(1-test_bit)*qa_bit
        return cmin, test_bit

    @run_if_test_exists_for_scene
    def bt_difference_11_4um_test(self,
                                  band: str,
                                  cmin: np.ndarray,
                                  test_name: str = '11-4um_BT_Difference_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            thr = preproc.bt_diff_11_4um_thresholds(self.data, threshold)

        # CONTINUE FROM HERE...


    @run_if_test_exists_for_scene
    def midlevel_cloud_test():
        pass

    @run_if_test_exists_for_scene
    def water_vapor_cloud_test():
        pass

    @run_if_test_exists_for_scene
    def variability_11um_test():
        pass

    @run_if_test_exists_for_scene
    def oceanic_stratus_11_4um_test(self,
                                    band: str,
                                    cmin: np.ndarray,
                                    test_name: str = '11-4um_Oceanic_Stratus_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            thr = threshold['thr']
            if self.scene_name in ['Land_Day_Desert', 'Land_Day_Desert_Coast', 'Polar_Day_Desert',
                                   'Polar_Day_Desert_Coast']:
                confidence[self.scene_idx] = conf.conf_test_dble(rad, thr)

            # these scenes have not been implemented yet
            elif self.scene_name in ['Land_Night', 'Polar_Night_Land', 'Polar_Day_Snow', 'Polar_Night_Snow',
                                     'Day_Snow', 'Night_Snow', 'Antarctic_Day']:
                pass
            else:
                scene_flag = scn.find_scene(self.data, self.thresholds['Sun_Glint']['bounds'][3])
                idx = np.nonzero((self.data[band].values >= threshold['thr'][1]) &
                                 (self.data[self.scene_name] == 1) &
                                 (scene_flag['sunglint'] == 0))
                test_bit[idx] = 1
                idx = np.nonzero((self.data[self.scene_name] == 1) &
                                 (scene_flag['sunglint'] == 0))
                confidence[idx] = conf.conf_test_new(self.data[band].values[idx], thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, test_bit  # np.abs(1-test_bit)*qa_bit

    @run_if_test_exists_for_scene
    def nir_reflectance_test(self,
                             band: str,
                             cmin: np.ndarray,
                             test_name: str = 'NIR_Reflectance_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            thr = preproc.thresholds_NIR(self.data, self.thresholds, self.scene_name,
                                         test_name, self.scene_idx)
            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            idx = np.nonzero((rad <= thr[1, :]) &
                             (self.data[self.scene_name].values[self.scene_idx] == 1))
            tmp_bit = test_bit[self.scene_idx]
            tmp_bit[idx] = 1
            test_bit[self.scene_idx] = tmp_bit
            confidence[self.scene_idx] = conf.conf_test_new(rad, thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, np.abs(1-test_bit)*qa_bit

    @run_if_test_exists_for_scene
    def vis_nir_ratio_test(self,
                           band: str,
                           cmin: np.ndarray,
                           test_name: str = 'Vis/NIR_Reflectance_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            # I'm not using the self.scene_idx here because it messes up with the indexing of the
            # confidence array and I don't want to think of a solution at the moment. I will need
            # to figure out the logic to make it work once I'm at the stage where I want to optimize
            # the code
            qa_bit[self.scene_idx] = 1
            rad = self.data[band].values  # [self.scene_idx]
            solar_zenith = self.data.solar_zenith.values  # [self.scene_idx]
            sunglint = scn.find_scene(self.data, self.thresholds['Sun_Glint']['bounds'][3])['sunglint']

            idx = np.nonzero((solar_zenith <= 85) & (sunglint == 1))

            thr_no_sunglint = np.array([threshold['thr'][i] for i in range(8)])
            thr_sunglint = np.array([self.thresholds['Sun_Glint']['snglnt'][i] for i in range(8)])
            print(thr_no_sunglint)
            print(thr_sunglint)
            # tmp = self.thresholds['Sun_Glint']['snglnt']
            # thr_sunglint = np.array([tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], 1])
            confidence[sunglint == 0] = conf.conf_test_dble(rad[sunglint == 0], thr_no_sunglint)
            confidence[idx] = conf.conf_test_dble(rad[idx], thr_sunglint)

            idx = np.nonzero(((rad < thr_no_sunglint[1]) | (rad > thr_no_sunglint[4])) &
                             (self.data[self.scene_name].values == 1) & (sunglint == 0))
            test_bit[tuple(idx)] = 1

            idx = np.nonzero(((rad < thr_sunglint[1]) | (rad > thr_sunglint[4])) &
                             ((self.data[self.scene_name].values == 1) &
                              (solar_zenith <= 85) & (sunglint == 1)))
            test_bit[tuple(idx)] = 1

        cmin = np.fmin(cmin, confidence)

        return cmin, np.abs(1-test_bit)*qa_bit

    @run_if_test_exists_for_scene
    def test_16_21um_Reflectance(self,
                                 band: str,
                                 cmin: np.ndarray,
                                 test_name: str = '1.6_2.1um_NIR_Reflectance_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            thr = preproc.thresholds_NIR(self.data, self.thresholds, self.scene_name,
                                         test_name, self.scene_idx)
            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            idx = np.nonzero((rad <= thr[1, :]) &
                             (self.data[self.scene_name].values[self.scene_idx] == 1))
            tmp_bit = test_bit[self.scene_idx]
            tmp_bit[idx] = 1
            test_bit[self.scene_idx] = tmp_bit
            confidence[self.scene_idx] = conf.conf_test_new(rad, thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, np.abs(1-test_bit)*qa_bit

    @run_if_test_exists_for_scene
    def visible_reflectance_test(self,
                                 band: str,
                                 cmin: np.ndarray,
                                 test_name: str = 'Visible_Reflectance_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            print(f'Testing "{self.scene_name}"\n')
            thr, rad = preproc.vis_refl_thresholds(self.data, self.thresholds, self.scene_name,
                                                   self.scene_idx)
            confidence[self.scene_idx] = conf.conf_test_new(rad, thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, test_bit

    @run_if_test_exists_for_scene
    def gemi_test(self,
                  band: str,
                  cmin: np.ndarray,
                  test_name: str = 'GEMI_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            thr = preproc.gemi_thresholds(self.data, threshold, self.scene_name,
                                          self.scene_idx)
            rad = self.data[band].values[self.scene_idx]
            confidence[self.scene_idx] = conf.conf_test_new(rad, thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, test_bit

    @run_if_test_exists_for_scene
    def test_1_38um_high_clouds(self,
                                band: str,
                                cmin: np.ndarray,
                                test_name: str = '1.38um_High_Cloud_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        qa_bit = np.zeros(self.data[band].shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            qa_bit[self.scene_idx] = 1
            if self.scene_name in ['Ocean_Day', 'Polar_Day_Ocean']:
                thr = preproc.thresholds_1_38um_test(self.data, self.thresholds, self.scene_name,
                                                     self.scene_idx)
            else:
                return cmin, test_bit
                thr = threshold['thr']

            print(f'Testing "{self.scene_name}"\n')
            rad = self.data[band].values[self.scene_idx]
            idx = np.nonzero((rad <= thr[1, :]) &
                             (self.data[self.scene_name].values[self.scene_idx] == 1))
            tmp_bit = test_bit[self.scene_idx]
            tmp_bit[idx] = 1
            test_bit[self.scene_idx] = tmp_bit
            confidence[self.scene_idx] = conf.conf_test_new(rad, thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, np.abs(1-test_bit)*qa_bit

    @run_if_test_exists_for_scene
    def thin_cirrus_4_12um_BTD_test(self,
                                    band: str,
                                    cmin: np.ndarray,
                                    test_name: str = '4-12um_BTD_Thin_Cirrus_Test') -> np.ndarray:

        confidence = np.ones(self.data.M01.shape)
        test_bit = np.zeros(self.data[band].shape)
        threshold = self.thresholds[self.scene_name][test_name]

        if (threshold['perform'] is True and self.pixels_in_scene is True):
            thr = preproc.polar_night_thresholds(self.data, self.thresholds, self.scene_name,
                                                 test_name, self.scene_idx)
            rad = self.data[band].values[self.scene_idx]
            confidence[self.scene_idx] = conf.conf_test_new(rad, thr)

        cmin = np.fmin(cmin, confidence)

        return cmin, test_bit

    def single_threshold_test(self, test_name, band, cmin):

        if band == 'bad_data':
            return cmin
        print(f'Running test "{test_name}" for "{self.scene_name}"')

        # preproc_thresholds()
        if 'thr' in self.thresholds[self.scene_name][test_name]:
            thr = np.array(self.thresholds[self.scene_name][test_name]['thr'])
        else:
            thr = np.array(self.thresholds[self.scene_name][test_name])
        thr_xr = xr.Dataset()

        if test_name == '11-12um_Cirrus_Test':
            thr_xr['threshold'] = pt.preproc(self.data, self.thresholds[self.scene_name], self.scene_name)
            thr = np.ones((5,))  # This is only temporary to force the logic of the code
                                 # I need to find a better solution at some point
        elif test_name == 'SST_Test':
            thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'),
                                   np.ones((self.data[band].shape[0], self.data[band].shape[1], 5))*thr)
        elif test_name == '7.3-11um_BTD_Mid_Level_Cloud_Test':
            thr_xr['threshold'] = pt.get_pn_thresholds(self.data, self.thresholds, self.scene_name,
                                                       '7.3-11um_BTD_Mid_Level_Cloud_Test')
            thr = np.ones((5,))
        elif test_name == 'Surface_Temperature_Test':
            thr_xr['threshold'] = pt.preproc_surf_temp(self.data, self.thresholds[self.scene_name])
            thr = np.ones((5,))
        elif (test_name == '11-4um_Oceanic_Stratus_Test' and
              self.scene_name in ['Land_Day_Desert', 'Land_Day_Desert_Coast', 'Polar_Day_Desert',
                                  'Polar_Day_Desert_Coast']):
            thr = np.array([self.thresholds[self.scene_name][test_name][i] for i in range(8)])
        elif test_name == 'NIR_Reflectance_Test':
            corr_thr = pt.preproc_nir(self.data, self.thresholds, self.scene_name)
            thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'), corr_thr)
        elif test_name == 'Visible_Reflectance_Test':
            thr_xr['threshold'], self.data['M128'] = pt.vis_refl_thresholds(self.data,
                                                                            self.thresholds,
                                                                            self.scene_name)
        elif test_name == '1.6_2.1um_NIR_Reflectance_Test':
            corr_thr = pt.nir_refl(self.data, self.thresholds, self.scene_name)
            thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'), corr_thr)
            thr = np.ones((5,))
        elif test_name == '4-12um_BTD_Thin_Cirrus_Test':
            thr_xr['threshold'] = pt.get_pn_thresholds(self.data, self.thresholds, self.scene_name,
                                                       '4-12um_BTD_Thin_Cirrus_Test')
            thr = np.ones((5,))
        elif test_name == 'GEMI_Test':
            thr_xr['threshold'] = pt.GEMI_test(self.data, self.thresholds, self.scene_name)
            thr = np.ones((5,))
        elif (test_name == '1.38um_High_Cloud_Test' and self.scene_name in ['Ocean_Day', 'Polar_Ocean_Day']):
            thr_xr['threshold'] = pt.test_1_38um_preproc(self.data, self.thresholds, self.scene_name)
            thr = np.ones((5,))
        else:
            thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'),
                                   np.ones((self.data[band].shape[0], self.data[band].shape[1], 5))*thr)

        data = xr.Dataset(self.data, coords=thr_xr)

        if test_name == 'SST_Test':
            data['sfcdif'] = (('number_of_lines', 'number_of_pixels'),
                              pt.preproc_sst(data, self.thresholds[self.scene_name][test_name]).values)
            band = 'sfcdif'
        if test_name == '11um_Variability_Test':
            var = pt.var_11um(self.data, self.thresholds)
            data['11um_var'] = data.M15
            data['11um_var'].values[var != 9] = np.nan

        if thr[4] == 1:
            print('test running...')
            confidence = conf_xr.conf_test(data, band)

        cmin = np.fmin(cmin, confidence)

        return cmin

    def double_threshold_test(self, test_name, band, cmin):
        data = self.data
        if test_name == '11-4um_BT_Difference_Test':
            thr = pt.bt11_4um_preproc(self.data, self.thresholds, self.scene_name)
            print('test running...')
            confidence = conf.conf_test_dble(data['M15-M13'].values, thr)
            confidence = confidence.reshape(data.M01.shape)

        if test_name == 'Vis/NIR_Ratio_Test':
            print('test running...')
            thr_no_sunglint = np.array([self.thresholds[self.scene_name][test_name][i] for i in range(8)])
            thr_sunglint = np.array([self.thresholds['Sun_Glint']['snglnt'][i] for i in range(8)])
            vrat = data.M07.values/data.M05.values
            _dtr = np.pi/180.0
            sza = data.sensor_zenith.values
            raz = data.relative_azimuth.values
            vza = data.sensor_zenith.values
            cos_refang = np.sin(vza*_dtr) * np.sin(sza*_dtr) * np.cos(raz*_dtr) + \
                np.cos(vza*_dtr) * np.cos(sza*_dtr)
            refang = np.arccos(cos_refang) * 180./np.pi
            idx = np.nonzero((data.solar_zenith <= 85) & (refang <= data.sunglint_angle))

            confidence = conf.conf_test_dble(vrat, thr_no_sunglint)
            confidence = confidence.reshape(data.M01.shape)
            confidence[idx] = conf.conf_test_dble(vrat[idx], thr_sunglint)

            confidence = confidence.reshape(data.M01.shape)

        if (test_name == '11-4um_Oceanic_Stratus_Test' and
            self.scene_name in ['Land_Day_Desert', 'Land_Day_Desert_Coast', 'Polar_Day_Desert',
                                'Polar_Day_Desert_Coast']):
            thr = np.array([self.thresholds[self.scene_name][test_name][i] for i in range(8)])

            print('test running...')
            confidence = conf.conf_test_dble(data['M15-M16'].values, thr)
            confidence = confidence.reshape(data.M01.shape)

        cmin = np.fmin(cmin, confidence)

        return cmin


class ComputeTests(CloudTests):

    def __init__(self,
                 data: xr.Dataset,
                 scene_name: str,
                 thresholds: Dict) -> None:
        super().__init__(data, scene_name, thresholds)

    def run_tests(self):
        cmin_G1 = np.ones(self.data.M01.shape)
        cmin_G2 = np.ones(self.data.M01.shape)
        cmin_G3 = np.ones(self.data.M01.shape)
        cmin_G4 = np.ones(self.data.M01.shape)
        cmin_G5 = np.ones(self.data.M01.shape)

        # Group 1
        cmin_G1 = self.test_11um('M15', cmin_G1, test_name='11um_Test')
        cmin_G1 = self.sst_test('M15', 'M16', cmin_G1, test_name='SST_Test')

        # Group 2
        cmin_G2 = self.bt_diff_86_11um('M14-M15', cmin_G2, test_name='8.6-11um_Test')
        cmin_G2 = self.test_11_12um_diff('M15-M16', cmin_G2, test_name='11-12um_Cirrus_Test')
        cmin_G2 = self.oceanic_stratus_11_4um_test('M15-M13', cmin_G2,
                                                   test_name='11-4um_Oceanic_Stratus_Test')

        # Group 3
        cmin_G3 = self.nir_reflectance_test('M07', cmin_G3, test_name='NIR_Reflectance_Test')
        cmin_G3 = self.vis_nir_ratio_test('M07-M05ratio', cmin_G3, test_name='Vis/NIR_Ratio_Test')
        cmin_G3 = self.nir_reflectance_test('M10', cmin_G3, test_name='1.6_2.1um_NIR_Reflectance_Test')
        cmin_G3 = self.visible_reflectance_test('M128', cmin_G3, test_name='Visible_Reflectance_Test')
        cmin_G3 = self.gemi_test('GEMI', cmin_G3, test_name='GEMI_Test')

        # Group 4
        cmin_G4 = self.test_1_38um_high_clouds('M09', cmin_G4, test_name='1.38um_High_Cloud_Test')

        # Group 5
        cmin_G5 = self.thin_cirrus_4_12um_BTD_test('M13-M16', cmin_G5,
                                                   test_name='4-12um_BTD_Thin_Cirrus_Test')

        cmin = cmin_G1 * cmin_G2 * cmin_G3 * cmin_G4 * cmin_G5

        return cmin

    def clear_sky_restoral(self,
                           cmin: np.ndarray) -> np.ndarray:

        total_bit = np.full(self.data.M01.shape, 3)
        sunglint_angle = self.thresholds['Sun_Glint']['bounds'][3]
        scene_flags = scn.find_scene(self.data, sunglint_angle)

        cmin_tmp = cmin + 0
        idx = np.nonzero((scene_flags['water'] == 1) & (scene_flags['ice'] == 0) &
                         (scene_flags['uniform'] == 1) & (cmin <= 0.99) & (cmin >= 0.05))
        cmin[idx] = restoral.spatial(self.data, self.thresholds['Sun_Glint'], scene_flags, cmin_tmp)[idx]

        cmin_tmp = cmin + 0
        idx = np.nonzero((scene_flags['water'] == 1) & (scene_flags['sunglint'] == 1) &
                         (scene_flags['uniform'] == 1) & (cmin <= 0.95))
        cmin[idx] = restoral.sunglint(self.data, self.thresholds['Sun_Glint'], total_bit, cmin_tmp)[idx]

        """
        idx = np.nonzero((scene_flags['day'] == 1) & (scene_flags['land'] == 1) &
                         (scene_flags['snow'] == 0) & (scene_flags['ice'] == 0) &
                         (cmin <= 0.95))
        cmin[idx] = restoral.land(self.data, self.thresholds, scene_flags, cmin)[idx]

        idx = np.nonzero((scene_flags['day'] == 1) & (scene_flags['land'] == 1) &
                         (scene_flags['coast'] == 1) & (scene_flags['snow'] == 0) &
                         (scene_flags['ice'] == 0))
        cmin[idx] = restoral.coast(self.data, self.thresholds, scene_flags, cmin)[idx]
        """

        return cmin


def preproc_thresholds(thresholds, data):
    thr = np.array(thresholds)
    thr_xr = xr.Dataset()
    thr_xr['tresholds'] = (('number_of_lines', 'number_of_pixels', 'z'),
                           np.ones((data['M01'].shape[0], data['M01'].shape[1], 5))*thr)

    nl_sfct1 = thresholds['Land_Night']['Surface_Temperature_Test'][0]
#    nl_sfct2 = thresholds['Land_Night']['Surface_Temperature_Test'][1]
#    nlsfct_pfm = thresholds['Land_Night']['Surface_Temperature_Test'][2]
    nl_df1 = thresholds['Land_Night']['Surface_Temperature_Test_difference'][0:2]
    nl_df2 = thresholds['Land_Night']['Surface_Temperature_Test_difference'][2:]

#    df1 = data.M15 - data.M16
#    df2 = data.M15 - data.M13
    thr_xr = thr_xr.where(data.desert != 1, nl_sfct1)
    thr_xr = thr_xr.where((data['M15-M16'] > nl_df1[0]) |
                          ((data['M15-M16'] < nl_df1[0]) &
                           ((data['M15-M13'] <= nl_df2[0]) | (data['M15-M13'] >= nl_df2[1]))),
                          nl_sfct1[0])

    data = xr.Dataset(data, coords=thr_xr)

    return data


def single_threshold_test(test, rad, threshold):

    radshape = rad.shape
    rad = rad.reshape(np.prod(radshape))

    thr = np.array(threshold[test])
    confidence = np.zeros(rad.shape)

    if thr[4] == 1:
        print(f"{test} test running")
        # the C code has the line below that I don't quite understand the purpose of.
        # It seems to be setting the bit to 0 if the BT value is greater than the midpoint
        #
        # if (m31 >= dobt11[1]) (void) set_bit(13, pxout.testbits);

        # confidence = utils.conf_test(rad, thr)
        confidence = conf.conf_test(rad, thr)

    return confidence.reshape(radshape)


def test():
    rad = np.random.randint(50, size=[4, 8])
    # coeffs = [5, 42, 20, 28, 15, 35, 1]
    # coeffs = [20, 28, 5, 42, 15, 35, 1]
    coeffs = [35, 15, 20, 1, 1]
    # confidence = conf_test_dble(rad, coeffs)
    confidence = test_11um(rad, coeffs)
    print(rad)
    print('\n')
    print(confidence)


if __name__ == "__main__":
    test()