From 9028a71f4ec1e2b2fce59ec6189bf6e2f938bea6 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Sun, 17 Apr 2022 13:02:35 -0500
Subject: [PATCH] minor...

---
 modules/deeplearning/unet.py | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index 2f3ae5f1..7bea589d 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -209,11 +209,11 @@ class UNET:
         n_chans = 3
         if TRIPLET:
             n_chans *= 3
-        self.X_img = tf.keras.Input(shape=(64, 64, n_chans))
+        self.X_img = tf.keras.Input(shape=(None, None, n_chans))
 
         self.inputs.append(self.X_img)
         #self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
-        self.inputs.append(tf.keras.Input(shape=(64, 64, 3)))
+        self.inputs.append(tf.keras.Input(shape=(None, None, 3)))
 
         self.flight_level = 0
 
@@ -999,6 +999,13 @@ class UNET:
         self.build_evaluation()
         self.do_training()
 
+    def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None):
+        self.setup_salt_pipeline(data_path, label_path)
+        self.build_model()
+        self.build_training()
+        self.build_evaluation()
+        self.restore(ckpt_dir)
+
     def run_restore(self, filename_l1b, filename_l2, ckpt_dir):
         self.setup_test_pipeline(filename_l1b, filename_l2)
         self.build_model()
-- 
GitLab