From 3dfd2c50017c25c23327b9ee724621aa208a2980 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Fri, 22 Jan 2021 14:49:40 -0600 Subject: [PATCH] snapshot... --- modules/deeplearning/cloudheight.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index 8316f11c..1728c076 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -18,7 +18,7 @@ DISK_CACHE = True PROC_BATCH_SIZE = 40 NumLabels = 1 -BATCH_SIZE = 384 +BATCH_SIZE = 512 NUM_EPOCHS = 200 GLOBAL_BATCH_SIZE = BATCH_SIZE * 3 @@ -809,9 +809,10 @@ class CloudHeightNN: self.test_accuracy.reset_states() for abi, temp, lbfp, sfc in self.test_dataset: ds = tf.data.Dataset.from_tensor_slices((abi, temp, lbfp, sfc)) - ds = ds.batch(BATCH_SIZE) - for mini_batch in ds: - self.test_step(mini_batch) + ds = ds.batch(GLOBAL_BATCH_SIZE) + tst_dist_ds = self.strategy.experimental_distribute_dataset(ds) + for mini_batch_test in tst_dist_ds: + self.distributed_test_step(mini_batch_test) print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result()) ckpt_manager.save() -- GitLab