diff --git a/modules/util/augment.py b/modules/util/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4181794fd4c0e959ca44b0052603355ee0eaed --- /dev/null +++ b/modules/util/augment.py @@ -0,0 +1,95 @@ +from functools import partial + +import tensorflow as tf + + +def augment_image( + brightness_delta=0.05, + contrast_factor=[0.7, 1.3], + saturation=[0.6, 1.6]): + """ Helper function used for augmentation of images in the dataset. + Args: + brightness_delta: maximum value for randomly assigning brightness of the image. + contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast. + None, if not to be used. + saturation: list / tuple of minimum and maximum value of factor to set random saturation. + None, if not to be used. + Returns: + tf.data.Dataset mappable function for image augmentation + """ + + def augment_fn(low_resolution, high_resolution, *args, **kwargs): + # Augmenting data (~ 80%) + def augment_steps_fn(low_resolution, high_resolution): + # Randomly rotating image (~50%) + def rotate_fn(low_resolution, high_resolution): + times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) + return (tf.image.rot90(low_resolution, times), + tf.image.rot90(high_resolution, times)) + + low_resolution, high_resolution = tf.cond( + tf.less_equal(tf.random.uniform([]), 0.5), + lambda: rotate_fn(low_resolution, high_resolution), + lambda: (low_resolution, high_resolution)) + + # Randomly flipping image (~50%) + def flip_fn(low_resolution, high_resolution): + return (tf.image.flip_left_right(low_resolution), + tf.image.flip_left_right(high_resolution)) + + low_resolution, high_resolution = tf.cond( + tf.less_equal(tf.random.uniform([]), 0.5), + lambda: flip_fn(low_resolution, high_resolution), + lambda: (low_resolution, high_resolution)) + + # Randomly setting brightness of image (~50%) + # def brightness_fn(low_resolution, high_resolution): + # delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[]) + # return (tf.image.adjust_brightness(low_resolution, delta=delta), + # tf.image.adjust_brightness(high_resolution, delta=delta)) + # + # low_resolution, high_resolution = tf.cond( + # tf.less_equal(tf.random.uniform([]), 0.5), + # lambda: brightness_fn(low_resolution, high_resolution), + # lambda: (low_resolution, high_resolution)) + # + # # Randomly setting constrast (~50%) + # def contrast_fn(low_resolution, high_resolution): + # factor = tf.random.uniform( + # minval=contrast_factor[0], + # maxval=contrast_factor[1], + # dtype=tf.float32, shape=[]) + # return (tf.image.adjust_contrast(low_resolution, factor), + # tf.image.adjust_contrast(high_resolution, factor)) + # + # if contrast_factor: + # low_resolution, high_resolution = tf.cond( + # tf.less_equal(tf.random.uniform([]), 0.5), + # lambda: contrast_fn(low_resolution, high_resolution), + # lambda: (low_resolution, high_resolution)) + # + # # Randomly setting saturation(~50%) + # def saturation_fn(low_resolution, high_resolution): + # factor = tf.random.uniform( + # minval=saturation[0], + # maxval=saturation[1], + # dtype=tf.float32, + # shape=[]) + # return (tf.image.adjust_saturation(low_resolution, factor), + # tf.image.adjust_saturation(high_resolution, factor)) + # + # if saturation: + # low_resolution, high_resolution = tf.cond( + # tf.less_equal(tf.random.uniform([]), 0.5), + # lambda: saturation_fn(low_resolution, high_resolution), + # lambda: (low_resolution, high_resolution)) + + return low_resolution, high_resolution + + # Randomly returning unchanged data (~20%) + return tf.cond( + tf.less_equal(tf.random.uniform([]), 0.2), + lambda: (low_resolution, high_resolution), + partial(augment_steps_fn, low_resolution, high_resolution)) + + return augment_fn