From adca25e22a609f1522b83f680870f61437ca75f0 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 3 Aug 2022 16:24:40 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/espcn.py | 30 ++++++++----------------------
 1 file changed, 8 insertions(+), 22 deletions(-)

diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py
index f665a5cf..be6d471e 100644
--- a/modules/deeplearning/espcn.py
+++ b/modules/deeplearning/espcn.py
@@ -210,14 +210,14 @@ class ESPCN:
 
         self.n_chans = 1
 
-        #self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
+        self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
         # self.X_img = tf.keras.Input(shape=(36, 36, self.n_chans))
-        self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans))
+        # self.X_img = tf.keras.Input(shape=(32, 32, self.n_chans))
 
         self.inputs.append(self.X_img)
-        #self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans)))
+        self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans)))
         # self.inputs.append(tf.keras.Input(shape=(36, 36, self.n_chans)))
-        self.inputs.append(tf.keras.Input(shape=(32, 32, self.n_chans)))
+        # self.inputs.append(tf.keras.Input(shape=(32, 32, self.n_chans)))
 
         self.DISK_CACHE = False
 
@@ -225,24 +225,16 @@ class ESPCN:
 
     def get_in_mem_data_batch(self, idxs, is_training):
         if is_training:
-            data_files = self.train_data_files
-            label_files = self.train_label_files
+            label_files = self.train_data_files
         else:
-            data_files = self.test_data_files
-            label_files = self.test_label_files
+            label_files = self.test_data_files
 
-        data_s = []
         label_s = []
         for k in idxs:
-            f = data_files[k]
-            nda = np.load(f)
-            data_s.append(nda)
-
             f = label_files[k]
             nda = np.load(f)
             label_s.append(nda)
 
-        # data = np.concatenate(data_s)
         data = np.concatenate(label_s)
         label = np.concatenate(label_s)
 
@@ -350,12 +342,10 @@ class ESPCN:
         dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
         self.eval_dataset = dataset
 
-    def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples):
+    def setup_pipeline(self, train_data_files, test_data_files, num_train_samples):
 
         self.train_data_files = train_data_files
-        self.train_label_files = train_label_files
         self.test_data_files = test_data_files
-        self.test_label_files = test_label_files
 
         trn_idxs = np.arange(len(train_data_files))
         np.random.shuffle(trn_idxs)
@@ -807,15 +797,11 @@ class ESPCN:
     def run(self, directory):
         train_data_files = glob.glob(directory+'data_train*.npy')
         valid_data_files = glob.glob(directory+'data_valid*.npy')
-        train_label_files = glob.glob(directory+'label_train*.npy')
-        valid_label_files = glob.glob(directory+'label_valid*.npy')
 
         train_data_files.sort()
         valid_data_files.sort()
-        train_label_files.sort()
-        valid_label_files.sort()
 
-        self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, 200000)
+        self.setup_pipeline(train_data_files, valid_data_files, 200000)
         self.build_model()
         self.build_training()
         self.build_evaluation()
-- 
GitLab