diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index 2f3ae5f1ef01bb5c4f39c601028b1d406a42af14..7bea589d8298152a8883f9d57ece8e4b53fcee88 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -209,11 +209,11 @@ class UNET:
         n_chans = 3
         if TRIPLET:
             n_chans *= 3
-        self.X_img = tf.keras.Input(shape=(64, 64, n_chans))
+        self.X_img = tf.keras.Input(shape=(None, None, n_chans))
 
         self.inputs.append(self.X_img)
         #self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
-        self.inputs.append(tf.keras.Input(shape=(64, 64, 3)))
+        self.inputs.append(tf.keras.Input(shape=(None, None, 3)))
 
         self.flight_level = 0
 
@@ -999,6 +999,13 @@ class UNET:
         self.build_evaluation()
         self.do_training()
 
+    def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None):
+        self.setup_salt_pipeline(data_path, label_path)
+        self.build_model()
+        self.build_training()
+        self.build_evaluation()
+        self.restore(ckpt_dir)
+
     def run_restore(self, filename_l1b, filename_l2, ckpt_dir):
         self.setup_test_pipeline(filename_l1b, filename_l2)
         self.build_model()