diff --git a/modules/util/augment.py b/modules/util/augment.py index 9668d78081bb93291f82ca6d46b9e9040a3bed6e..68c39c61e8d90e2771fa5527bcf8735a8bd0137d 100644 --- a/modules/util/augment.py +++ b/modules/util/augment.py @@ -59,8 +59,8 @@ def augment_image_3arg(): def rotate_fn(data, data_b, label): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) return (tf.image.rot90(data, times), - tf.image.rot90(data_b, times), - tf.image.rot90(label, times)) + data_b, + label) data, data_b, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), @@ -70,8 +70,8 @@ def augment_image_3arg(): # Randomly flipping image (~50%) def flip_fn(data, data_b, label): return (tf.image.flip_left_right(data), - tf.image.flip_left_right(data_b), - tf.image.flip_left_right(label)) + data_b, + label) data, data_b, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5),