from functools import partial import tensorflow as tf def augment_image( brightness_delta=0.05, contrast_factor=[0.7, 1.3], saturation=[0.6, 1.6]): """ Helper function used for augmentation of images in the dataset. Args: brightness_delta: maximum value for randomly assigning brightness of the image. contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast. None, if not to be used. saturation: list / tuple of minimum and maximum value of factor to set random saturation. None, if not to be used. Returns: tf.data.Dataset mappable function for image augmentation """ def augment_fn(data, label, *args, **kwargs): # Augmenting data (~ 80%) def augment_steps_fn(data, label): # Randomly rotating image (~50%) def rotate_fn(data, label): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) return (tf.image.rot90(data, times), tf.image.rot90(label, times)) data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: rotate_fn(data, label), lambda: (data, label)) # Randomly flipping image (~50%) def flip_fn(data, label): return (tf.image.flip_left_right(data), tf.image.flip_left_right(label)) data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: flip_fn(data, label), lambda: (data, label)) return data, label # Randomly returning unchanged data (~20%) return tf.cond( tf.less_equal(tf.random.uniform([]), 0.2), lambda: (data, label), partial(augment_steps_fn, data, label)) return augment_fn