diff --git a/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py b/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py index abf950cd8af00af7c14afae640e92fd92f41ef3b..f8fc9ef8010ec965b531c5962550d1c712fa42bf 100644 --- a/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py +++ b/modules/deeplearning/cloud_fraction_fcn_abi_hkm_refl.py @@ -1,7 +1,7 @@ import tensorflow as tf from util.plot_cm import confusion_matrix_values -from util.augment import augment_image +from util.augment import augment_image_3arg from util.setup_cloud_products import logdir, modeldir, now, ancillary_path from util.util import EarlyStop, normalize, denormalize, scale, scale2, get_grid_values_all, make_tf_callable_generator import glob @@ -445,7 +445,7 @@ class SRCNN: dataset = dataset.cache(filename=CACHE_FILE) dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True) if DO_AUGMENT: - dataset = dataset.map(augment_image(), num_parallel_calls=8) + dataset = dataset.map(augment_image_3arg(), num_parallel_calls=8) dataset = dataset.prefetch(buffer_size=1) self.train_dataset = dataset