From e9e080f3817508640019b4c68d895bcf8842d7c5 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 23 Oct 2023 13:44:57 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/icing_fcn.py | 44 ++++++++++++++++---------------
 1 file changed, 23 insertions(+), 21 deletions(-)

diff --git a/modules/deeplearning/icing_fcn.py b/modules/deeplearning/icing_fcn.py
index 89166211..f70d69d7 100644
--- a/modules/deeplearning/icing_fcn.py
+++ b/modules/deeplearning/icing_fcn.py
@@ -2,6 +2,7 @@ import tensorflow as tf
 from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir
 from util.util import EarlyStop, normalize, make_for_full_domain_predict, get_training_parameters
 from util.geos_nav import get_navigation
+from util.augment import augment_image
 
 import os, datetime
 import numpy as np
@@ -315,26 +316,26 @@ class IcingIntensityFCN:
             label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label)
             label = label.reshape((label.shape[0], 1))
 
-        if is_training and DO_AUGMENT:
-            data_ud = np.flip(data, axis=1)
-            data_alt_ud = np.copy(data_alt)
-            label_ud = np.copy(label)
-
-            data_lr = np.flip(data, axis=2)
-            data_alt_lr = np.copy(data_alt)
-            label_lr = np.copy(label)
-
-            data_r1 = np.rot90(data, k=1, axes=(1, 2))
-            data_alt_r1 = np.copy(data_alt)
-            label_r1 = np.copy(label)
-
-            data_r2 = np.rot90(data, k=1, axes=(1, 2))
-            data_alt_r2 = np.copy(data_alt)
-            label_r2 = np.copy(label)
-
-            data = np.concatenate([data, data_ud, data_lr, data_r1, data_r2])
-            data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr, data_alt_r1, data_alt_r2])
-            label = np.concatenate([label, label_ud, label_lr, label_r1, label_r2])
+        # if is_training and DO_AUGMENT:
+        #     data_ud = np.flip(data, axis=1)
+        #     data_alt_ud = np.copy(data_alt)
+        #     label_ud = np.copy(label)
+        #
+        #     data_lr = np.flip(data, axis=2)
+        #     data_alt_lr = np.copy(data_alt)
+        #     label_lr = np.copy(label)
+        #
+        #     data_r1 = np.rot90(data, k=1, axes=(1, 2))
+        #     data_alt_r1 = np.copy(data_alt)
+        #     label_r1 = np.copy(label)
+        #
+        #     data_r2 = np.rot90(data, k=1, axes=(1, 2))
+        #     data_alt_r2 = np.copy(data_alt)
+        #     label_r2 = np.copy(label)
+        #
+        #     data = np.concatenate([data, data_ud, data_lr, data_r1, data_r2])
+        #     data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr, data_alt_r1, data_alt_r2])
+        #     label = np.concatenate([label, label_ud, label_lr, label_r1, label_r2])
 
         return data, data_alt, label
 
@@ -475,8 +476,9 @@ class IcingIntensityFCN:
         dataset = dataset.batch(PROC_BATCH_SIZE)
         dataset = dataset.map(self.data_function, num_parallel_calls=8)
         dataset = dataset.cache()
+        dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True)
         if DO_AUGMENT:
-            dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
+            dataset = dataset.map(augment_image(), num_parallel_calls=8)
         dataset = dataset.prefetch(buffer_size=1)
         self.train_dataset = dataset
 
-- 
GitLab