diff --git a/modules/deeplearning/cnn_cld_frac.py b/modules/deeplearning/cnn_cld_frac.py index 417f0d83074d5aa5c514eb455f3bc32a943c3235..f408659d938ba7d5c3d73c2adeccf76bd311a515 100644 --- a/modules/deeplearning/cnn_cld_frac.py +++ b/modules/deeplearning/cnn_cld_frac.py @@ -357,12 +357,12 @@ class CNN: @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) def data_function(self, indexes): - out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32]) + out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.int32]) return out @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) def data_function_test(self, indexes): - out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32]) + out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.int32]) return out @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])