From 97c0a8ad9e29761b677c888e69813bf7e1e9b795 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 5 May 2021 09:53:34 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/icing_cnn.py | 48 +++++++++++++++++++++++++++++--
 1 file changed, 46 insertions(+), 2 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 0f5f7ee5..27a14f0d 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -243,6 +243,50 @@ class IcingIntensityNN:
 
         return data, label
 
+    def get_in_mem_data_batch_test(self, idxs):
+        key = frozenset(idxs)
+
+        if CACHE_DATA_IN_MEM:
+            tup = self.in_mem_data_cache.get(key)
+            if tup is not None:
+                return tup[0], tup[1]
+
+        # sort these to use as numpy indexing arrays
+        nd_idxs = np.array(idxs)
+        nd_idxs = np.sort(nd_idxs)
+
+        data = []
+        for param in train_params:
+            nda = self.h5f[param][nd_idxs, ]
+            nda = normalize(nda, param, mean_std_dct)
+            data.append(nda)
+        data = np.stack(data)
+        data = data.astype(np.float32)
+        data = np.transpose(data, axes=(1, 2, 3, 0))
+
+        label = self.h5f['icing_intensity'][nd_idxs]
+        label = label.astype(np.int32)
+        label = np.where(label == -1, 0, label)
+
+        # binary, two class
+        if NumClasses == 2:
+            label = np.where(label != 0, 1, label)
+            label = label.reshape((label.shape[0], 1))
+        elif NumClasses == 3:
+            label = np.where(np.logical_or(label == 1, label == 2), 1, label)
+            label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
+            label = label.reshape((label.shape[0], 1))
+
+        if CACHE_DATA_IN_MEM:
+            self.in_mem_data_cache[key] = (data, label)
+
+        return data, label
+
+    @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
+    def data_function_test(self, indexes):
+        out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32])
+        return out
+
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
     def data_function(self, indexes):
         out = tf.numpy_function(self.get_in_mem_data_batch, [indexes], [tf.float32, tf.int32])
@@ -254,7 +298,7 @@ class IcingIntensityNN:
         dataset = tf.data.Dataset.from_tensor_slices(indexes)
         dataset = dataset.batch(PROC_BATCH_SIZE)
         dataset = dataset.map(self.data_function, num_parallel_calls=8)
-        # dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
+        dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
         dataset = dataset.prefetch(buffer_size=1)
         self.train_dataset = dataset
 
@@ -263,7 +307,7 @@ class IcingIntensityNN:
 
         dataset = tf.data.Dataset.from_tensor_slices(indexes)
         dataset = dataset.batch(PROC_BATCH_SIZE)
-        dataset = dataset.map(self.data_function, num_parallel_calls=8)
+        dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
         self.test_dataset = dataset
 
     def setup_pipeline(self, filename, trn_idxs=None, tst_idxs=None, seed=None):
-- 
GitLab