diff --git a/modules/deeplearning/cloudheight_mgpus.py b/modules/deeplearning/cloudheight_mgpus.py index ee1e2fb637a20aa9a248c98ce94f6a08f13214f9..49330f5b38a59f70a41dab057634614436c4b655 100644 --- a/modules/deeplearning/cloudheight_mgpus.py +++ b/modules/deeplearning/cloudheight_mgpus.py @@ -21,6 +21,7 @@ NumLabels = 1 BATCH_SIZE = 1024 NUM_EPOCHS = 200 GLOBAL_BATCH_SIZE = BATCH_SIZE * 3 +PROC_BATCH_BUFFER_SIZE = 50000 TRACK_MOVING_AVERAGE = False @@ -425,7 +426,7 @@ class CloudHeightNN: dataset = tf.data.Dataset.from_tensor_slices(time_keys) dataset = dataset.batch(PROC_BATCH_SIZE) dataset = dataset.map(self.data_function, num_parallel_calls=8) - dataset = dataset.shuffle(PROC_BATCH_SIZE*GLOBAL_BATCH_SIZE) + dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE) dataset = dataset.prefetch(buffer_size=1) self.train_dataset = dataset @@ -817,8 +818,6 @@ class CloudHeightNN: print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result()) ckpt_manager.save() - self.DISK_CACHE = True - if self.DISK_CACHE and epoch == 0: f = open(cachepath, 'wb') pickle.dump(self.in_mem_data_cache, f)