From 51a158f9e20274f81cb420dfb49a93bd547dc27b Mon Sep 17 00:00:00 2001
From: Paolo Veglio <paolo.veglio@ssec.wisc.edu>
Date: Wed, 26 Oct 2022 20:45:35 +0000
Subject: [PATCH] reworked the single threshold function to work better with
 xarray

---
 conf_xr.py | 186 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 main.py    |  51 +++++++++++----
 scene.py   |  64 +++++++++---------
 tests.py   |  40 +++++++++++-
 4 files changed, 295 insertions(+), 46 deletions(-)
 create mode 100644 conf_xr.py

diff --git a/conf_xr.py b/conf_xr.py
new file mode 100644
index 0000000..655f687
--- /dev/null
+++ b/conf_xr.py
@@ -0,0 +1,186 @@
+import numpy as np
+import xarray as xr
+
+
+def test(flipped=False):
+    bt = np.arange(265, 275)
+    if flipped is False:
+        thr = np.array([267, 270, 273, 1, 1])
+    else:
+        thr = np.array([273, 270, 267, 1, 1])
+    c = conf_test(bt, thr)
+    print(c)
+
+
+def test_dble(flipped=False):
+    bt = np.arange(260, 282)
+    if flipped is False:
+        thr = np.array([264, 267, 270, 273, 276, 279, 1, 1])
+    else:
+        thr = np.array([279, 276, 273, 270, 267, 264, 1, 1])
+    c = conf_test_dble(bt, thr)
+    print(c)
+
+
+def conf_test(data, band):
+    '''
+    Assuming a linear function between min and max confidence level, the plot below shows
+    how the confidence (y axis) is computed as function of radiance (x axis).
+    This case illustrates alpha < gamma, obviously in case alpha > gamma, the plot would be
+    flipped.
+                       gamma
+    c  1                 ________
+    o  |                /
+    n  |               /
+    f  |              /
+    i  |     beta    /
+    d 1/2    |....../
+    e  |           /
+    n  |          /
+    c  |         /
+    e  0________/
+       |      alpha
+    --------- radiance ---------->
+    '''
+
+    hicut = data.threshold[:, :, 2]
+    beta = data.threshold[:, :, 1]
+    locut = data.threshold[:, :, 0]
+    power = data.threshold[:, :, 3]
+    coeff = np.power(2, (power - 1))
+
+    gamma = data.threshold.where(hicut > locut, data.threshold[:, :, 0])[:, :, 2]
+    alpha = data.threshold.where(hicut > locut, data.threshold[:, :, 2])[:, :, 0]
+    flipped = xr.zeros_like(data[band]).where(hicut > locut, 1)
+
+    # Rad between alpha and beta
+    range_ = 2. * (beta - alpha)
+    s1 = (data[band].values - alpha)/range_
+    conf_tmp1 = (coeff * np.power(s1, power)).where((data[band] <= beta) & (flipped == 0))
+    conf_tmp2 = (1.0 - coeff * np.power(s1, power)).where((data[band] <= beta) & (flipped == 1))
+    conf_tmp12 = conf_tmp1.where(flipped == 0, conf_tmp2)
+
+    # Rad between beta and gamma
+    range_ = 2. * (beta - gamma)
+    s1 = (data[band].values - gamma)/range_
+    conf_tmp3 = (1.0 - coeff * np.power(s1, power)).where((data[band] <= beta) & (flipped == 0))
+    conf_tmp4 = (coeff * np.power(s1, power)).where((data[band] <= beta) & (flipped == 1))
+    conf_tmp34 = conf_tmp3.where(flipped == 0, conf_tmp4)
+
+    confidence = conf_tmp12.where(data[band] <= beta, conf_tmp34)
+
+    confidence = confidence.where(confidence > 0, 0)
+    confidence = confidence.where(confidence < 1, 1)
+
+    return confidence
+
+
+def conf_test_dble(rad, coeffs):
+    # '''
+    #            gamma1                         gamma2
+    #    c  1_______                               ________
+    #    o  |       \                             /
+    #    n  |        \                           /
+    #    f  |         \                         /
+    #    i  |          \   beta1       beta2   /
+    #    d 1/2          \....|          |...../
+    #    e  |            \                   /
+    #    n  |             \                 /
+    #    c  |              \               /
+    #    e  0               \_____________/
+    #       |             alpha1       alpha2
+    #    --------------------- radiance ------------------------->
+    # '''
+
+    coeffs = np.array(coeffs)
+    radshape = rad.shape
+    rad = rad.reshape(np.prod(radshape))
+    confidence = np.zeros(rad.shape)
+
+    alpha1, gamma1 = np.empty(rad.shape), np.empty(rad.shape)
+    alpha2, gamma2 = np.empty(rad.shape), np.empty(rad.shape)
+
+    if coeffs.ndim == 1:
+        coeffs = np.full((rad.shape[0], 7), coeffs[:7]).T
+
+    gamma1 = coeffs[0, :]
+    beta1 = coeffs[1, :]
+    alpha1 = coeffs[2, :]
+    alpha2 = coeffs[3, :]
+    beta2 = coeffs[4, :]
+    gamma2 = coeffs[5, :]
+    power = coeffs[6, :]
+
+    coeff = np.power(2, (power - 1))
+    # radshape = rad.shape
+    # rad = rad.reshape((rad.shape[0]*rad.shape[1]))
+
+    # ## Find if interval between inner cutoffs passes or fails test
+
+    # Inner region fails test
+
+    # Value is within range of lower set of limits
+    range_ = 2. * (beta1 - alpha1)
+    s1 = (rad - alpha1) / range_
+    idx = np.nonzero((rad <= alpha1) & (rad >= beta1) & (alpha1 - gamma1 > 0))
+    confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx])
+
+    range_ = 2. * (beta1 - gamma1)
+    s1 = (rad - gamma1) / range_
+    idx = np.nonzero((rad >= gamma1) & (rad < beta1) & (alpha1 - gamma1 > 0))
+    confidence[idx] = 1.0 - coeff[idx] * np.power(s1[idx], power[idx])
+
+    # Value is within range of upper set of limits
+    range_ = 2. * (beta2 - alpha2)
+    s1 = (rad - alpha2) / range_
+    idx = np.nonzero((rad > alpha1) & (rad <= beta2) & (alpha1 - gamma1 > 0))
+    confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx])
+
+    range_ = 2. * (beta2 - gamma2)
+    s1 = (rad - gamma2) / range_
+    idx = np.nonzero((rad > alpha1) & (rad > beta2) & (alpha1 - gamma1 > 0))
+    confidence[idx] = 1.0 - coeff[idx] * np.power(s1[idx], power[idx])
+
+    # Check for value beyond function range
+    confidence[(alpha1 - gamma1 > 0) & (rad > alpha1) & (rad < alpha2)] = 0
+    confidence[(alpha1 - gamma1 > 0) & ((rad < gamma1) | (rad > gamma2))] = 1
+
+    ###
+
+    # Inner region passes test
+    print("I NEED TO REVIEW THIS TO WRITE IT MORE CLEARLY")
+    # FOR NOW ALPHA AND GAMMA ARE SWITCHED BECAUSE OF HOW THE ARRAYS ARE DEFINED.
+    # THINK ON HOW THIS COULD BE WRITTEN SO THAT IT'S EASIER TO UNDERSTAND (AND DEBUG)
+    # Value is within range of lower set of limits
+    range_ = 2 * (beta1 - alpha1)
+    s1 = (rad - alpha1) / range_
+    idx = np.nonzero((rad > alpha1) & (rad <= gamma1) & (rad <= beta1) & (alpha1 - gamma1 <= 0))
+    confidence[idx] = 1.0 - coeff[idx] * np.power(s1[idx], power[idx])
+
+    range_ = 2 * (beta1 - gamma1)
+    s1 = (rad - gamma1) / range_
+    idx = np.nonzero((rad > alpha1) & (rad <= gamma1) & (rad > beta1) & (alpha1 - gamma1 <= 0))
+    confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx])
+
+    # Values is within range of upper set of limits
+    range_ = 2 * (beta2 - alpha2)
+    s1 = (rad - alpha2) / range_
+    idx = np.nonzero((rad > gamma2) & (rad < alpha2) & (rad >= beta2) & (alpha1 - gamma1 <= 0))
+    confidence[idx] = 1.0 - coeff[idx] * np.power(s1[idx], power[idx])
+
+    range_ = 2 * (beta2 - gamma2)
+    s1 = (rad - gamma2) / range_
+    idx = np.nonzero((rad > gamma2) & (rad < alpha2) & (rad < beta2) & (alpha1 - gamma1 <= 0))
+    confidence[idx] = coeff[idx] * np.power(s1[idx], power[idx])
+
+    confidence[(alpha1 - gamma1 <= 0) & ((rad > gamma1) | (rad < gamma2))] = 0
+    confidence[(alpha1 - gamma1 <= 0) & (rad <= alpha1) & (rad >= alpha2)] = 1
+
+    confidence[confidence > 1] = 1
+    confidence[confidence < 0] = 0
+
+    return confidence
+
+
+if __name__ == "__main__":
+    test_dble()
diff --git a/main.py b/main.py
index 8439c89..c74fa58 100644
--- a/main.py
+++ b/main.py
@@ -1,11 +1,12 @@
 import ruamel_yaml as yml
 import numpy as np
