Newer
Older
from functools import partial
import tensorflow as tf
def augment_image(
brightness_delta=0.05,
contrast_factor=[0.7, 1.3],
saturation=[0.6, 1.6]):
""" Helper function used for augmentation of images in the dataset.
Args:
brightness_delta: maximum value for randomly assigning brightness of the image.
contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast.
None, if not to be used.
saturation: list / tuple of minimum and maximum value of factor to set random saturation.
None, if not to be used.
Returns:
tf.data.Dataset mappable function for image augmentation
"""
def augment_fn(low_resolution, high_resolution, *args, **kwargs):
# Augmenting data (~ 80%)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def augment_steps_fn(low_resolution, high_resolution):
# Randomly rotating image (~50%)
def rotate_fn(low_resolution, high_resolution):
times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
return (tf.image.rot90(low_resolution, times),
tf.image.rot90(high_resolution, times))
low_resolution, high_resolution = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: rotate_fn(low_resolution, high_resolution),
lambda: (low_resolution, high_resolution))
# Randomly flipping image (~50%)
def flip_fn(low_resolution, high_resolution):
return (tf.image.flip_left_right(low_resolution),
tf.image.flip_left_right(high_resolution))
low_resolution, high_resolution = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5),
lambda: flip_fn(low_resolution, high_resolution),
lambda: (low_resolution, high_resolution))
# Randomly setting brightness of image (~50%)
# def brightness_fn(low_resolution, high_resolution):
# delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[])
# return (tf.image.adjust_brightness(low_resolution, delta=delta),
# tf.image.adjust_brightness(high_resolution, delta=delta))
#
# low_resolution, high_resolution = tf.cond(
# tf.less_equal(tf.random.uniform([]), 0.5),
# lambda: brightness_fn(low_resolution, high_resolution),
# lambda: (low_resolution, high_resolution))
#
# # Randomly setting constrast (~50%)
# def contrast_fn(low_resolution, high_resolution):
# factor = tf.random.uniform(
# minval=contrast_factor[0],
# maxval=contrast_factor[1],
# dtype=tf.float32, shape=[])
# return (tf.image.adjust_contrast(low_resolution, factor),
# tf.image.adjust_contrast(high_resolution, factor))
#
# if contrast_factor:
# low_resolution, high_resolution = tf.cond(
# tf.less_equal(tf.random.uniform([]), 0.5),
# lambda: contrast_fn(low_resolution, high_resolution),
# lambda: (low_resolution, high_resolution))
#
# # Randomly setting saturation(~50%)
# def saturation_fn(low_resolution, high_resolution):
# factor = tf.random.uniform(
# minval=saturation[0],
# maxval=saturation[1],
# dtype=tf.float32,
# shape=[])
# return (tf.image.adjust_saturation(low_resolution, factor),
# tf.image.adjust_saturation(high_resolution, factor))
#
# if saturation:
# low_resolution, high_resolution = tf.cond(
# tf.less_equal(tf.random.uniform([]), 0.5),
# lambda: saturation_fn(low_resolution, high_resolution),
# lambda: (low_resolution, high_resolution))
return low_resolution, high_resolution
# Randomly returning unchanged data (~20%)
return tf.cond(
tf.less_equal(tf.random.uniform([]), 0.2),
lambda: (low_resolution, high_resolution),
partial(augment_steps_fn, low_resolution, high_resolution))
return augment_fn