From 3dfd2c50017c25c23327b9ee724621aa208a2980 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Fri, 22 Jan 2021 14:49:40 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloudheight.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py
index 8316f11c..1728c076 100644
--- a/modules/deeplearning/cloudheight.py
+++ b/modules/deeplearning/cloudheight.py
@@ -18,7 +18,7 @@ DISK_CACHE = True
 
 PROC_BATCH_SIZE = 40
 NumLabels = 1
-BATCH_SIZE = 384
+BATCH_SIZE = 512
 NUM_EPOCHS = 200
 GLOBAL_BATCH_SIZE = BATCH_SIZE * 3
 
@@ -809,9 +809,10 @@ class CloudHeightNN:
             self.test_accuracy.reset_states()
             for abi, temp, lbfp, sfc in self.test_dataset:
                 ds = tf.data.Dataset.from_tensor_slices((abi, temp, lbfp, sfc))
-                ds = ds.batch(BATCH_SIZE)
-                for mini_batch in ds:
-                    self.test_step(mini_batch)
+                ds = ds.batch(GLOBAL_BATCH_SIZE)
+                tst_dist_ds = self.strategy.experimental_distribute_dataset(ds)
+                for mini_batch_test in tst_dist_ds:
+                    self.distributed_test_step(mini_batch_test)
 
             print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
             ckpt_manager.save()
-- 
GitLab