augment.py 4.56 KiB
from functools import partial
import tensorflow as tf
def augment_image():
""" Helper function used for augmentation of images in the dataset.
Returns:
tf.data.Dataset mappable function for image augmentation
"""
def augment_fn(data, label, *args, **kwargs):
# Augmenting data (~ 80%)
def augment_steps_fn(data, label):
# Randomly rotating image (~50%)
def rotate_fn(data, label):
times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
return (tf.image.rot90(data, times),
tf.image.rot90(label, times))
data, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: rotate_fn(data, label),
lambda: (data, label))
# Randomly flipping image (~50%)
def flip_fn(data, label):
return (tf.image.flip_left_right(data),
tf.image.flip_left_right(label))
data, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: flip_fn(data, label),
lambda: (data, label))
return data, label
# Randomly returning unchanged data (~20%)
return tf.cond(
tf.less_equal(tf.random.uniform([]), 0.2),
lambda: (data, label),
partial(augment_steps_fn, data, label))
return augment_fn
def augment_icing():
""" Helper function used for augmentation of images in the dataset.
Returns:
tf.data.Dataset mappable function for image augmentation
"""
def augment_fn(data, data_b, label, *args, **kwargs):
# Augmenting data (~ 80%)
def augment_steps_fn(data, data_b, label):
# Randomly rotating image (~50%)
def rotate_fn(data, data_b, label):
times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
return (tf.image.rot90(data, times),
data_b,
label)
data, data_b, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: rotate_fn(data, data_b, label),
lambda: (data, data_b, label))
# Randomly flipping image (~50%)
def flip_fn(data, data_b, label):
return (tf.image.flip_left_right(data),
data_b,
label)
data, data_b, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: flip_fn(data, data_b, label),
lambda: (data, data_b, label))
return data, data_b, label
# Randomly returning unchanged data (~20%)
return tf.cond(
tf.less_equal(tf.random.uniform([]), 0.2),
lambda: (data, data_b, label),
partial(augment_steps_fn, data, data_b, label))
return augment_fn
def augment_image_3arg():
""" Helper function used for augmentation of images in the dataset.
Returns:
tf.data.Dataset mappable function for image augmentation
"""
def augment_fn(data, data_b, label, *args, **kwargs):
# Augmenting data (~ 80%)
def augment_steps_fn(data, data_b, label):
# Randomly rotating image (~50%)
def rotate_fn(data, data_b, label):
times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
return (tf.image.rot90(data, times),
tf.image.rot90(data_b, times),
tf.image.rot90(label, times))
data, data_b, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: rotate_fn(data, data_b, label),
lambda: (data, data_b, label))
# Randomly flipping image (~50%)
def flip_fn(data, data_b, label):
return (tf.image.flip_left_right(data),
tf.image.flip_left_right(data_b),
tf.image.flip_left_right(label))
data, data_b, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: flip_fn(data, data_b, label),
lambda: (data, data_b, label))
return data, data_b, label
# Randomly returning unchanged data (~20%)
return tf.cond(
tf.less_equal(tf.random.uniform([]), 0.2),
lambda: (data, data_b, label),
partial(augment_steps_fn, data, data_b, label))
return augment_fn