From 63b15479de6ddc6f5e5be459759744c80684cc8e Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Fri, 7 May 2021 14:06:07 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/icing_cnn.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 9128ae7e..08a57c10 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -248,9 +248,20 @@ class IcingIntensityNN: return data, label + def get_in_mem_data_batch_train(self, idxs): + return self.get_in_mem_data_batch(idxs, True) + + def get_in_mem_data_batch_test(self, idxs): + return self.get_in_mem_data_batch(idxs, False) + @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) def data_function(self, indexes): - out = tf.numpy_function(self.get_in_mem_data_batch, [indexes], [tf.float32, tf.int32]) + out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32]) + return out + + @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) + def data_function_test(self, indexes): + out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32]) return out def get_train_dataset(self, indexes): @@ -269,7 +280,7 @@ class IcingIntensityNN: dataset = tf.data.Dataset.from_tensor_slices(indexes) dataset = dataset.batch(PROC_BATCH_SIZE) - dataset = dataset.map(self.data_function, num_parallel_calls=8) + dataset = dataset.map(self.data_function_test, num_parallel_calls=8) dataset = dataset.cache() self.test_dataset = dataset -- GitLab