-# import xarray as xr
+import xarray as xr
 
 from glob import glob
 
 import read_data as rd
-from tests import CloudTests
+import scene as scn
+from tests import CloudTests_new
 
 # import tests
 import ocean_day_tests as odt
@@ -68,20 +69,47 @@ def main(*, data_path=_datapath, mod02=_fname_mod02, mod03=_fname_mod03,
 
     viirs_data = rd.get_data(file_names, sunglint_angle)
 
+    # scene_xr = xr.Dataset()
+    # for s in scn._scene_list:
+    #    scene_xr[s] = (('number_of_lines', 'number_of_pixels'), scn.scene_id[s])
+    # scene_xr['latitude'] = viirs_xr.latitude
+    # scene_xr['longitude'] = viirs_xr.longitude
+    #
+    # viirs_data = xr.Dataset(viirs_xr, coords=scene_xr)
+    # viirs_data.drop_vars(['latitude', 'longitude'])
+
     cmin_G1 = np.ones(viirs_data.M01.shape)
+    cmin_test = {'Ocean_Day': np.ones(viirs_data.M01.shape),
+                 'Polar_Ocean_Day': np.ones(viirs_data.M01.shape),
+                 'Polar_Ocean_Night': np.ones(viirs_data.M01.shape)
+                 }
     cmin2 = np.ones(viirs_data.M01.shape)
     cmin3 = np.ones(viirs_data.M01.shape)
     cmin4 = np.ones(viirs_data.M01.shape)
 
-    Ocean_Day = CloudTests(viirs_data, 'Ocean_Day', thresholds)
-    Polar_Ocean_Day = CloudTests(viirs_data, 'Polar_Ocean_Day', thresholds)
-    Polar_Ocean_Night = CloudTests(viirs_data, 'Polar_Ocean_Night', thresholds)
-
-
-    cmin_G1 = Ocean_Day.single_threshold_test('11BT_Test', viirs_data.M15.values, cmin_G1)
-    cmin_G1 = Polar_Ocean_Day.single_threshold_test('11BT_Test', viirs_data.M15.values, cmin_G1)
-    cmin_G1 = Polar_Ocean_Night.single_threshold_test('11BT_Test', viirs_data.M15.values, cmin_G1)
-
+    Ocean_Day = CloudTests_new(viirs_data, 'Ocean_Day', thresholds)
+    Polar_Ocean_Day = CloudTests_new(viirs_data, 'Polar_Ocean_Day', thresholds)
+    Polar_Ocean_Night = CloudTests_new(viirs_data, 'Polar_Ocean_Night', thresholds)
+
+    # Land_Day = CloudTests(viirs_data, 'Land_Day', thresholds)
+    # Night_Snow = CloudTests(viirs_data, 'Night_Snow', thresholds)
+    # Day_Snow = CloudTests(viirs_data, 'Day_Snow', thresholds)
+    # Land_Night = CloudTests(viirs_data, 'Land_Night', thresholds)
+    # Land_Day_Coast = CloudTests(viirs_data, 'Land_Day_Coast', thresholds)
+    # Land_Day_Desert = CloudTests(viirs_data, 'Land_Day_Desert', thresholds)
+    # Land_Day_Desert_Coast = CloudTests(viirs_data, 'Land_Day_Desert_Coast', thresholds)
+
+    # 11um BT Test
+    cmin_test['Ocean_Day'] = Ocean_Day.single_threshold_test('11um_Test', 'M15', cmin_G1)
+    cmin_test['Polar_Ocean_Day'] = Polar_Ocean_Day.single_threshold_test('11um_Test', 'M15', cmin_G1)
+    cmin_test['Polar_Ocean_Night'] = Polar_Ocean_Night.single_threshold_test('11um_Test', 'M15', cmin_G1)
+
+    return cmin_test
+    '''
+    # CO2 High Cloud Test
+    # cmin_G1 = Land_Day
+
+    # 11-12um BT Difference
     cmin_G1 = Ocean_Day.single_threshold_test('11-12BT_diff',
                                               viirs_data.M15.values-viirs_data.M16.values,
                                               cmin_G1)
@@ -146,6 +174,7 @@ def main(*, data_path=_datapath, mod02=_fname_mod02, mod03=_fname_mod03,
              lat=viirs_data.latitude.values, lon=viirs_data.longitude.values)
 
     return confidence
+    '''
 
 
 def test_main():
diff --git a/scene.py b/scene.py
index c28538d..ddfac88 100644
--- a/scene.py
+++ b/scene.py
@@ -7,10 +7,10 @@ import read_data as rd
 import ancillary_data as anc
 
 # lsf: land sea flag
-_scene_list = ['ocean_day', 'ocean_night', 'land_day', 'land_night', 'snow_day', 'snow_night', 'coast_day',
-               'desert_day', 'antarctic_day', 'polar_day_snow', 'polar_day_desert', 'polar_day_ocean',
-               'polar_day_desert_coast', 'polar_day_coast', 'polar_day_land', 'polar_night_snow',
-               'polar_night_land', 'polar_night_ocean', 'land_day_desert_coast']
+_scene_list = ['Ocean_Day', 'Ocean_Night', 'Land_Day', 'Land_Night', 'Snow_Day', 'Snow_Night', 'Coast_Day',
+               'Desert_Day', 'Antarctic_Day', 'Polar_Day_Snow', 'Polar_Day_Desert', 'Polar_Day_Ocean',
+               'Polar_Day_Desert_Coast', 'Polar_Day_Coast', 'Polar_Day_Land', 'Polar_Night_Snow',
+               'Polar_Night_Land', 'Polar_Night_Ocean', 'Land_Day_Desert_Coast', 'Land_Day_Coast']
 _flags = ['day', 'night', 'land', 'coast', 'sh_lake', 'sh_ocean', 'water', 'polar', 'sunglint',
           'greenland', 'high_elevation', 'antarctica', 'desert', 'visusd', 'vrused', 'map_snow', 'map_ice',
           'ndsi_snow', 'snow', 'ice', 'new_zealand', 'uniform']
@@ -23,7 +23,7 @@ _rtd = 180./np.pi
 
 # I'm defining here the flags for difference scenes. Eventually I want to find a better way of doing this
 land = 1
-#coast = .2
+# coast = .2
 sh_lake = .3
 sh_ocean = .4
 water = 5
@@ -235,13 +235,13 @@ def find_scene(data, sunglint_angle):
     perm_ice_fraction = data['geos_landicefr']
     ice_fraction = data['geos_icefr']
 
-    idx = np.nonzero((snow_fraction > 0.10) & (snow_fraction <= 1.0))
+    idx = tuple(np.nonzero((snow_fraction > 0.10) & (snow_fraction <= 1.0)))
     scene_flag['map_snow'][idx] = 1
 
-    idx = np.nonzero((perm_ice_fraction > 0.10) & (perm_ice_fraction <= 1.0))
+    idx = tuple(np.nonzero((perm_ice_fraction > 0.10) & (perm_ice_fraction <= 1.0)))
     scene_flag['map_snow'][idx] = 1
 
-    idx = np.nonzero((ice_fraction > 0.10) & (ice_fraction <= 1.0))
+    idx = tuple(np.nonzero((ice_fraction > 0.10) & (ice_fraction <= 1.0)))
     scene_flag['map_ice'][idx] = 1
 
     # need to define this function and write this block better
@@ -316,118 +316,118 @@ def scene_id(scene_flag):
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                      (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['coast'] == 0) & (scene_flag['desert'] == 0))
-    scene['ocean_day'][idx] = 1
+    scene['Ocean_Day'][idx] = 1
 
     # Ocean Night
     idx = np.nonzero((scene_flag['water'] == 1) & (scene_flag['night'] == 1) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                      (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['coast'] == 0) & (scene_flag['desert'] == 0))
