From b781badd876105f0fc0694e7c833076f6219a23e Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 23 Oct 2023 14:26:17 -0500
Subject: [PATCH] snapshot...

---
 modules/util/augment.py | 44 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 44 insertions(+)

diff --git a/modules/util/augment.py b/modules/util/augment.py
index 7c05cb50..9668d780 100644
--- a/modules/util/augment.py
+++ b/modules/util/augment.py
@@ -43,3 +43,47 @@ def augment_image():
             partial(augment_steps_fn, data, label))
 
     return augment_fn
+
+
+def augment_image_3arg():
+    """ Helper function used for augmentation of images in the dataset.
+      Returns:
+        tf.data.Dataset mappable function for image augmentation
+    """
+
+    def augment_fn(data, data_b, label, *args, **kwargs):
+        # Augmenting data (~ 80%)
+
+        def augment_steps_fn(data, data_b, label):
+            # Randomly rotating image (~50%)
+            def rotate_fn(data, data_b, label):
+                times = tf.random.uniform(minval=1, maxval=4, dtype=tf.int32, shape=[])
+                return (tf.image.rot90(data, times),
+                        tf.image.rot90(data_b, times),
+                        tf.image.rot90(label, times))
+
+            data, data_b, label = tf.cond(
+                tf.less_equal(tf.random.uniform([]), 0.5),
+                lambda: rotate_fn(data, data_b, label),
+                lambda: (data, data_b, label))
+
+            # Randomly flipping image (~50%)
+            def flip_fn(data, data_b, label):
+                return (tf.image.flip_left_right(data),
+                        tf.image.flip_left_right(data_b),
+                        tf.image.flip_left_right(label))
+
+            data, data_b, label = tf.cond(
+                tf.less_equal(tf.random.uniform([]), 0.5),
+                lambda: flip_fn(data, data_b, label),
+                lambda: (data, data_b, label))
+
+            return data, data_b, label
+
+        # Randomly returning unchanged data (~20%)
+        return tf.cond(
+            tf.less_equal(tf.random.uniform([]), 0.2),
+            lambda: (data, data_b, label),
+            partial(augment_steps_fn, data, data_b, label))
+
+    return augment_fn
-- 
GitLab