From 8ffbc901fa2429a9e46c25ad72cee30d8467022c Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 6 Jun 2022 09:16:32 -0500
Subject: [PATCH] minor...

---
 modules/deeplearning/unet.py | 97 ++++++++++++------------------------
 1 file changed, 33 insertions(+), 64 deletions(-)

diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py
index 9e420efa..f9522139 100644
--- a/modules/deeplearning/unet.py
+++ b/modules/deeplearning/unet.py
@@ -36,12 +36,12 @@ NOISE_TRAINING = True
 NOISE_STDDEV = 0.10
 DO_AUGMENT = True
 
-img_width = 16
-
 mean_std_file = home_dir+'/viirs_emis_rad_mean_std.pkl'
-f = open(mean_std_file, 'rb')
-mean_std_dct = pickle.load(f)
-f.close()
+f_stats = open(mean_std_file, 'rb')
+mean_std_dct = pickle.load(f_stats)
+f_stats.close()
+
+param = 'M15'
 
 # -- Zero out params (Experimentation Only) ------------
 zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
@@ -93,11 +93,6 @@ class UNET:
         self.inner_handle = None
         self.in_mem_batch = None
 
-        self.h5f_l1b_trn = None
-        self.h5f_l1b_tst = None
-        self.h5f_l2_trn = None
-        self.h5f_l2_tst = None
-
         self.logits = None
 
         self.predict_data = None
@@ -120,12 +115,6 @@ class UNET:
 
         self.OUT_OF_RANGE = False
 
-        self.abi = None
-        self.temp = None
-        self.wv = None
-        self.lbfp = None
-        self.sfc = None
-
         self.in_mem_data_cache = {}
         self.in_mem_data_cache_test = {}
 
@@ -159,11 +148,6 @@ class UNET:
         self.test_data_files = None
         self.test_label_files = None
 
-        self.train_data_nda = None
-        self.train_label_nda = None
-        self.test_data_nda = None
-        self.test_label_nda = None
-
         # self.n_chans = len(self.train_params)
         self.n_chans = 1
         if TRIPLET:
@@ -171,7 +155,6 @@ class UNET:
         self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
 
         self.inputs.append(self.X_img)
-        # self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
         self.inputs.append(tf.keras.Input(shape=(None, None, 1)))
 
         self.flight_level = 0
@@ -188,45 +171,34 @@ class UNET:
 
     def get_in_mem_data_batch(self, idxs, is_training):
         if is_training:
-            train_data = []
-            train_label = []
-            for k in idxs:
-                f = self.train_data_files[k]
-                nda = np.load(f)
-                train_data.append(nda)
-
-                f = self.train_label_files[k]
-                nda = np.load(f)
-                train_label.append(nda)
-
-            data = np.concatenate(train_data)
-            data = np.expand_dims(data, axis=3)
-
-            label = np.concatenate(train_label)
-            label = np.expand_dims(label, axis=3)
+            data_files = self.train_data_files
+            label_files = self.train_label_files
         else:
-            test_data = []
-            test_label = []
-            for k in idxs:
-                f = self.test_data_files[k]
-                nda = np.load(f)
-                test_data.append(nda)
+            data_files = self.test_data_files
+            label_files = self.test_label_files
+
+        data_s = []
+        label_s = []
+        for k in idxs:
+            f = data_files[k]
+            nda = np.load(f)
+            data_s.append(nda)
 
-                f = self.test_label_files[k]
-                nda = np.load(f)
-                test_label.append(nda)
+            f = label_files[k]
+            nda = np.load(f)
+            label_s.append(nda)
 
-            data = np.concatenate(test_data)
-            data = np.expand_dims(data, axis=3)
+        data = np.concatenate(data_s)
+        data = np.expand_dims(data, axis=3)
 
-            label = np.concatenate(test_label)
-            label = np.expand_dims(label, axis=3)
+        label = np.concatenate(label_s)
+        label = np.expand_dims(label, axis=3)
 
         data = data.astype(np.float32)
         label = label.astype(np.float32)
 
-        data = normalize(data, 'M15', mean_std_dct)
-        label = normalize(label, 'M15', mean_std_dct)
+        data = normalize(data, param, mean_std_dct)
+        label = normalize(label, param, mean_std_dct)
 
         if is_training and DO_AUGMENT:
             data_ud = np.flip(data, axis=1)
@@ -337,24 +309,21 @@ class UNET:
     #     print('num test samples: ', tst_idxs.shape[0])
     #     print('setup_pipeline: Done')
 
-    def setup_pipeline(self, data_files, label_files, perc=0.20):
-        num_files = len(data_files)
-        num_test_files = int(num_files * perc)
-        num_train_files = num_files - num_test_files
+    def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples):
 
-        self.train_data_files = data_files[0:num_train_files]
-        self.train_label_files = label_files[0:num_train_files]
-        self.test_data_files = data_files[num_train_files:]
-        self.test_label_files = label_files[num_train_files:]
+        self.train_data_files = train_data_files
+        self.train_label_files = train_label_files
+        self.test_data_files = test_data_files
+        self.test_label_files = test_label_files
 
-        trn_idxs = np.arange(num_train_files)
+        trn_idxs = np.arange(len(train_data_files))
         np.random.shuffle(trn_idxs)
-        tst_idxs = np.arange(num_test_files)
+        tst_idxs = np.arange(len(train_data_files))
 
         self.get_train_dataset(trn_idxs)
         self.get_test_dataset(tst_idxs)
 
-        self.num_data_samples = num_train_files * 30  # approximately
+        self.num_data_samples = num_train_samples  # approximately
 
         print('datetime: ', now)
         print('training and test data: ')
-- 
GitLab