diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index a2c34494ebb13bd3757e1a1c11b279c4274277b4..b6230fd94ac98fe45a806cce912375bd0e6d7a91 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -11,6 +11,7 @@ import h5py
 
 # L1B M/I-bands: /apollo/cloud/scratch/cwhite/VIIRS_HRES/2019/2019_01_01/
 # CLAVRx: /apollo/cloud/scratch/Satellite_Output/VIIRS_HRES/2019/2019_01_01/
+# /apollo/cloud/scratch/Satellite_Output/andi/NEW/VIIRS_HRES/2019
 
 LOG_DEVICE_PLACEMENT = False
 
@@ -24,7 +25,7 @@ else:
     NumLogits = NumClasses
 
 BATCH_SIZE = 128
-NUM_EPOCHS = 100
+NUM_EPOCHS = 40
 
 TRACK_MOVING_AVERAGE = False
 EARLY_STOP = True
@@ -207,15 +208,20 @@ class UNET:
         self.test_data_files = None
         self.test_label_files = None
 
-        # n_chans = len(self.train_params)
-        n_chans = 3
+        self.train_data_nda = None
+        self.train_label_nda = None
+        self.test_data_nda = None
+        self.test_label_nda = None
+
+        # self.n_chans = len(self.train_params)
+        self.n_chans = 1
         if TRIPLET:
-            n_chans *= 3
-        self.X_img = tf.keras.Input(shape=(None, None, n_chans))
+            self.n_chans *= 3
+        self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
 
         self.inputs.append(self.X_img)
-        #self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
-        self.inputs.append(tf.keras.Input(shape=(None, None, 3)))
+        # self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
+        self.inputs.append(tf.keras.Input(shape=(None, None, 1)))
 
         self.flight_level = 0
 
@@ -294,7 +300,7 @@ class UNET:
     #
     #     return data, data_alt, label
 
-    def get_in_mem_data_batch(self, idxs, is_training):
+    def get_in_mem_data_batch_salt(self, idxs, is_training):
         data = []
         for k in idxs:
             if is_training:
@@ -327,6 +333,33 @@ class UNET:
 
         return data, data, label
 
+    def get_in_mem_data_batch(self, idxs, is_training):
+        if is_training:
+            data = self.train_data_nda[idxs,]
+            data = np.expand_dims(data, axis=3)
+            label = self.train_label_nda[idxs,]
+            label = np.expand_dims(label, axis=3)
+        else:
+            data = self.test_data_nda[idxs,]
+            data = np.expand_dims(data, axis=3)
+            label = self.test_label_nda[idxs,]
+            label = np.expand_dims(label, axis=3)
+
+        data = data.astype(np.float32)
+        label = label.astype(np.float32)
+
+        if is_training and DO_AUGMENT:
+            data_ud = np.flip(data, axis=1)
+            label_ud = np.flip(label, axis=1)
+
+            data_lr = np.flip(data, axis=2)
+            label_lr = np.flip(label, axis=2)
+
+            data = np.concatenate([data, data_ud, data_lr])
+            label = np.concatenate([label, label_ud, label_lr])
+
+        return data, data, label
+
     # def get_parameter_data(self, param, nd_idxs, is_training):
     #     if is_training:
     #         if param in self.train_params_l1b:
@@ -386,12 +419,12 @@ class UNET:
 
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
     def data_function(self, indexes):
-        out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32, tf.int32])
+        out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32, tf.float32])
         return out
 
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
     def data_function_test(self, indexes):
-        out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32, tf.int32])
+        out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32, tf.float32])
         return out
 
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
@@ -428,49 +461,77 @@ class UNET:
         dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
         self.eval_dataset = dataset
 
