From c487913e220ae971a553855edb8f5266134eaef0 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 28 Jan 2021 16:12:49 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloudheight_mgpus.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/modules/deeplearning/cloudheight_mgpus.py b/modules/deeplearning/cloudheight_mgpus.py
index ee1e2fb6..49330f5b 100644
--- a/modules/deeplearning/cloudheight_mgpus.py
+++ b/modules/deeplearning/cloudheight_mgpus.py
@@ -21,6 +21,7 @@ NumLabels = 1
 BATCH_SIZE = 1024
 NUM_EPOCHS = 200
 GLOBAL_BATCH_SIZE = BATCH_SIZE * 3
+PROC_BATCH_BUFFER_SIZE = 50000
 
 
 TRACK_MOVING_AVERAGE = False
@@ -425,7 +426,7 @@ class CloudHeightNN:
         dataset = tf.data.Dataset.from_tensor_slices(time_keys)
         dataset = dataset.batch(PROC_BATCH_SIZE)
         dataset = dataset.map(self.data_function, num_parallel_calls=8)
-        dataset = dataset.shuffle(PROC_BATCH_SIZE*GLOBAL_BATCH_SIZE)
+        dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
         dataset = dataset.prefetch(buffer_size=1)
         self.train_dataset = dataset
 
@@ -817,8 +818,6 @@ class CloudHeightNN:
             print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
             ckpt_manager.save()
 
-            self.DISK_CACHE = True
-
             if self.DISK_CACHE and epoch == 0:
                 f = open(cachepath, 'wb')
                 pickle.dump(self.in_mem_data_cache, f)
-- 
GitLab