diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index eacd2fc33a7f94c3796bd85d645981ffa16b782f..55585f0e53a9f84e5d87dec611c95e5575f3b162 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -17,6 +17,7 @@ CACHE_GFS = True DISK_CACHE = True PROC_BATCH_SIZE = 40 +PROC_BATCH_BUFFER_SIZE = 50000 NumLabels = 1 BATCH_SIZE = 384 NUM_EPOCHS = 200 @@ -422,7 +423,7 @@ class CloudHeightNN: dataset = tf.data.Dataset.from_tensor_slices(time_keys) dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.map(self.data_function, num_parallel_calls=8) - dataset = dataset.shuffle(PROC_BATCH_SIZE*BATCH_SIZE) + dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) dataset = dataset.prefetch(buffer_size=1) self.train_dataset = dataset