From 69f178e619929d538f65a14e5504db01f6670e16 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Wed, 23 Mar 2022 09:08:22 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/icing_cnn.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 15b3bafc..c05cdcff 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -28,6 +28,7 @@ NUM_EPOCHS = 100 TRACK_MOVING_AVERAGE = False EARLY_STOP = False +DO_AUGMENT = True TRIPLET = False CONV3D = False @@ -311,6 +312,19 @@ class IcingIntensityNN: else: self.in_mem_data_cache_test[key] = (data, data_alt, label) + if is_training and DO_AUGMENT: + data_ud = np.flip(data, axis=1) + data_alt_ud = np.copy(data_alt) + label_ud = np.copy(label) + + data_lr = np.flip(data, axis=2) + data_alt_lr = np.copy(data_alt) + label_lr = np.copy(label) + + data = np.concatenate([data, data_ud, data_lr]) + data_alt = np.concatenate([data_alt, data_alt_ud, data_alt_lr]) + label = np.concatenate([label, label_ud, label_lr]) + return data, data_alt, label def get_parameter_data(self, param, nd_idxs, is_training): -- GitLab