diff --git a/modules/util/augment.py b/modules/util/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a4181794fd4c0e959ca44b0052603355ee0eaed
--- /dev/null
+++ b/modules/util/augment.py
@@ -0,0 +1,95 @@
+from functools import partial
+
+import tensorflow as tf
+
+
+def augment_image(
+        brightness_delta=0.05,
+        contrast_factor=[0.7, 1.3],
+        saturation=[0.6, 1.6]):
+    """ Helper function used for augmentation of images in the dataset.
+      Args:
+        brightness_delta: maximum value for randomly assigning brightness of the image.
+        contrast_factor: list / tuple of minimum and maximum value of factor to set random contrast.
+                          None, if not to be used.
+        saturation: list / tuple of minimum and maximum value of factor to set random saturation.
+                    None, if not to be used.
+      Returns:
+        tf.data.Dataset mappable function for image augmentation
+    """
+
+    def augment_fn(low_resolution, high_resolution, *args, **kwargs):
+        # Augmenting data (~ 80%)
+        def augment_steps_fn(low_resolution, high_resolution):
+            # Randomly rotating image (~50%)
+            def rotate_fn(low_resolution, high_resolution):
+                times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
+                return (tf.image.rot90(low_resolution, times),
+                        tf.image.rot90(high_resolution, times))
+
+            low_resolution, high_resolution = tf.cond(
+                tf.less_equal(tf.random.uniform([]), 0.5),
+                lambda: rotate_fn(low_resolution, high_resolution),
+                lambda: (low_resolution, high_resolution))
+
+            # Randomly flipping image (~50%)
+            def flip_fn(low_resolution, high_resolution):
+                return (tf.image.flip_left_right(low_resolution),
+                        tf.image.flip_left_right(high_resolution))
+
+            low_resolution, high_resolution = tf.cond(
+                tf.less_equal(tf.random.uniform([]), 0.5),
+                lambda: flip_fn(low_resolution, high_resolution),
+                lambda: (low_resolution, high_resolution))
+
+            # Randomly setting brightness of image (~50%)
+            # def brightness_fn(low_resolution, high_resolution):
+            #     delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[])
+            #     return (tf.image.adjust_brightness(low_resolution, delta=delta),
+            #             tf.image.adjust_brightness(high_resolution, delta=delta))
+            #
+            # low_resolution, high_resolution = tf.cond(
+            #     tf.less_equal(tf.random.uniform([]), 0.5),
+            #     lambda: brightness_fn(low_resolution, high_resolution),
+            #     lambda: (low_resolution, high_resolution))
+            #
+            # # Randomly setting constrast (~50%)
+            # def contrast_fn(low_resolution, high_resolution):
+            #     factor = tf.random.uniform(
+            #         minval=contrast_factor[0],
+            #         maxval=contrast_factor[1],
+            #         dtype=tf.float32, shape=[])
+            #     return (tf.image.adjust_contrast(low_resolution, factor),
+            #             tf.image.adjust_contrast(high_resolution, factor))
+            #
+            # if contrast_factor:
+            #     low_resolution, high_resolution = tf.cond(
+            #         tf.less_equal(tf.random.uniform([]), 0.5),
+            #         lambda: contrast_fn(low_resolution, high_resolution),
+            #         lambda: (low_resolution, high_resolution))
+            #
+            # # Randomly setting saturation(~50%)
+            # def saturation_fn(low_resolution, high_resolution):
+            #     factor = tf.random.uniform(
+            #         minval=saturation[0],
+            #         maxval=saturation[1],
+            #         dtype=tf.float32,
+            #         shape=[])
+            #     return (tf.image.adjust_saturation(low_resolution, factor),
+            #             tf.image.adjust_saturation(high_resolution, factor))
+            #
+            # if saturation:
+            #     low_resolution, high_resolution = tf.cond(
+            #         tf.less_equal(tf.random.uniform([]), 0.5),
+            #         lambda: saturation_fn(low_resolution, high_resolution),
+            #         lambda: (low_resolution, high_resolution))
+
+            return low_resolution, high_resolution
+
+        # Randomly returning unchanged data (~20%)
+        return tf.cond(
+            tf.less_equal(tf.random.uniform([]), 0.2),
+            lambda: (low_resolution, high_resolution),
+            partial(augment_steps_fn, low_resolution, high_resolution))
+
+    return augment_fn