From adca25e22a609f1522b83f680870f61437ca75f0 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Wed, 3 Aug 2022 16:24:40 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/espcn.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py index f665a5cf..be6d471e 100644 --- a/modules/deeplearning/espcn.py +++ b/modules/deeplearning/espcn.py @@ -210,14 +210,14 @@ class ESPCN: self.n_chans = 1 - #self.X_img = tf.keras.Input(shape=(None, None, self.n_chans)) + self.X_img = tf.keras.Input(shape=(None, None, self.n_chans)) # self.X_img = tf.keras.Input(shape=(36, 36, self.n_chans)) - self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans)) + # self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans)) self.inputs.append(self.X_img) - #self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans))) + self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans))) # self.inputs.append(tf.keras.Input(shape=(36, 36, self.n_chans))) - self.inputs.append(tf.keras.Input(shape=(32, 32, self.n_chans))) + # self.inputs.append(tf.keras.Input(shape=(32, 32, self.n_chans))) self.DISK_CACHE = False @@ -225,24 +225,16 @@ class ESPCN: def get_in_mem_data_batch(self, idxs, is_training): if is_training: - data_files = self.train_data_files - label_files = self.train_label_files + label_files = self.train_data_files else: - data_files = self.test_data_files - label_files = self.test_label_files + label_files = self.test_data_files - data_s = [] label_s = [] for k in idxs: - f = data_files[k] - nda = np.load(f) - data_s.append(nda) - f = label_files[k] nda = np.load(f) label_s.append(nda) - # data = np.concatenate(data_s) data = np.concatenate(label_s) label = np.concatenate(label_s) @@ -350,12 +342,10 @@ class ESPCN: dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8) self.eval_dataset = dataset - def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples): + def setup_pipeline(self, train_data_files, test_data_files, num_train_samples): self.train_data_files = train_data_files - self.train_label_files = train_label_files self.test_data_files = test_data_files - self.test_label_files = test_label_files trn_idxs = np.arange(len(train_data_files)) np.random.shuffle(trn_idxs) @@ -807,15 +797,11 @@ class ESPCN: def run(self, directory): train_data_files = glob.glob(directory+'data_train*.npy') valid_data_files = glob.glob(directory+'data_valid*.npy') - train_label_files = glob.glob(directory+'label_train*.npy') - valid_label_files = glob.glob(directory+'label_valid*.npy') train_data_files.sort() valid_data_files.sort() - train_label_files.sort() - valid_label_files.sort() - self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 200000) + self.setup_pipeline(train_data_files, valid_data_files, 200000) self.build_model() self.build_training() self.build_evaluation() -- GitLab