diff --git a/modules/util/augment.py b/modules/util/augment.py
index c48f3e8341988ef5072d805ff8ce61b3b308c780..f88b93df49a92d1d03402416ebe28aca300be4c3 100644
--- a/modules/util/augment.py
+++ b/modules/util/augment.py
@@ -18,79 +18,79 @@ def augment_image(
         tf.data.Dataset mappable function for image augmentation
     """
 
-    def augment_fn(low_resolution, high_resolution, *args, **kwargs):
+    def augment_fn(data, label, *args, **kwargs):
         # Augmenting data (~ 80%)
 
-        def augment_steps_fn(low_resolution, high_resolution):
+        def augment_steps_fn(data, label):
             # Randomly rotating image (~50%)
-            def rotate_fn(low_resolution, high_resolution):
+            def rotate_fn(data, label):
                 times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
-                return (tf.image.rot90(low_resolution, times),
-                        tf.image.rot90(high_resolution, times))
+                return (tf.image.rot90(data, times),
+                        tf.image.rot90(label, times))
 
-            low_resolution, high_resolution = tf.cond(
+            data, label = tf.cond(
                 tf.less_equal(tf.random.uniform([]), 0.5),
-                lambda: rotate_fn(low_resolution, high_resolution),
-                lambda: (low_resolution, high_resolution))
+                lambda: rotate_fn(data, label),
+                lambda: (data, label))
 
             # Randomly flipping image (~50%)
-            def flip_fn(low_resolution, high_resolution):
-                return (tf.image.flip_left_right(low_resolution),
-                        tf.image.flip_left_right(high_resolution))
+            def flip_fn(data, label):
+                return (tf.image.flip_left_right(data),
+                        tf.image.flip_left_right(label))
 
-            low_resolution, high_resolution = tf.cond(
+            data, label = tf.cond(
                 tf.less_equal(tf.random.uniform([]), 0.5),
-                lambda: flip_fn(low_resolution, high_resolution),
-                lambda: (low_resolution, high_resolution))
+                lambda: flip_fn(data, label),
+                lambda: (data, label))
 
             # Randomly setting brightness of image (~50%)
-            # def brightness_fn(low_resolution, high_resolution):
+            # def brightness_fn(data, label):
             #     delta = tf.random.uniform(minval=0, maxval=brightness_delta, dtype=tf.float32, shape=[])
-            #     return (tf.image.adjust_brightness(low_resolution, delta=delta),
-            #             tf.image.adjust_brightness(high_resolution, delta=delta))
+            #     return (tf.image.adjust_brightness(data, delta=delta),
+            #             tf.image.adjust_brightness(label, delta=delta))
             #
-            # low_resolution, high_resolution = tf.cond(
+            # data, label = tf.cond(
             #     tf.less_equal(tf.random.uniform([]), 0.5),
-            #     lambda: brightness_fn(low_resolution, high_resolution),
-            #     lambda: (low_resolution, high_resolution))
+            #     lambda: brightness_fn(data, label),
+            #     lambda: (data, label))
             #
             # # Randomly setting constrast (~50%)
-            # def contrast_fn(low_resolution, high_resolution):
+            # def contrast_fn(data, label):
             #     factor = tf.random.uniform(
             #         minval=contrast_factor[0],
             #         maxval=contrast_factor[1],
             #         dtype=tf.float32, shape=[])
-            #     return (tf.image.adjust_contrast(low_resolution, factor),
-            #             tf.image.adjust_contrast(high_resolution, factor))
+            #     return (tf.image.adjust_contrast(data, factor),
+            #             tf.image.adjust_contrast(label, factor))
             #
             # if contrast_factor:
-            #     low_resolution, high_resolution = tf.cond(
+            #     data, label = tf.cond(
             #         tf.less_equal(tf.random.uniform([]), 0.5),
-            #         lambda: contrast_fn(low_resolution, high_resolution),
-            #         lambda: (low_resolution, high_resolution))
+            #         lambda: contrast_fn(data, label),
+            #         lambda: (data, label))
             #
             # # Randomly setting saturation(~50%)
-            # def saturation_fn(low_resolution, high_resolution):
+            # def saturation_fn(data, label):
             #     factor = tf.random.uniform(
             #         minval=saturation[0],
             #         maxval=saturation[1],
             #         dtype=tf.float32,
             #         shape=[])
-            #     return (tf.image.adjust_saturation(low_resolution, factor),
-            #             tf.image.adjust_saturation(high_resolution, factor))
+            #     return (tf.image.adjust_saturation(data, factor),
+            #             tf.image.adjust_saturation(label, factor))
             #
             # if saturation:
-            #     low_resolution, high_resolution = tf.cond(
+            #     data, label = tf.cond(
             #         tf.less_equal(tf.random.uniform([]), 0.5),
-            #         lambda: saturation_fn(low_resolution, high_resolution),
-            #         lambda: (low_resolution, high_resolution))
+            #         lambda: saturation_fn(data, label),
+            #         lambda: (data, label))
 
-            return low_resolution, high_resolution
+            return data, label
 
         # Randomly returning unchanged data (~20%)
         return tf.cond(
             tf.less_equal(tf.random.uniform([]), 0.2),
-            lambda: (low_resolution, high_resolution),
-            partial(augment_steps_fn, low_resolution, high_resolution))
+            lambda: (data, label),
+            partial(augment_steps_fn, data, label))
 
     return augment_fn