diff --git a/mvcm/constants.py b/mvcm/constants.py
index 47a017c2e86216ccf532639ff4651de1013267e8..962ce2eef1a1a2d758a1768be9f95f6d5ff109c2 100644
--- a/mvcm/constants.py
+++ b/mvcm/constants.py
@@ -21,6 +21,7 @@ class ConstantsNamespace:
 class SensorConstants:
     """Sensor-dependent constants."""
 
+    MAX_VZA = 70.13
     VIIRS_VIS_BANDS = ("M01", "M02", "M03", "M04", "M05", "M06", "M07", "M08", "M09", "M10")
     VIIRS_IR_BANDS = ("M11", "M12", "M13", "M14", "M15", "M16")
     VIIRS_IMG_BANDS = ("I01", "I02", "I03", "I04", "I05")
diff --git a/mvcm/main_tests_only.py b/mvcm/main_tests_only.py
index df0e7bc7e70066750fc4161c62c544276a788bab..7a94f57a3e3a1778226b8e8660538aafb99c4b35 100644
--- a/mvcm/main_tests_only.py
+++ b/mvcm/main_tests_only.py
@@ -163,6 +163,7 @@ def main(
     viirs_data["hicut"] = np.full_like(
         viirs_data.latitude.shape, fill_value=np.nan, dtype=np.float32
     )
+    viirs_data["high_elevation"] = pixel_type.high_elevation
     logger.info(f"Memory usage #5: {proc.memory_info().rss / 1e6} MB")
     # scene_types = np.zeros(viirs_data.M11.shape, dtype=np.ubyte)
     for scene_name in SceneConstants.SCENE_LIST:
@@ -205,7 +206,9 @@ def main(
         cmin_g1, bits["01"] = my_scene.test_11um(m15_name, cmin_g1, bits["01"])
         # this test needs to be implemented properly,
         # for now is commented since it's only used for night granules
-        # cmin_g1, bits["02"] = my_scene.surface_temperature_test("", viirs_data, cmin_g1, bits["02"])
+        cmin_g1, bits["02"] = my_scene.surface_temperature_test(
+            "M15", viirs_data, cmin_g1, bits["02"]
+        )
         cmin_g1, bits["03"] = my_scene.sst_test("M15", "M16", cmin_g1, bits["03"])
 
         # Group 2
diff --git a/mvcm/preprocess_thresholds.py b/mvcm/preprocess_thresholds.py
index 8339b0e4ce46e1caec5210de4427012f4a35b1da..dc0f1d777607451c214ca061ce67c021b96a4178 100644
--- a/mvcm/preprocess_thresholds.py
+++ b/mvcm/preprocess_thresholds.py
@@ -236,8 +236,12 @@ def thresholds_11_12um(
         err_msg = "Scene name not valid"
         raise ValueError(err_msg)
 
-    thr_out = np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)))
-    return np.squeeze(thr_out.T)
+    thr_out = np.squeeze(
+        np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)))
+    )
+    if thr_out.ndim == 1:
+        thr_out = np.reshape(thr_out, (1, thr_out.shape[0]))
+    return thr_out.T
 
 
 def thresholds_NIR(data, thresholds, scene, test_name, scene_idx):
@@ -301,37 +305,54 @@ def thresholds_NIR(data, thresholds, scene, test_name, scene_idx):
 def thresholds_surface_temperature(data, thresholds, scene_idx):
     """Compute thresholds for surface temperature test."""
     # def preproc_surf_temp(data, thresholds):
-    thr_sfc1 = thresholds["desert_thr"]
-    thr_sfc2 = thresholds["regular_thr"]
+    desert_thr = thresholds["desert_thr"]
+    regular_thr = thresholds["regular_thr"]
     thr_df1 = thresholds["channel_diff_11-12um_thr"]
     thr_df2 = thresholds["channel_diff_11-4um_thr"]
     max_vza = 70.13  # This values is set based on sensor.
     #                  Check mask_processing_constants.h for MODIS value
 
-    df1 = (data.M15 - data.M16).values[scene_idx].ravel()
-    df2 = (data.M15 - data.M13).values[scene_idx].ravel()
+    diff_11_12um = (data.M15 - data.M16).values[scene_idx].ravel()
+    diff_11_4um = (data.M15 - data.M12).values[scene_idx].ravel()
     desert_flag = data.Desert.values[scene_idx].ravel()