-    def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None):
-        if filename_l1b_trn is not None:
-            self.h5f_l1b_trn = h5py.File(filename_l1b_trn, 'r')
-        if filename_l1b_tst is not None:
-            self.h5f_l1b_tst = h5py.File(filename_l1b_tst, 'r')
-        if filename_l2_trn is not None:
-            self.h5f_l2_trn = h5py.File(filename_l2_trn, 'r')
-        if filename_l2_tst is not None:
-            self.h5f_l2_tst = h5py.File(filename_l2_tst, 'r')
-
-        if trn_idxs is None:
-            # Note: time is same across both L1B and L2 for idxs
-            if self.h5f_l1b_trn is not None:
-                h5f = self.h5f_l1b_trn
-            else:
-                h5f = self.h5f_l2_trn
-            time = h5f['time']
-            trn_idxs = np.arange(time.shape[0])
-            if seed is not None:
-                np.random.seed(seed)
-            np.random.shuffle(trn_idxs)
-
-            if self.h5f_l1b_tst is not None:
-                h5f = self.h5f_l1b_tst
-            else:
-                h5f = self.h5f_l2_tst
-            time = h5f['time']
-            tst_idxs = np.arange(time.shape[0])
-            if seed is not None:
-                np.random.seed(seed)
-            np.random.shuffle(tst_idxs)
-
-        self.num_data_samples = trn_idxs.shape[0]
+    # def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None):
+    #     if filename_l1b_trn is not None:
+    #         self.h5f_l1b_trn = h5py.File(filename_l1b_trn, 'r')
+    #     if filename_l1b_tst is not None:
+    #         self.h5f_l1b_tst = h5py.File(filename_l1b_tst, 'r')
+    #     if filename_l2_trn is not None:
+    #         self.h5f_l2_trn = h5py.File(filename_l2_trn, 'r')
+    #     if filename_l2_tst is not None:
+    #         self.h5f_l2_tst = h5py.File(filename_l2_tst, 'r')
+    #
+    #     if trn_idxs is None:
+    #         # Note: time is same across both L1B and L2 for idxs
+    #         if self.h5f_l1b_trn is not None:
+    #             h5f = self.h5f_l1b_trn
+    #         else:
+    #             h5f = self.h5f_l2_trn
+    #         time = h5f['time']
+    #         trn_idxs = np.arange(time.shape[0])
+    #         if seed is not None:
+    #             np.random.seed(seed)
+    #         np.random.shuffle(trn_idxs)
+    #
+    #         if self.h5f_l1b_tst is not None:
+    #             h5f = self.h5f_l1b_tst
+    #         else:
+    #             h5f = self.h5f_l2_tst
+    #         time = h5f['time']
+    #         tst_idxs = np.arange(time.shape[0])
+    #         if seed is not None:
+    #             np.random.seed(seed)
+    #         np.random.shuffle(tst_idxs)
+    #
+    #     self.num_data_samples = trn_idxs.shape[0]
+    #
+    #     self.get_train_dataset(trn_idxs)
+    #     self.get_test_dataset(tst_idxs)
+    #
+    #     print('datetime: ', now)
+    #     print('training and test data: ')
+    #     print(filename_l1b_trn)
+    #     print(filename_l1b_tst)
+    #     print(filename_l2_trn)
+    #     print(filename_l2_tst)
+    #     print('---------------------------')
+    #     print('num train samples: ', self.num_data_samples)
+    #     print('BATCH SIZE: ', BATCH_SIZE)
+    #     print('num test samples: ', tst_idxs.shape[0])
+    #     print('setup_pipeline: Done')
+
+    def setup_pipeline(self, data_nda, label_nda, perc=0.20):
+
+        num_samples = data_nda.shape[0]
+        num_test = int(num_samples * perc)
+        self.num_data_samples = num_samples - num_test
+        num_train = self.num_data_samples
+
+        self.train_data_nda = data_nda[0:num_train]
+        self.train_label_nda = label_nda[0:num_train]
+        self.test_data_nda = data_nda[num_train:]
+        self.test_label_nda = label_nda[num_train:]
+
+        trn_idxs = np.arange(self.train_data_nda.shape[0])
+        tst_idxs = np.arange(self.test_data_nda.shape[0])
+
+        np.random.shuffle(tst_idxs)
 
         self.get_train_dataset(trn_idxs)
         self.get_test_dataset(tst_idxs)
 
         print('datetime: ', now)
         print('training and test data: ')
-        print(filename_l1b_trn)
-        print(filename_l1b_tst)
-        print(filename_l2_trn)
-        print(filename_l2_tst)
         print('---------------------------')
         print('num train samples: ', self.num_data_samples)
         print('BATCH SIZE: ', BATCH_SIZE)
@@ -560,7 +621,7 @@ class UNET:
         momentum = 0.99
 
         # num_filters = len(self.train_params) * 4
-        num_filters = 3 * 4
+        num_filters = self.n_chans * 4
 
         input_2d = self.inputs[0]
         conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=None)(input_2d)
