From 63b15479de6ddc6f5e5be459759744c80684cc8e Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Fri, 7 May 2021 14:06:07 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/icing_cnn.py | 15 +++++++++++++--
 1 file changed, 13 insertions(+), 2 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 9128ae7e..08a57c10 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -248,9 +248,20 @@ class IcingIntensityNN:
 
         return data, label
 
+    def get_in_mem_data_batch_train(self, idxs):
+        return self.get_in_mem_data_batch(idxs, True)
+
+    def get_in_mem_data_batch_test(self, idxs):
+        return self.get_in_mem_data_batch(idxs, False)
+
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
     def data_function(self, indexes):
-        out = tf.numpy_function(self.get_in_mem_data_batch, [indexes], [tf.float32, tf.int32])
+        out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32])
+        return out
+
+    @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
+    def data_function_test(self, indexes):
+        out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32])
         return out
 
     def get_train_dataset(self, indexes):
@@ -269,7 +280,7 @@ class IcingIntensityNN:
 
         dataset = tf.data.Dataset.from_tensor_slices(indexes)
         dataset = dataset.batch(PROC_BATCH_SIZE)
-        dataset = dataset.map(self.data_function, num_parallel_calls=8)
+        dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
         dataset = dataset.cache()
         self.test_dataset = dataset
 
-- 
GitLab