-    scene['ocean_night'][idx] = 1
+    scene['Ocean_Night'][idx] = 1
 
     # Land Day
     idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                      (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['coast'] == 0) & (scene_flag['desert'] == 0))
-    scene['land_day'][idx] = 1
+    scene['Land_Day'][idx] = 1
 
     # Land Night
     idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['night'] == 1) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                      (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['coast'] == 0))
-    scene['land_night'][idx] = 1
+    scene['Land_Night'][idx] = 1
 
     # Snow Day
     idx = np.nonzero((scene_flag['day'] == 1) &
                      ((scene_flag['ice'] == 1) | (scene_flag['snow'] == 1)) &
-                     (scene_flag['polar']) & (scene_flag['antarctica']))
-    scene['snow_day'][idx] = 1
+                     (scene_flag['polar'] == 1) & (scene_flag['antarctica'] == 1))
+    scene['Snow_Day'][idx] = 1
 
     # Snow Night
     idx = np.nonzero((scene_flag['night'] == 1) &
                      ((scene_flag['ice'] == 1) | (scene_flag['snow'] == 1)) &
-                     (scene_flag['polar']) & (scene_flag['antarctica']))
-    scene['snow_night'][idx] = 1
+                     (scene_flag['polar'] == 1) & (scene_flag['antarctica'] == 1))
+    scene['Snow_Night'][idx] = 1
 
     # Land Day Coast
     idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                      (scene_flag['coast'] == 1) & (scene_flag['desert'] == 0) &
                      (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0))