@@ -662,10 +723,14 @@ class UNET:
         conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
         print('8: ', conv.shape)
 
-        if NumClasses == 2:
-            activation = tf.nn.sigmoid  # For binary
-        else:
-            activation = tf.nn.softmax  # For multi-class
+        conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
+        print('9: ', conv.shape)
+
+        # if NumClasses == 2:
+        #     activation = tf.nn.sigmoid  # For binary
+        # else:
+        #     activation = tf.nn.softmax  # For multi-class
+        activation = tf.nn.sigmoid
 
         # Called logits, but these are actually probabilities, see activation
         self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv)
@@ -673,10 +738,11 @@ class UNET:
         print(self.logits.shape)
 
     def build_training(self):
-        if NumClasses == 2:
-            self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)  # for two-class only
-        else:
-            self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)  # For multi-class
+        # if NumClasses == 2:
+        #     self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)  # for two-class only
+        # else:
+        #     self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)  # For multi-class
+        self.loss = tf.keras.losses.MeanSquaredError()  # Regression
 
         # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
         initial_learning_rate = 0.006
@@ -698,22 +764,26 @@ class UNET:
         self.initial_learning_rate = initial_learning_rate
 
     def build_evaluation(self):
+        #self.train_loss = tf.keras.metrics.Mean(name='train_loss')
+        #self.test_loss = tf.keras.metrics.Mean(name='test_loss')
+        self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy')
+        self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy')
         self.train_loss = tf.keras.metrics.Mean(name='train_loss')
         self.test_loss = tf.keras.metrics.Mean(name='test_loss')
 
-        if NumClasses == 2:
-            self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
-            self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
-            self.test_auc = tf.keras.metrics.AUC(name='test_auc')
-            self.test_recall = tf.keras.metrics.Recall(name='test_recall')
-            self.test_precision = tf.keras.metrics.Precision(name='test_precision')
-            self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
-            self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
-            self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
-            self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
-        else:
-            self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
-            self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
+        # if NumClasses == 2:
+        #     self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
+        #     self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
+        #     self.test_auc = tf.keras.metrics.AUC(name='test_auc')
+        #     self.test_recall = tf.keras.metrics.Recall(name='test_recall')
+        #     self.test_precision = tf.keras.metrics.Precision(name='test_precision')
+        #     self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
+        #     self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
+        #     self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
+        #     self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
+        # else:
+        #     self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
+        #     self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
 
     @tf.function
     def train_step(self, mini_batch):
@@ -745,14 +815,14 @@ class UNET:
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
-        if NumClasses == 2:
-            self.test_auc(labels, pred)
-            self.test_recall(labels, pred)
-            self.test_precision(labels, pred)
-            self.test_true_neg(labels, pred)
-            self.test_true_pos(labels, pred)
-            self.test_false_neg(labels, pred)
-            self.test_false_pos(labels, pred)
+        # if NumClasses == 2:
+        #     self.test_auc(labels, pred)
+        #     self.test_recall(labels, pred)
+        #     self.test_precision(labels, pred)
+        #     self.test_true_neg(labels, pred)
+        #     self.test_true_pos(labels, pred)
+        #     self.test_false_neg(labels, pred)
+        #     self.test_false_pos(labels, pred)
 
     def predict(self, mini_batch):
         inputs = [mini_batch[0], mini_batch[1]]
@@ -765,26 +835,26 @@ class UNET:
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
-        if NumClasses == 2:
-            self.test_auc(labels, pred)
-            self.test_recall(labels, pred)
-            self.test_precision(labels, pred)
-            self.test_true_neg(labels, pred)
-            self.test_true_pos(labels, pred)
-            self.test_false_neg(labels, pred)
-            self.test_false_pos(labels, pred)
+        # if NumClasses == 2:
+        #     self.test_auc(labels, pred)
+        #     self.test_recall(labels, pred)
+        #     self.test_precision(labels, pred)
+        #     self.test_true_neg(labels, pred)
+        #     self.test_true_pos(labels, pred)
+        #     self.test_false_neg(labels, pred)
+        #     self.test_false_pos(labels, pred)
 
     def reset_test_metrics(self):
         self.test_loss.reset_states()
         self.test_accuracy.reset_states()
