diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index d0d3d00fd6b1c678dbb445f6b08c52d07398ff57..0f5f7ee5efd862490fe250292a789b5b5434d55d 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -15,7 +15,7 @@ LOG_DEVICE_PLACEMENT = False CACHE_DATA_IN_MEM = True -PROC_BATCH_SIZE = 2046 +PROC_BATCH_SIZE = 4096 PROC_BATCH_BUFFER_SIZE = 50000 NumClasses = 3 NumLogits = 1 @@ -210,6 +210,25 @@ class IcingIntensityNN: label = label.astype(np.int32) label = np.where(label == -1, 0, label) + # Augmentation, TODO: work into Dataset.map (efficiency) + data_aug = [] + label_aug = [] + for k in range(label.shape[0]): + if label[k] == 3 or label[k] == 4 or label[k] == 5 or label[k] == 6: + data_aug.append(tf.image.flip_up_down(data[k,]).numpy()) + data_aug.append(tf.image.flip_left_right(data[k,]).numpy()) + data_aug.append(tf.image.rot90(data[k,]).numpy()) + label_aug.append(label[k]) + label_aug.append(label[k]) + label_aug.append(label[k]) + + data_aug = np.stack(data_aug) + label_aug = np.stack(label_aug) + + data = np.concatenate([data, data_aug]) + label = np.concatenate([label, label_aug]) + # -------------------------------------------------------- + # binary, two class if NumClasses == 2: label = np.where(label != 0, 1, label)