From 3cfb756bef4f87eea4dfc72411f46195be322f10 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 5 May 2021 09:40:07 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/icing_cnn.py | 21 ++++++++++++++++++++-
 1 file changed, 20 insertions(+), 1 deletion(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index d0d3d00f..0f5f7ee5 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -15,7 +15,7 @@ LOG_DEVICE_PLACEMENT = False
 
 CACHE_DATA_IN_MEM = True
 
-PROC_BATCH_SIZE = 2046
+PROC_BATCH_SIZE = 4096
 PROC_BATCH_BUFFER_SIZE = 50000
 NumClasses = 3
 NumLogits = 1
@@ -210,6 +210,25 @@ class IcingIntensityNN:
         label = label.astype(np.int32)
         label = np.where(label == -1, 0, label)
 
+        # Augmentation, TODO: work into Dataset.map (efficiency)
+        data_aug = []
+        label_aug = []
+        for k in range(label.shape[0]):
+            if label[k] == 3 or label[k] == 4 or label[k] == 5 or label[k] == 6:
+                data_aug.append(tf.image.flip_up_down(data[k,]).numpy())
+                data_aug.append(tf.image.flip_left_right(data[k,]).numpy())
+                data_aug.append(tf.image.rot90(data[k,]).numpy())
+                label_aug.append(label[k])
+                label_aug.append(label[k])
+                label_aug.append(label[k])
+
+        data_aug = np.stack(data_aug)
+        label_aug = np.stack(label_aug)
+
+        data = np.concatenate([data, data_aug])
+        label = np.concatenate([label, label_aug])
+        # --------------------------------------------------------
+
         # binary, two class
         if NumClasses == 2:
             label = np.where(label != 0, 1, label)
-- 
GitLab