diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index 2f3ae5f1ef01bb5c4f39c601028b1d406a42af14..7bea589d8298152a8883f9d57ece8e4b53fcee88 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -209,11 +209,11 @@ class UNET: n_chans = 3 if TRIPLET: n_chans *= 3 - self.X_img = tf.keras.Input(shape=(64, 64, n_chans)) + self.X_img = tf.keras.Input(shape=(None, None, n_chans)) self.inputs.append(self.X_img) #self.inputs.append(tf.keras.Input(shape=(None, None, 5))) - self.inputs.append(tf.keras.Input(shape=(64, 64, 3))) + self.inputs.append(tf.keras.Input(shape=(None, None, 3))) self.flight_level = 0 @@ -999,6 +999,13 @@ class UNET: self.build_evaluation() self.do_training() + def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None): + self.setup_salt_pipeline(data_path, label_path) + self.build_model() + self.build_training() + self.build_evaluation() + self.restore(ckpt_dir) + def run_restore(self, filename_l1b, filename_l2, ckpt_dir): self.setup_test_pipeline(filename_l1b, filename_l2) self.build_model()