diff --git a/modules/util/augment.py b/modules/util/augment.py index c48f3e8341988ef5072d805ff8ce61b3b308c780..f88b93df49a92d1d03402416ebe28aca300be4c3 100644 --- a/modules/util/augment.py +++ b/modules/util/augment.py @@ -18,79 +18,79 @@ def augment_image( tf.data.Dataset mappable function for image augmentation """ - def augment_fn(low_resolution, high_resolution, *args, **kwargs): + def augment_fn(data, label, *args, **kwargs): # Augmenting data (~ 80%) - def augment_steps_fn(low_resolution, high_resolution): + def augment_steps_fn(data, label): # Randomly rotating image (~50%) - def rotate_fn(low_resolution, high_resolution): + def rotate_fn(data, label): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) - return (tf.image.rot90(low_resolution, times), - tf.image.rot90(high_resolution, times)) + return (tf.image.rot90(data, times), + tf.image.rot90(label, times)) - low_resolution, high_resolution = tf.cond( + data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), - lambda: rotate_fn(low_resolution, high_resolution), - lambda: (low_resolution, high_resolution)) + lambda: rotate_fn(data, label), + lambda: (data, label)) # Randomly flipping image (~50%) - def flip_fn(low_resolution, high_resolution): - return (tf.image.flip_left_right(low_resolution), - tf.image.flip_left_right(high_resolution)) + def flip_fn(data, label): + return (tf.image.flip_left_right(data), + tf.image.flip_left_right(label)) - low_resolution, high_resolution = tf.cond( + data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), - lambda: flip_fn(low_resolution, high_resolution), - lambda: (low_resolution, high_resolution)) + lambda: flip_fn(data, label), + lambda: (data, label)) # Randomly setting brightness of image (~50%) - # def brightness_fn(low_resolution, high_resolution): + # def brightness_fn(data, label): # delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[]) - # return (tf.image.adjust_brightness(low_resolution, delta=delta), - # tf.image.adjust_brightness(high_resolution, delta=delta)) + # return (tf.image.adjust_brightness(data, delta=delta), + # tf.image.adjust_brightness(label, delta=delta)) # - # low_resolution, high_resolution = tf.cond( + # data, label = tf.cond( # tf.less_equal(tf.random.uniform([]), 0.5), - # lambda: brightness_fn(low_resolution, high_resolution), - # lambda: (low_resolution, high_resolution)) + # lambda: brightness_fn(data, label), + # lambda: (data, label)) # # # Randomly setting constrast (~50%) - # def contrast_fn(low_resolution, high_resolution): + # def contrast_fn(data, label): # factor = tf.random.uniform( # minval=contrast_factor[0], # maxval=contrast_factor[1], # dtype=tf.float32, shape=[]) - # return (tf.image.adjust_contrast(low_resolution, factor), - # tf.image.adjust_contrast(high_resolution, factor)) + # return (tf.image.adjust_contrast(data, factor), + # tf.image.adjust_contrast(label, factor)) # # if contrast_factor: - # low_resolution, high_resolution = tf.cond( + # data, label = tf.cond( # tf.less_equal(tf.random.uniform([]), 0.5), - # lambda: contrast_fn(low_resolution, high_resolution), - # lambda: (low_resolution, high_resolution)) + # lambda: contrast_fn(data, label), + # lambda: (data, label)) # # # Randomly setting saturation(~50%) - # def saturation_fn(low_resolution, high_resolution): + # def saturation_fn(data, label): # factor = tf.random.uniform( # minval=saturation[0], # maxval=saturation[1], # dtype=tf.float32, # shape=[]) - # return (tf.image.adjust_saturation(low_resolution, factor), - # tf.image.adjust_saturation(high_resolution, factor)) + # return (tf.image.adjust_saturation(data, factor), + # tf.image.adjust_saturation(label, factor)) # # if saturation: - # low_resolution, high_resolution = tf.cond( + # data, label = tf.cond( # tf.less_equal(tf.random.uniform([]), 0.5), - # lambda: saturation_fn(low_resolution, high_resolution), - # lambda: (low_resolution, high_resolution)) + # lambda: saturation_fn(data, label), + # lambda: (data, label)) - return low_resolution, high_resolution + return data, label # Randomly returning unchanged data (~20%) return tf.cond( tf.less_equal(tf.random.uniform([]), 0.2), - lambda: (low_resolution, high_resolution), - partial(augment_steps_fn, low_resolution, high_resolution)) + lambda: (data, label), + partial(augment_steps_fn, data, label)) return augment_fn