Skip to content
Snippets Groups Projects
augment.py 1.92 KiB
Newer Older
tomrink's avatar
tomrink committed
from functools import partial

import tensorflow as tf


def augment_image(
        brightness_delta=0.05,
        contrast_factor=[0.7, 1.3],
        saturation=[0.6, 1.6]):
    """ Helper function used for augmentation of images in the dataset.
      Args:
        brightness_delta: maximum value for randomly assigning brightness of the image.
        contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast.
                          None, if not to be used.
        saturation: list / tuple of minimum and maximum value of factor to set random saturation.
                    None, if not to be used.
      Returns:
        tf.data.Dataset mappable function for image augmentation
    """

tomrink's avatar
tomrink committed
    def augment_fn(data, label, *args, **kwargs):
tomrink's avatar
tomrink committed
        # Augmenting data (~ 80%)
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
        def augment_steps_fn(data, label):
tomrink's avatar
tomrink committed
            # Randomly rotating image (~50%)
tomrink's avatar
tomrink committed
            def rotate_fn(data, label):
tomrink's avatar
tomrink committed
                times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
tomrink's avatar
tomrink committed
                return (tf.image.rot90(data, times),
                        tf.image.rot90(label, times))
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
            data, label = tf.cond(
tomrink's avatar
tomrink committed
                tf.less_equal(tf.random.uniform([]), 0.5),
tomrink's avatar
tomrink committed
                lambda: rotate_fn(data, label),
                lambda: (data, label))
tomrink's avatar
tomrink committed

            # Randomly flipping image (~50%)
tomrink's avatar
tomrink committed
            def flip_fn(data, label):
                return (tf.image.flip_left_right(data),
                        tf.image.flip_left_right(label))
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
            data, label = tf.cond(
tomrink's avatar
tomrink committed
                tf.less_equal(tf.random.uniform([]), 0.5),
tomrink's avatar
tomrink committed
                lambda: flip_fn(data, label),
                lambda: (data, label))
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
            return data, label
tomrink's avatar
tomrink committed

        # Randomly returning unchanged data (~20%)
        return tf.cond(
            tf.less_equal(tf.random.uniform([]), 0.2),
tomrink's avatar
tomrink committed
            lambda: (data, label),
            partial(augment_steps_fn, data, label))
tomrink's avatar
tomrink committed

    return augment_fn