from functools import partial import tensorflow as tf def augment_image( brightness_delta=0.05, contrast_factor=[0.7, 1.3], saturation=[0.6, 1.6]): """ Helper function used for augmentation of images in the dataset. Args: brightness_delta: maximum value for randomly assigning brightness of the image. contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast. None, if not to be used. saturation: list / tuple of minimum and maximum value of factor to set random saturation. None, if not to be used. Returns: tf.data.Dataset mappable function for image augmentation """ def augment_fn(low_resolution, high_resolution, *args, **kwargs): # Augmenting data (~ 80%) def augment_steps_fn(low_resolution, high_resolution): # Randomly rotating image (~50%) def rotate_fn(low_resolution, high_resolution): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) return (tf.image.rot90(low_resolution, times), tf.image.rot90(high_resolution, times)) low_resolution, high_resolution = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: rotate_fn(low_resolution, high_resolution), lambda: (low_resolution, high_resolution)) # Randomly flipping image (~50%) def flip_fn(low_resolution, high_resolution): return (tf.image.flip_left_right(low_resolution), tf.image.flip_left_right(high_resolution)) low_resolution, high_resolution = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: flip_fn(low_resolution, high_resolution), lambda: (low_resolution, high_resolution)) # Randomly setting brightness of image (~50%) # def brightness_fn(low_resolution, high_resolution): # delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[]) # return (tf.image.adjust_brightness(low_resolution, delta=delta), # tf.image.adjust_brightness(high_resolution, delta=delta)) # # low_resolution, high_resolution = tf.cond( # tf.less_equal(tf.random.uniform([]), 0.5), # lambda: brightness_fn(low_resolution, high_resolution), # lambda: (low_resolution, high_resolution)) # # # Randomly setting constrast (~50%) # def contrast_fn(low_resolution, high_resolution): # factor = tf.random.uniform( # minval=contrast_factor[0], # maxval=contrast_factor[1], # dtype=tf.float32, shape=[]) # return (tf.image.adjust_contrast(low_resolution, factor), # tf.image.adjust_contrast(high_resolution, factor)) # # if contrast_factor: # low_resolution, high_resolution = tf.cond( # tf.less_equal(tf.random.uniform([]), 0.5), # lambda: contrast_fn(low_resolution, high_resolution), # lambda: (low_resolution, high_resolution)) # # # Randomly setting saturation(~50%) # def saturation_fn(low_resolution, high_resolution): # factor = tf.random.uniform( # minval=saturation[0], # maxval=saturation[1], # dtype=tf.float32, # shape=[]) # return (tf.image.adjust_saturation(low_resolution, factor), # tf.image.adjust_saturation(high_resolution, factor)) # # if saturation: # low_resolution, high_resolution = tf.cond( # tf.less_equal(tf.random.uniform([]), 0.5), # lambda: saturation_fn(low_resolution, high_resolution), # lambda: (low_resolution, high_resolution)) return low_resolution, high_resolution # Randomly returning unchanged data (~20%) return tf.cond( tf.less_equal(tf.random.uniform([]), 0.2), lambda: (low_resolution, high_resolution), partial(augment_steps_fn, low_resolution, high_resolution)) return augment_fn