diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index 15a89cf256d6e614a0b7fd97b2aee219ed5cbf1a..84a91bbea2f829788c4f73d4f7f6d78f4a1c5e18 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -1,7 +1,8 @@
+import glob
 import tensorflow as tf
 from util.setup import logdir, modeldir, cachepath, now, ancillary_path
 from util.util import EarlyStop, normalize, make_for_full_domain_predict
-from util.geos_nav import get_navigation, get_lon_lat_2d_mesh
+import matplotlib.pyplot as plt
 
 import os, datetime
 import numpy as np
@@ -11,7 +12,7 @@ import h5py
 
 LOG_DEVICE_PLACEMENT = False
 
-PROC_BATCH_SIZE = 4096
+PROC_BATCH_SIZE = 1024
 PROC_BATCH_BUFFER_SIZE = 50000
 
 NumClasses = 2
@@ -199,15 +200,20 @@ class UNET:
         self.initial_learning_rate = None
 
         self.data_dct = None
+        self.train_data_files = None
+        self.train_label_files = None
+        self.test_data_files = None
+        self.test_label_files = None
 
-        n_chans = len(self.train_params)
+        # n_chans = len(self.train_params)
+        n_chans = 3
         if TRIPLET:
             n_chans *= 3
-        #self.X_img = tf.keras.Input(shape=(None, None, n_chans))
-        self.X_img = tf.keras.Input(shape=(16, 16, n_chans))
+        self.X_img = tf.keras.Input(shape=(64, 64, n_chans))
 
         self.inputs.append(self.X_img)
-        self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
+        #self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
+        self.inputs.append(tf.keras.Input(shape=(64, 64, 3)))
 
         self.flight_level = 0
 
@@ -236,120 +242,120 @@ class UNET:
         #         # Memory growth must be set before GPUs have been initialized
         #         print(e)
 
-    def get_in_mem_data_batch(self, idxs, is_training):
-
-        # sort these to use as numpy indexing arrays
-        nd_idxs = np.array(idxs)
-        nd_idxs = np.sort(nd_idxs)
+    # def get_in_mem_data_batch(self, idxs, is_training):
+    #
+    #     # sort these to use as numpy indexing arrays
+    #     nd_idxs = np.array(idxs)
+    #     nd_idxs = np.sort(nd_idxs)
+    #
+    #     data = []
+    #     for param in self.train_params:
+    #         nda = self.get_parameter_data(param, nd_idxs, is_training)
+    #         nda = normalize(nda, param, mean_std_dct)
+    #         if DO_ZERO_OUT and is_training:
+    #             try:
+    #                 zero_out_params.index(param)
+    #                 nda[:,] = 0.0
+    #             except ValueError:
+    #                 pass
+    #         data.append(nda)
+    #     data = np.stack(data)
+    #     data = data.astype(np.float32)
+    #     data = np.transpose(data, axes=(1, 2, 3, 0))
+    #
+    #     data_alt = self.get_scalar_data(nd_idxs, is_training)
+    #
+    #     label = self.get_label_data(nd_idxs, is_training)
+    #     label = np.where(label == -1, 0, label)
+    #
+    #     # binary, two class
+    #     if NumClasses == 2:
+    #         label = np.where(label != 0, 1, label)
+    #         label = label.reshape((label.shape[0], 1))
+    #     elif NumClasses == 3:
+    #         label = np.where(np.logical_or(label == 1, label == 2), 1, label)
+    #         label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
+    #         label = label.reshape((label.shape[0], 1))
+    #
+    #     if is_training and DO_AUGMENT:
+    #         data_ud = np.flip(data, axis=1)
+    #         data_alt_ud = np.copy(data_alt)
+    #         label_ud = np.copy(label)
+    #
+    #         data_lr = np.flip(data, axis=2)
+    #         data_alt_lr = np.copy(data_alt)
+    #         label_lr = np.copy(label)
+    #
+    #         data = np.concatenate([data, data_ud, data_lr])
+    #         data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr])
+    #         label = np.concatenate([label, label_ud, label_lr])
+    #
+    #     return data, data_alt, label
 
+    def get_in_mem_data_batch(self, idxs, is_training):
         data = []