-    thresh = np.ones(df1.shape) * thr_sfc1
+    thresh = np.full(diff_11_12um.shape, desert_thr)
 
-    idx = np.where(
-        (df1 >= thr_df1[0]) | ((df1 < thr_df1[0]) & ((df2 <= thr_df2[0]) | (df2 >= thr_df2[1])))
+    # idx = np.nonzero(desert_flag == 1)
+    # thresh[idx] = desert_thr
+    idx = np.nonzero(
+        (desert_flag == 0)
+        & (
+            (diff_11_12um >= thr_df1[0])
+            | (
+                (diff_11_12um < thr_df1[0])
+                & ((diff_11_4um <= thr_df2[0]) | (diff_11_4um >= thr_df2[1]))
+            )
+        )
     )
-    thresh[idx] = thr_sfc2
-    idx = np.where(desert_flag == 1)
-    thresh[idx] = thr_sfc1
+    thresh[idx] = regular_thr
 
+    # thresh = 0
+
+    # # midpt = thresh
+    # midpt = np.where(diff_11_12um >= thr_df1[1], thresh + (2.0 * np.round(diff_11_12um)), thresh)
+    # # midpt[idx] = thresh[idx] + (2.0 * np.round(df1[idx]))
     midpt = thresh
-    idx = np.where(df1 >= thr_df1[1])
-    midpt[idx] = thresh[idx] + 2.0 * df1[idx]
+    idx = np.nonzero(diff_11_12um >= thr_df1[1])
+    midpt[idx] = thresh[idx] + (2.0 * np.round(diff_11_12um[idx]))
 
     corr = np.power(data.sensor_zenith.values[scene_idx].ravel() / max_vza, 4) * 3.0
     midpt = midpt + corr
     locut = midpt + 2.0
     hicut = midpt - 2.0
 
-    thr_out = np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)))
+    thr_out = np.squeeze(
+        np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)))
+    )
+
+    if thr_out.ndim == 1:
+        thr_out = np.reshape(thr_out, (1, thr_out.shape[0]))
 
-    return np.squeeze(thr_out.T)
+    return thr_out.T
 
 
 # This function is currently not used
@@ -589,9 +610,11 @@ def polar_night_thresholds(data, thresholds, scene, test_name, scene_idx):
     # out_thr = xr.DataArray(data=np.dstack((locut, midpt, hicut,
     #                                        np.ones(data.ndvi.shape), power)),
     #                        dims=('number_of_lines', 'number_of_pixels', 'z'))
-    out_thr = np.dstack((locut, midpt, hicut, np.ones(locut.shape), power))
+    out_thr = np.squeeze(np.dstack((locut, midpt, hicut, np.ones(locut.shape), power)))
+    if out_thr.ndim == 1:
+        out_thr = np.reshape(out_thr, (1, out_thr.shape[0]))
 
-    return np.squeeze(out_thr.T)
+    return out_thr.T
 
 
 # get_nl_thresholds
@@ -632,6 +655,10 @@ def land_night_thresholds(m15_m16, geos_tpw, threshold, coast=True):
                 power.ravel(),
             )
         )
+
+        if out_thr.ndim == 1:
+            out_thr = np.reshape(out_thr, (1, out_thr.shape[0]))
+
         return np.squeeze(out_thr.T)
     else:
         b0 = threshold["coeffs"][0]
@@ -736,7 +763,7 @@ def vis_refl_thresholds(data, thresholds, scene, scene_idx):
     #                        dims=('number_of_lines', 'number_of_pixels', 'z'))
     # out_rad = xr.DataArray(data=m128.reshape(data.M01.shape),
     #                        dims=('number_of_lines', 'number_of_pixels'))
-    out_thr = np.dstack((b1_locut, b1_midpt, b1_hicut, b1_power))
+    out_thr = np.squeeze(np.dstack((b1_locut, b1_midpt, b1_hicut, b1_power)))
 
     out_rad = m128
 
@@ -744,7 +771,10 @@ def vis_refl_thresholds(data, thresholds, scene, scene_idx):
     # data.midpt.values[scene_idx] = b1_midpt
     # data.hicut.values[scene_idx] = b1_hicut
 
