From 9028a71f4ec1e2b2fce59ec6189bf6e2f938bea6 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Sun, 17 Apr 2022 13:02:35 -0500 Subject: [PATCH] minor... --- modules/deeplearning/unet.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index 2f3ae5f1..7bea589d 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -209,11 +209,11 @@ class UNET: n_chans = 3 if TRIPLET: n_chans *= 3 - self.X_img = tf.keras.Input(shape=(64, 64, n_chans)) + self.X_img = tf.keras.Input(shape=(None, None, n_chans)) self.inputs.append(self.X_img) #self.inputs.append(tf.keras.Input(shape=(None, None, 5))) - self.inputs.append(tf.keras.Input(shape=(64, 64, 3))) + self.inputs.append(tf.keras.Input(shape=(None, None, 3))) self.flight_level = 0 @@ -999,6 +999,13 @@ class UNET: self.build_evaluation() self.do_training() + def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None): + self.setup_salt_pipeline(data_path, label_path) + self.build_model() + self.build_training() + self.build_evaluation() + self.restore(ckpt_dir) + def run_restore(self, filename_l1b, filename_l2, ckpt_dir): self.setup_test_pipeline(filename_l1b, filename_l2) self.build_model() -- GitLab