from functools import partial import tensorflow as tf def augment_image(): """ Helper function used for augmentation of images in the dataset. Returns: tf.data.Dataset mappable function for image augmentation """ def augment_fn(data, label, *args, **kwargs): # Augmenting data (~ 80%) def augment_steps_fn(data, label): # Randomly rotating image (~50%) def rotate_fn(data, label): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) return (tf.image.rot90(data, times), tf.image.rot90(label, times)) data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: rotate_fn(data, label), lambda: (data, label)) # Randomly flipping image (~50%) def flip_fn(data, label): return (tf.image.flip_left_right(data), tf.image.flip_left_right(label)) data, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: flip_fn(data, label), lambda: (data, label)) return data, label # Randomly returning unchanged data (~20%) return tf.cond( tf.less_equal(tf.random.uniform([]), 0.2), lambda: (data, label), partial(augment_steps_fn, data, label)) return augment_fn def augment_icing(): """ Helper function used for augmentation of images in the dataset. Returns: tf.data.Dataset mappable function for image augmentation """ def augment_fn(data, data_b, label, *args, **kwargs): # Augmenting data (~ 80%) def augment_steps_fn(data, data_b, label): # Randomly rotating image (~50%) def rotate_fn(data, data_b, label): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) return (tf.image.rot90(data, times), data_b, label) data, data_b, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: rotate_fn(data, data_b, label), lambda: (data, data_b, label)) # Randomly flipping image (~50%) def flip_fn(data, data_b, label): return (tf.image.flip_left_right(data), data_b, label) data, data_b, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: flip_fn(data, data_b, label), lambda: (data, data_b, label)) return data, data_b, label # Randomly returning unchanged data (~20%) return tf.cond( tf.less_equal(tf.random.uniform([]), 0.2), lambda: (data, data_b, label), partial(augment_steps_fn, data, data_b, label)) return augment_fn def augment_image_3arg(): """ Helper function used for augmentation of images in the dataset. Returns: tf.data.Dataset mappable function for image augmentation """ def augment_fn(data, data_b, label, *args, **kwargs): # Augmenting data (~ 80%) def augment_steps_fn(data, data_b, label): # Randomly rotating image (~50%) def rotate_fn(data, data_b, label): times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) return (tf.image.rot90(data, times), tf.image.rot90(data_b, times), tf.image.rot90(label, times)) data, data_b, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: rotate_fn(data, data_b, label), lambda: (data, data_b, label)) # Randomly flipping image (~50%) def flip_fn(data, data_b, label): return (tf.image.flip_left_right(data), tf.image.flip_left_right(data_b), tf.image.flip_left_right(label)) data, data_b, label = tf.cond( tf.less_equal(tf.random.uniform([]), 0.5), lambda: flip_fn(data, data_b, label), lambda: (data, data_b, label)) return data, data_b, label # Randomly returning unchanged data (~20%) return tf.cond( tf.less_equal(tf.random.uniform([]), 0.2), lambda: (data, data_b, label), partial(augment_steps_fn, data, data_b, label)) return augment_fn