-        for param in self.train_params:
-            nda = self.get_parameter_data(param, nd_idxs, is_training)
-            # Manual Corruption Process. Better: see use of tf.keras.layers.GaussianNoise
-            # if NOISE_TRAINING and is_training:
-            #     nda = normalize(nda, param, mean_std_dct, add_noise=True, noise_scale=0.01, seed=42)
-            # else:
-            #     nda = normalize(nda, param, mean_std_dct)
-            nda = normalize(nda, param, mean_std_dct)
-            if DO_ZERO_OUT and is_training:
-                try:
-                    zero_out_params.index(param)
-                    nda[:,] = 0.0
-                except ValueError:
-                    pass
-            data.append(nda)
+        for k in idxs:
+            if is_training:
+                nda = plt.imread(self.train_data_files[k])
+            else:
+                nda = plt.imread(self.test_data_files[k])
+            data.append(nda[0:64, 0:64])
         data = np.stack(data)
         data = data.astype(np.float32)
-        data = np.transpose(data, axes=(1, 2, 3, 0))
-
-        data_alt = self.get_scalar_data(nd_idxs, is_training)
 
-        label = self.get_label_data(nd_idxs, is_training)
-        label = np.where(label == -1, 0, label)
-
-        # binary, two class
-        if NumClasses == 2:
-            label = np.where(label != 0, 1, label)
-            label = label.reshape((label.shape[0], 1))
-        elif NumClasses == 3:
-            label = np.where(np.logical_or(label == 1, label == 2), 1, label)
-            label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
-            label = label.reshape((label.shape[0], 1))
+        label = []
+        for k in idxs:
+            if is_training:
+                nda = plt.imread(self.train_label_files[k])
+            else:
+                nda = plt.imread(self.test_label_files[k])
+            label.append(nda[0:64, 0:64])
+        label = np.stack(label)
+        label = label.astype(np.int32)
 
         if is_training and DO_AUGMENT:
             data_ud = np.flip(data, axis=1)
-            data_alt_ud = np.copy(data_alt)
-            label_ud = np.copy(label)
+            label_ud = np.flip(label, axis=1)
 
             data_lr = np.flip(data, axis=2)
-            data_alt_lr = np.copy(data_alt)
-            label_lr = np.copy(label)
+            label_lr = np.flip(label, axis=2)
 
             data = np.concatenate([data, data_ud, data_lr])
-            data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr])
             label = np.concatenate([label, label_ud, label_lr])
 
-        return data, data_alt, label
-
-    def get_parameter_data(self, param, nd_idxs, is_training):
-        if is_training:
-            if param in self.train_params_l1b:
-                h5f = self.h5f_l1b_trn
-            else:
-                h5f = self.h5f_l2_trn
-        else:
-            if param in self.train_params_l1b:
-                h5f = self.h5f_l1b_tst
-            else:
-                h5f = self.h5f_l2_tst
-
-        nda = h5f[param][nd_idxs,]
-        return nda
-
-    def get_scalar_data(self, nd_idxs, is_training):
-        param = 'flight_altitude'
-
-        if is_training:
-            if self.h5f_l1b_trn is not None:
-                h5f = self.h5f_l1b_trn
-            else:
-                h5f = self.h5f_l2_trn
-        else:
-            if self.h5f_l1b_tst is not None:
-                h5f = self.h5f_l1b_tst
-            else:
-                h5f = self.h5f_l2_tst
-
-        nda = h5f[param][nd_idxs,]
-
-        nda[np.logical_and(nda >= 0, nda < 2000)] = 0
-        nda[np.logical_and(nda >= 2000, nda < 4000)] = 1
-        nda[np.logical_and(nda >= 4000, nda < 6000)] = 2
-        nda[np.logical_and(nda >= 6000, nda < 8000)] = 3
-        nda[np.logical_and(nda >= 8000, nda < 15000)] = 4
-
-        nda = tf.one_hot(nda, 5).numpy()
-        nda = np.expand_dims(nda, axis=1)
-        nda = np.expand_dims(nda, axis=1)
-
-        return nda
-
-    def get_label_data(self, nd_idxs, is_training):
-        # Note: labels will be same for nd_idxs across both L1B and L2
-        if is_training:
-            if self.h5f_l1b_trn is not None:
-                h5f = self.h5f_l1b_trn
-            else:
-                h5f = self.h5f_l2_trn
-        else:
-            if self.h5f_l1b_tst is not None:
-                h5f = self.h5f_l1b_tst
-            else:
-                h5f = self.h5f_l2_tst
-
-        label = h5f['icing_intensity'][nd_idxs]
-        label = label.astype(np.int32)
-        return label
+        return data, data, label
+
+    # def get_parameter_data(self, param, nd_idxs, is_training):
+    #     if is_training:
+    #         if param in self.train_params_l1b:
+    #             h5f = self.h5f_l1b_trn
+    #         else:
+    #             h5f = self.h5f_l2_trn
+    #     else:
+    #         if param in self.train_params_l1b:
+    #             h5f = self.h5f_l1b_tst
+    #         else:
+    #             h5f = self.h5f_l2_tst
+    #
+    #     nda = h5f[param][nd_idxs,]
+    #     return nda
+    #
+    # def get_label_data(self, nd_idxs, is_training):
+    #     # Note: labels will be same for nd_idxs across both L1B and L2
+    #     if is_training:
+    #         if self.h5f_l1b_trn is not None:
+    #             h5f = self.h5f_l1b_trn
+    #         else:
+    #             h5f = self.h5f_l2_trn
+    #     else:
+    #         if self.h5f_l1b_tst is not None:
+    #             h5f = self.h5f_l1b_tst
+    #         else:
+    #             h5f = self.h5f_l2_tst
+    #
+    #     label = h5f['icing_intensity'][nd_idxs]
+    #     label = label.astype(np.int32)
+    #     return label
 
     def get_in_mem_data_batch_train(self, idxs):
         return self.get_in_mem_data_batch(idxs, True)