-        if NumClasses == 2:
-            self.test_auc.reset_states()
-            self.test_recall.reset_states()
-            self.test_precision.reset_states()
-            self.test_true_neg.reset_states()
-            self.test_true_pos.reset_states()
-            self.test_false_neg.reset_states()
-            self.test_false_pos.reset_states()
+        # if NumClasses == 2:
+        #     self.test_auc.reset_states()
+        #     self.test_recall.reset_states()
+        #     self.test_precision.reset_states()
+        #     self.test_true_neg.reset_states()
+        #     self.test_true_pos.reset_states()
+        #     self.test_false_neg.reset_states()
+        #     self.test_false_pos.reset_states()
 
     def get_metrics(self):
         recall = self.test_recall.result()
@@ -858,20 +928,20 @@ class UNET:
                             for mini_batch_test in tst_ds:
                                 self.test_step(mini_batch_test)
 
-                        if NumClasses == 2:
-                            f1, mcc = self.get_metrics()
+                        # if NumClasses == 2:
+                        #     f1, mcc = self.get_metrics()
 
                         with self.writer_valid.as_default():
                             tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
                             tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
-                            if NumClasses == 2:
-                                tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
-                                tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
-                                tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
-                                tf.summary.scalar('f1_val', f1, step=step)
-                                tf.summary.scalar('mcc_val', mcc, step=step)
-                                tf.summary.scalar('num_train_steps', step, step=step)
-                                tf.summary.scalar('num_epochs', epoch, step=step)
+                            # if NumClasses == 2:
+                            #     tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
+                            #     tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
+                            #     tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
+                            #     tf.summary.scalar('f1_val', f1, step=step)
+                            #     tf.summary.scalar('mcc_val', mcc, step=step)
+                            #     tf.summary.scalar('num_train_steps', step, step=step)
+                            #     tf.summary.scalar('num_epochs', epoch, step=step)
 
                         with self.writer_train_valid_loss.as_default():
                             tf.summary.scalar('loss_trn', loss.numpy(), step=step)
@@ -898,24 +968,25 @@ class UNET:
                 for mini_batch in ds:
                     self.test_step(mini_batch)
 
-            if NumClasses == 2:
-                f1, mcc = self.get_metrics()
-                print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
-                      self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
-            else:
-                print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
+            print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
+            # if NumClasses == 2:
+            #     f1, mcc = self.get_metrics()
+            #     print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
+            #           self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
+            # else:
+            #     print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
             print('------------------------------------------------------')
 
             tst_loss = self.test_loss.result().numpy()
             if tst_loss < best_test_loss:
                 best_test_loss = tst_loss
-                if NumClasses == 2:
-                    best_test_acc = self.test_accuracy.result().numpy()
-                    best_test_recall = self.test_recall.result().numpy()
-                    best_test_precision = self.test_precision.result().numpy()
-                    best_test_auc = self.test_auc.result().numpy()
-                    best_test_f1 = f1.numpy()
-                    best_test_mcc = mcc.numpy()
+                # if NumClasses == 2:
+                #     best_test_acc = self.test_accuracy.result().numpy()
+                #     best_test_recall = self.test_recall.result().numpy()
+                #     best_test_precision = self.test_precision.result().numpy()
+                #     best_test_auc = self.test_auc.result().numpy()
+                #     best_test_f1 = f1.numpy()
+                #     best_test_mcc = mcc.numpy()
 
                 ckpt_manager.save()
 
@@ -941,9 +1012,9 @@ class UNET:
         if self.h5f_l2_tst is not None:
             self.h5f_l2_tst.close()
 
-        f = open(home_dir+'/best_stats_'+now+'.pkl', 'wb')
-        pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
-        f.close()
+        # f = open(home_dir+'/best_stats_'+now+'.pkl', 'wb')
+        # pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
+        # f.close()
 
     def build_model(self):
         self.build_unet()
@@ -1008,13 +1079,20 @@ class UNET:
         self.build_evaluation()
         self.do_training()
 
-    def run_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/'):
+    def run_salt(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/'):
         self.setup_salt_pipeline(data_path, label_path)
         self.build_model()
         self.build_training()
         self.build_evaluation()
         self.do_training()
 
+    def run_test(self, data_nda, label_nda):
+        self.setup_pipeline(data_nda, label_nda)
+        self.build_model()
+        self.build_training()
+        self.build_evaluation()
+        self.do_training()
+
     def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None):
         #self.setup_salt_pipeline(data_path, label_path)
         self.setup_salt_restore()