diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 9128ae7e4384f2e1b8fe76a9073adb2dd250d021..08a57c1061a5f43c6de9982b55f3033decfa9382 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -248,9 +248,20 @@ class IcingIntensityNN:
 
         return data, label
 
+    def get_in_mem_data_batch_train(self, idxs):
+        return self.get_in_mem_data_batch(idxs, True)
+
+    def get_in_mem_data_batch_test(self, idxs):
+        return self.get_in_mem_data_batch(idxs, False)
+
     @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
     def data_function(self, indexes):
-        out = tf.numpy_function(self.get_in_mem_data_batch, [indexes], [tf.float32, tf.int32])
+        out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32])
+        return out
+
+    @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
+    def data_function_test(self, indexes):
+        out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32])
         return out
 
     def get_train_dataset(self, indexes):
@@ -269,7 +280,7 @@ class IcingIntensityNN:
 
         dataset = tf.data.Dataset.from_tensor_slices(indexes)
         dataset = dataset.batch(PROC_BATCH_SIZE)
-        dataset = dataset.map(self.data_function, num_parallel_calls=8)
+        dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
         dataset = dataset.cache()
         self.test_dataset = dataset