Skip to content
Snippets Groups Projects
Select Git revision
  • 6603fb4a7112e6dc281a0e0e72e5f4abfa323f82
  • master default protected
  • use_flight_altitude
  • distribute
4 results

plot_cm.py

Blame
  • augment.py 1.44 KiB
    from functools import partial
    
    import tensorflow as tf
    
    
    def augment_image():
        """ Helper function used for augmentation of images in the dataset.
          Returns:
            tf.data.Dataset mappable function for image augmentation
        """
    
        def augment_fn(data, label, *args, **kwargs):
            # Augmenting data (~ 80%)
    
            def augment_steps_fn(data, label):
                # Randomly rotating image (~50%)
                def rotate_fn(data, label):
                    times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
                    return (tf.image.rot90(data, times),
                            tf.image.rot90(label, times))
    
                data, label = tf.cond(
                    tf.less_equal(tf.random.uniform([]), 0.5),
                    lambda: rotate_fn(data, label),
                    lambda: (data, label))
    
                # Randomly flipping image (~50%)
                def flip_fn(data, label):
                    return (tf.image.flip_left_right(data),
                            tf.image.flip_left_right(label))
    
                data, label = tf.cond(
                    tf.less_equal(tf.random.uniform([]), 0.5),
                    lambda: flip_fn(data, label),
                    lambda: (data, label))
    
                return data, label
    
            # Randomly returning unchanged data (~20%)
            return tf.cond(
                tf.less_equal(tf.random.uniform([]), 0.2),
                lambda: (data, label),
                partial(augment_steps_fn, data, label))
    
        return augment_fn