-    scene['land_day_coast'][idx] = 1
+    scene['Land_Day_Coast'][idx] = 1
 
     # Land Day Desert
     idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                      (scene_flag['desert'] == 1) & (scene_flag['coast'] == 0) &
                      (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0))
-    scene['desert_day'][idx] = 1
+    scene['Desert_Day'][idx] = 1
 
     # Land Day Desert Coast
     idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                      (scene_flag['desert'] == 1) & (scene_flag['coast'] == 1) &
                      (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0))
-    scene['land_day_desert_coast'][idx] = 1
+    scene['Land_Day_Desert_Coast'][idx] = 1
 
     # Antarctic Day
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                      (scene_flag['antarctica'] == 1) & (scene_flag['land'] == 1))
-    scene['antarctic_day'][idx] = 1
+    scene['Antarctic_Day'][idx] = 1
 
     # Polar Day Snow
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                      ((scene_flag['snow'] == 1) | (scene_flag['ice'] == 1)) &
                      (scene_flag['antarctica'] == 0))
-    scene['polar_day_snow'][idx] = 1
+    scene['Polar_Day_Snow'][idx] = 1
 
     # Polar Day Desert
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                      (scene_flag['land'] == 1) & (scene_flag['desert'] == 1) &
                      (scene_flag['coast'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
-    scene['polar_day_desert'][idx] = 1
+    scene['Polar_Day_Desert'][idx] = 1
 
     # Polar Day Desert Coast
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                      (scene_flag['land'] == 1) & (scene_flag['desert'] == 1) &
                      (scene_flag['coast'] == 1) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
-    scene['polar_day_desert_coast'][idx] = 1
+    scene['Polar_Day_Desert_Coast'][idx] = 1
 
     # Polar Day Coast
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
-                     (scene_flag['land'] == 1) & (scene_flag['coast'] == 1)
+                     (scene_flag['land'] == 1) & (scene_flag['coast'] == 1) &
                      (scene_flag['desert'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
-    scene['polar_day_coast'][idx] = 1
+    scene['Polar_Day_Coast'][idx] = 1
 
     # Polar Day Land
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
-                     (scene_flag['land'] == 1) & (scene_flag['coast'] == 0)
+                     (scene_flag['land'] == 1) & (scene_flag['coast'] == 0) &
                      (scene_flag['desert'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
-    scene['polar_day_land'][idx] = 1
+    scene['Polar_Day_Land'][idx] = 1
 
     # Polar Day Ocean
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
-                     (scene_flag['water'] == 1) & (scene_flag['coast'] == 0)
+                     (scene_flag['water'] == 1) & (scene_flag['coast'] == 0) &
                      (scene_flag['desert'] == 0) & (scene_flag['antarctica'] == 0) &
                      (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
-    scene['polar_day_ocean'][idx] = 1
+    scene['Polar_Day_Ocean'][idx] = 1
 
     # Polar Night Snow
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['night'] == 1) &
                      ((scene_flag['snow'] == 1) | (scene_flag['ice'] == 1)))
-    scene['polar_night_snow'][idx] = 1
+    scene['Polar_Night_Snow'][idx] = 1
 
     # Polar Night Land
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['night'] == 1) &
                      (scene_flag['land'] == 1))
-    scene['polar_night_land'][idx] = 1
+    scene['Polar_Night_Land'][idx] = 1
 
     # Polar Night Ocean
     idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['night'] == 1) &
                      (scene_flag['water'] == 1))
-    scene['polar_night_ocean'][idx] = 1
+    scene['Polar_Night_Ocean'][idx] = 1
 
     return scene
diff --git a/tests.py b/tests.py
index 76b4c69..39860e5 100644
--- a/tests.py
+++ b/tests.py
@@ -1,9 +1,11 @@
 import numpy as np
+import xarray as xr
 
 from numpy.lib.stride_tricks import sliding_window_view
 
 import utils
 import conf
+import conf_xr
 
 
 # ############## GROUP 1 TESTS ############## #
@@ -224,18 +226,50 @@ class CloudTests:
         rad = rad.reshape(np.prod(radshape))
 
         thr = np.array(self.threshold[test_name])
-        confidence = np.zeros(rad.shape)
+        confidence = np.zeros(radshape)
 
         if thr[4] == 1:
             print('test running')
-            confidence = conf.conf_test(rad, thr)
+            confidence[self.idx] = conf.conf_test(rad[self.idx], thr)
 
-        return np.minimum(cmin, confidence)
+        cmin[self.idx] = np.minimum(cmin[self.idx], confidence[self.idx])
+        return cmin
 
     def double_threshold_test(self):
         pass
 
 
+# new class to try to use xarray more extensively
+class CloudTests_new:
+
+    def __init__(self, data, scene_name, thresholds):
+        self.data = data
+        self.scene_name = scene_name
+        self.thresholds = thresholds
+
+    def single_threshold_test(self, test_name, band, cmin):
+
+        # preproc_thresholds()
+        thr = np.array(self.thresholds[self.scene_name][test_name])
+        thr_xr = xr.Dataset()
+        thr_xr['threshold'] = (('number_of_lines', 'number_of_pixels', 'z'),
+                               np.ones((self.data[band].shape[0], self.data[band].shape[1], 5))*thr)
+        data = xr.Dataset(self.data, coords=thr_xr)
+
+        if thr[4] == 1:
+            print('test running')
+            confidence = conf_xr.conf_test(data, band)
+
+        cmin = np.fmin(cmin, confidence)
+
+        return cmin
+
+
+# single_threshold_test('11BT', 'M15', cmin)
+# single_threshold_test('12-11BT', 'M16-M15', cmin)
+# single_threshold_test('12BT', 'M16', cmin)
+
+
 def single_threshold_test(test, rad, threshold):
 
     radshape = rad.shape
-- 
GitLab