@@ -500,6 +506,26 @@ class UNET:
 
         self.get_evaluate_dataset(idxs)
 
+    def setup_salt_pipeline(self, data_path, label_path, perc=0.2):
+        data_files = glob.glob(data_path+'*.png')
+        label_files = glob.glob(label_path+'*.png')
+        num_data_samples = len(data_files)
+        idxs = np.arange(num_data_samples)
+        num_train = int(num_data_samples*(1-perc))
+        self.num_data_samples = num_train
+        trn_idxs = idxs[0:num_train]
+
+        self.train_data_files = data_files[0:num_train]
+        self.train_label_files = label_files[0:num_train]
+
+        self.test_data_files = data_files[num_train:]
+        self.test_label_files = label_files[num_train:]
+        tst_idxs = np.arange(num_data_samples-num_train)
+
+        np.random.shuffle(trn_idxs)
+        self.get_train_dataset(trn_idxs)
+        self.get_test_dataset(tst_idxs)
+
     def build_unet(self):
         print('build_cnn')
         # padding = "VALID"
@@ -510,7 +536,8 @@ class UNET:
         activation = tf.nn.leaky_relu
         momentum = 0.99
 
-        num_filters = len(self.train_params) * 4
+        # num_filters = len(self.train_params) * 4
+        num_filters = 3 * 4
 
         input_2d = self.inputs[0]
         conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=None)(input_2d)
@@ -671,7 +698,6 @@ class UNET:
         labels = mini_batch[2]
         with tf.GradientTape() as tape:
             pred = self.model(inputs, training=True)
-            pred = tf.reshape(pred, (pred.shape[0], NumLogits))
             loss = self.loss(labels, pred)
             total_loss = loss
             if len(self.model.losses) > 0:
@@ -692,7 +718,6 @@ class UNET:
         inputs = [mini_batch[0], mini_batch[1]]
         labels = mini_batch[2]
         pred = self.model(inputs, training=False)
-        pred = tf.reshape(pred, (pred.shape[0], NumLogits))
         t_loss = self.loss(labels, pred)
 
         self.test_loss(t_loss)
@@ -954,14 +979,19 @@ class UNET:
         self.test_preds = preds
 
     def run(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst):
-        # This doesn't really play well with SLURM
-        # with tf.device('/device:GPU:'+str(self.gpu_device)):
         self.setup_pipeline(filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst)
         self.build_model()
         self.build_training()
         self.build_evaluation()
         self.do_training()
 
+    def run_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/'):
+        self.setup_salt_pipeline(data_path, label_path)
+        self.build_model()
+        self.build_training()
+        self.build_evaluation()
+        self.do_training()
+
     def run_restore(self, filename_l1b, filename_l2, ckpt_dir):
         self.setup_test_pipeline(filename_l1b, filename_l2)
         self.build_model()