Skip to content
Snippets Groups Projects
Commit d6ca0acf authored by tomrink's avatar tomrink
Browse files

snapshot...

parent a22815c1
Branches
Tags
No related merge requests found
...@@ -18,79 +18,79 @@ def augment_image( ...@@ -18,79 +18,79 @@ def augment_image(
tf.data.Dataset mappable function for image augmentation tf.data.Dataset mappable function for image augmentation
""" """
def augment_fn(low_resolution, high_resolution, *args, **kwargs): def augment_fn(data, label, *args, **kwargs):
# Augmenting data (~ 80%) # Augmenting data (~ 80%)
def augment_steps_fn(low_resolution, high_resolution): def augment_steps_fn(data, label):
# Randomly rotating image (~50%) # Randomly rotating image (~50%)
def rotate_fn(low_resolution, high_resolution): def rotate_fn(data, label):
times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[]) times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
return (tf.image.rot90(low_resolution, times), return (tf.image.rot90(data, times),
tf.image.rot90(high_resolution, times)) tf.image.rot90(label, times))
low_resolution, high_resolution = tf.cond( data, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5), tf.less_equal(tf.random.uniform([]), 0.5),
lambda: rotate_fn(low_resolution, high_resolution), lambda: rotate_fn(data, label),
lambda: (low_resolution, high_resolution)) lambda: (data, label))
# Randomly flipping image (~50%) # Randomly flipping image (~50%)
def flip_fn(low_resolution, high_resolution): def flip_fn(data, label):
return (tf.image.flip_left_right(low_resolution), return (tf.image.flip_left_right(data),
tf.image.flip_left_right(high_resolution)) tf.image.flip_left_right(label))
low_resolution, high_resolution = tf.cond( data, label = tf.cond(
tf.less_equal(tf.random.uniform([]), 0.5), tf.less_equal(tf.random.uniform([]), 0.5),
lambda: flip_fn(low_resolution, high_resolution), lambda: flip_fn(data, label),
lambda: (low_resolution, high_resolution)) lambda: (data, label))
# Randomly setting brightness of image (~50%) # Randomly setting brightness of image (~50%)
# def brightness_fn(low_resolution, high_resolution): # def brightness_fn(data, label):
# delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[]) # delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[])
# return (tf.image.adjust_brightness(low_resolution, delta=delta), # return (tf.image.adjust_brightness(data, delta=delta),
# tf.image.adjust_brightness(high_resolution, delta=delta)) # tf.image.adjust_brightness(label, delta=delta))
# #
# low_resolution, high_resolution = tf.cond( # data, label = tf.cond(
# tf.less_equal(tf.random.uniform([]), 0.5), # tf.less_equal(tf.random.uniform([]), 0.5),
# lambda: brightness_fn(low_resolution, high_resolution), # lambda: brightness_fn(data, label),
# lambda: (low_resolution, high_resolution)) # lambda: (data, label))
# #
# # Randomly setting constrast (~50%) # # Randomly setting constrast (~50%)
# def contrast_fn(low_resolution, high_resolution): # def contrast_fn(data, label):
# factor = tf.random.uniform( # factor = tf.random.uniform(
# minval=contrast_factor[0], # minval=contrast_factor[0],
# maxval=contrast_factor[1], # maxval=contrast_factor[1],
# dtype=tf.float32, shape=[]) # dtype=tf.float32, shape=[])
# return (tf.image.adjust_contrast(low_resolution, factor), # return (tf.image.adjust_contrast(data, factor),
# tf.image.adjust_contrast(high_resolution, factor)) # tf.image.adjust_contrast(label, factor))
# #
# if contrast_factor: # if contrast_factor:
# low_resolution, high_resolution = tf.cond( # data, label = tf.cond(
# tf.less_equal(tf.random.uniform([]), 0.5), # tf.less_equal(tf.random.uniform([]), 0.5),
# lambda: contrast_fn(low_resolution, high_resolution), # lambda: contrast_fn(data, label),
# lambda: (low_resolution, high_resolution)) # lambda: (data, label))
# #
# # Randomly setting saturation(~50%) # # Randomly setting saturation(~50%)
# def saturation_fn(low_resolution, high_resolution): # def saturation_fn(data, label):
# factor = tf.random.uniform( # factor = tf.random.uniform(
# minval=saturation[0], # minval=saturation[0],
# maxval=saturation[1], # maxval=saturation[1],
# dtype=tf.float32, # dtype=tf.float32,
# shape=[]) # shape=[])
# return (tf.image.adjust_saturation(low_resolution, factor), # return (tf.image.adjust_saturation(data, factor),
# tf.image.adjust_saturation(high_resolution, factor)) # tf.image.adjust_saturation(label, factor))
# #
# if saturation: # if saturation:
# low_resolution, high_resolution = tf.cond( # data, label = tf.cond(
# tf.less_equal(tf.random.uniform([]), 0.5), # tf.less_equal(tf.random.uniform([]), 0.5),
# lambda: saturation_fn(low_resolution, high_resolution), # lambda: saturation_fn(data, label),
# lambda: (low_resolution, high_resolution)) # lambda: (data, label))
return low_resolution, high_resolution return data, label
# Randomly returning unchanged data (~20%) # Randomly returning unchanged data (~20%)
return tf.cond( return tf.cond(
tf.less_equal(tf.random.uniform([]), 0.2), tf.less_equal(tf.random.uniform([]), 0.2),
lambda: (low_resolution, high_resolution), lambda: (data, label),
partial(augment_steps_fn, low_resolution, high_resolution)) partial(augment_steps_fn, data, label))
return augment_fn return augment_fn
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment