diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 9128ae7e4384f2e1b8fe76a9073adb2dd250d021..08a57c1061a5f43c6de9982b55f3033decfa9382 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -248,9 +248,20 @@ class IcingIntensityNN: return data, label + def get_in_mem_data_batch_train(self, idxs): + return self.get_in_mem_data_batch(idxs, True) + + def get_in_mem_data_batch_test(self, idxs): + return self.get_in_mem_data_batch(idxs, False) + @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) def data_function(self, indexes): - out = tf.numpy_function(self.get_in_mem_data_batch, [indexes], [tf.float32, tf.int32]) + out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32]) + return out + + @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) + def data_function_test(self, indexes): + out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32]) return out def get_train_dataset(self, indexes): @@ -269,7 +280,7 @@ class IcingIntensityNN: dataset = tf.data.Dataset.from_tensor_slices(indexes) dataset = dataset.batch(PROC_BATCH_SIZE) - dataset = dataset.map(self.data_function, num_parallel_calls=8) + dataset = dataset.map(self.data_function_test, num_parallel_calls=8) dataset = dataset.cache() self.test_dataset = dataset