-    return np.squeeze(out_thr.T), out_rad
+    if out_thr.ndim == 1:
+        out_thr = np.reshape(out_thr, (1, out_thr.shape[0]))
+
+    return out_thr.T, out_rad
 
 
 def gemi_thresholds(data, thresholds, scene_name, scene_idx):
@@ -782,19 +812,24 @@ def bt_diff_11_4um_thresholds(data, threshold, scene_idx):
     locut0 = hicut0 + threshold["locut_coeff"][0]
     locut1 = hicut1 + threshold["locut_coeff"][1]
 
-    thr_out = np.dstack(
-        [
-            locut0,
-            midpt0,
-            hicut0,
-            hicut1,
-            midpt1,
-            locut1,
-            np.ones(hicut0.shape),
-            np.ones(hicut0.shape),
-        ]
+    thr_out = np.squeeze(
+        np.dstack(
+            [
+                locut0,
+                midpt0,
+                hicut0,
+                hicut1,
+                midpt1,
+                locut1,
+                np.ones(hicut0.shape),
+                np.ones(hicut0.shape),
+            ]
+        )
     )
-    return np.squeeze(thr_out.T)
+    if thr_out.ndim == 1:
+        thr_out = np.reshape(thr_out, (1, thr_out.shape[0]))
+
+    return thr_out.T
 
 
 def thresholds_1_38um_test(data, thresholds, scene_name, scene_idx):
@@ -829,9 +864,14 @@ def thresholds_1_38um_test(data, thresholds, scene_name, scene_idx):
     #                                        np.ones(data.ndvi.shape),
     #                                        np.ones(data.ndvi.shape))),
     #                        dims=('number_of_lines', 'number_of_pixels', 'z'))
-    out_thr = np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)))
+    out_thr = np.squeeze(
+        np.dstack((locut, midpt, hicut, np.ones(locut.shape), np.ones(locut.shape)))
+    )
+
+    if out_thr.ndim == 1:
+        out_thr = np.reshape(out_thr, (1, out_thr.shape[0]))
 
-    return np.squeeze(out_thr.T)
+    return out_thr.T
 
 
 def xr_thresholds_1_38um_test(data: xr.Dataset, thresholds: dict, scene_name: str):
diff --git a/mvcm/scene.py b/mvcm/scene.py
index 97eacd204ad3a237e38849b77205b14cfa7e98dd..ff5808ed034eb2c0697cd8af3a69de93002b63a2 100644
--- a/mvcm/scene.py
+++ b/mvcm/scene.py
@@ -612,7 +612,8 @@ def scene_id(
         (scene_flags.water == 1)
         & (scene_flags.night == 1)
         & (scene_flags.polar == 0)
-        & ((scene_flags.ice == 0) | (scene_flags.snow == 0)),
+        & (scene_flags.ice == 0),
+        # & (scene_flags.snow == 0),
         1,
         0,
     )
diff --git a/mvcm/spectral_tests.py b/mvcm/spectral_tests.py
index 0499906062e28743de50bbf5335a072a94cb6b36..541fcd37f3da415ca68470c3f47a3d3b3d0665be 100644
--- a/mvcm/spectral_tests.py
+++ b/mvcm/spectral_tests.py
@@ -15,7 +15,7 @@ from numpy.lib.stride_tricks import sliding_window_view
 import mvcm.preprocess_thresholds as preproc
 import mvcm.scene as scn
 from mvcm import conf, restoral
-from mvcm.constants import SceneConstants
+from mvcm.constants import SceneConstants, SensorConstants
 
 _DTR = np.pi / 180
 
@@ -249,25 +249,79 @@ class CloudTests:
     def surface_temperature_test(
         self, band: str, viirs_data: xr.Dataset, cmin: np.ndarray, bits: dict, **kwargs
     ) -> tuple[np.ndarray, dict]:
-        """Perform surface temperature test over land."""
+        """Perform surface temperature test over land at night."""
         threshold = kwargs["thresholds"]
 
         if threshold["perform"] is True and self.pixels_in_scene is True:
-            bits["qa"][self.scene_idx] = 1
             logger.info(f'Running test Surface_Temperature_Test on "{self.scene_name}"\n')
