diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index eac05cb067c4b8f4d8781aee9f5cfabc96471290..c07277d19bd4bcb08226ae688879e84dd2323452 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -2,6 +2,7 @@ import tensorflow as tf from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir from util.util import EarlyStop, normalize from icing.util import make_for_full_domain_predict +from util.augment import augment_icing from util.geos_nav import get_navigation, get_lon_lat_2d_mesh import os, datetime @@ -440,8 +441,10 @@ class IcingIntensityNN: dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.map(self.data_function, num_parallel_calls=8) dataset = dataset.cache() + # dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True) if DO_AUGMENT: dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) + # dataset = dataset.map(augment_icing(), num_parallel_calls=8) dataset = dataset.prefetch(buffer_size=1) self.train_dataset = dataset