diff --git a/modules/util/augment.py b/modules/util/augment.py index 4a4181794fd4c0e959ca44b0052603355ee0eaed..788bd6948be85e32a2d740133cc061663861abd9 100644 --- a/modules/util/augment.py +++ b/modules/util/augment.py @@ -20,6 +20,9 @@ def augment_image( def augment_fn(low_resolution, high_resolution, *args, **kwargs): # Augmenting data (~ 80%) + low_resolution = tf.convert_to_tensor(low_resolution) + high_resolution = tf.convert_to_tensor(high_resolution) + def augment_steps_fn(low_resolution, high_resolution): # Randomly rotating image (~50%) def rotate_fn(low_resolution, high_resolution):