-            # print(f'Testing "{self.scene_name}"\n')
-            if self.scene_name in ["Land_Night"]:
-                pass
-            rad = self.data[band].values[self.scene_idx]
-            sfcdif = viirs_data.geos_sfct.values[self.scene_idx] - rad
-            thr = preproc.thresholds_surface_temperature(viirs_data, threshold, self.scene_idx)
 
-            # idx = np.nonzero((sfcdif < thr[1, :]) & (self.data[self.scene_name] == 1))
-            # kwargs['test_bit'][idx] = 1
-            if np.ndim(thr) == 1:
-                thr = thr.reshape(thr.shape[0], 1)
-            bits["test"][self.scene_idx] = self.compute_test_bits(sfcdif, thr[:, 1], "<")
-            kwargs["confidence"][self.scene_idx] = conf.conf_test_new(sfcdif, thr)
+            # diff_11_12um = self.data["M15"].values - self.data["M16"].values
+            # diff_11_4um = self.data["M15"].values - self.data["M12"].values
+            #
+            scene_idx = np.where(
+                (self.data[self.scene_name] == 1)
+                & (self.data.high_elevation == 0)
+                & (self.data["ecosystem"] != SceneConstants.ECO_BARE_DESERT)
+            )
+
+            # bits["qa"][scene_idx] = 1
+            # thr = np.full(self.data.M15.shape, threshold["desert_thr"])
+            #
+            # desert_idx = np.where(
+            #     (self.data[self.scene_name] == 1)
+            #     & (self.data["ecosystem"] != SceneConstants.ECO_BARE_DESERT)
+            #     & (self.data["Desert"] == 1)
+            # )
+            #
+            # non_desert_idx = np.where(
+            #     (self.data[self.scene_name] == 1)
+            #     & (self.data["ecosystem"] != SceneConstants.ECO_BARE_DESERT)
+            #     & (self.data["Desert"] == 0)
+            #     & (
+            #         (diff_11_12um >= threshold["channel_diff_11-12um_thr"][0])
+            #         | (
+            #             (diff_11_12um < threshold["channel_diff_11-12um_thr"][0])
+            #             & (
+            #                 (diff_11_4um > threshold["channel_diff_11-4um_thr"][0])
+            #                 | (diff_11_4um < threshold["channel_diff_11-4um_thr"][0])
+            #             )
+            #         )
+            #     )
+            # )
+
+            # thr[non_desert_idx] = threshold["regular_thr"]
+            # thr[desert_idx] = threshold["desert_thr"]
+            #
+            # midpt = thr
+            # idx = np.where(diff_11_12um >= threshold["channel_diff_11-12um_thr"][0])
+            # midpt[idx] = thr[idx] + (2.0 * diff_11_12um[idx].astype("int32"))
+
+            sfc_thresholds = preproc.thresholds_surface_temperature(
+                viirs_data, threshold, scene_idx
+            )
+
+            # correction = (
+            #     np.power(viirs_data.sensor_zenith.values[scene_idx] / SensorConstants.MAX_VZA, 4.0)
+            #     * 3.0
+            # )
+            # midpt = midpt[scene_idx] + correction
+            # locut = midpt + 2.0
+            # hicut = midpt - 2.0
+            surface_diff_temp = (
+                viirs_data.geos_surface_temperature.values[scene_idx]
+                - viirs_data.M15.values[scene_idx]
+            )
+            # sfc_thresholds = np.dstack(
+            #     (locut, midpt, hicut, np.ones(midpt.shape), np.ones(midpt.shape))
+            # )
+
+            logger.info(f"threshold size: {sfc_thresholds.shape}")
+            logger.info(f"diff_temp size: {surface_diff_temp.shape}")
+
+            bits["test"][scene_idx] = self.compute_test_bits(
+                surface_diff_temp, sfc_thresholds[1, :], "<", scene_idx=scene_idx
+            )
+            kwargs["confidence"][scene_idx] = conf.conf_test_new(surface_diff_temp, sfc_thresholds)
 
         cmin = np.fmin(cmin, kwargs["confidence"])
 
@@ -286,7 +340,6 @@ class CloudTests:
             m31 = self.data[band31].values - 273.16
             m32 = self.data[band32].values - 273.16
             bt_diff = m31 - m32
-            # sst = self.data.sst.values - 273.16
             sst = self.data.geos_sst.values - 273.16
             cosvza = np.cos(self.data.sensor_zenith.values * _DTR)