From 8fdaec386074626d914776df583b5e949afb5ae9 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 21 Jan 2021 10:20:29 -0600
Subject: [PATCH] initial commit

---
 modules/deeplearning/cloudheight.py | 34 +++++++++++++++++++++--------
 1 file changed, 25 insertions(+), 9 deletions(-)

diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py
index eacd2fc3..c77ac2f5 100644
--- a/modules/deeplearning/cloudheight.py
+++ b/modules/deeplearning/cloudheight.py
@@ -20,6 +20,7 @@ PROC_BATCH_SIZE = 40
 NumLabels = 1
 BATCH_SIZE = 384
 NUM_EPOCHS = 200
+GLOBAL_BATCH_SIZE = BATCH_SIZE * 3
 
 
 TRACK_MOVING_AVERAGE = False
@@ -261,6 +262,8 @@ class CloudHeightNN:
                 # Memory growth must be set before GPUs have been initialized
                 print(e)
 
+        self.strategy = tf.distribute.MirroredStrategy()
+
     def get_in_mem_data_batch(self, time_keys):
         images = []
         vprof = []
@@ -422,7 +425,7 @@ class CloudHeightNN:
         dataset = tf.data.Dataset.from_tensor_slices(time_keys)
         dataset = dataset.batch(PROC_BATCH_SIZE)
         dataset = dataset.map(self.data_function, num_parallel_calls=8)
-        dataset = dataset.shuffle(PROC_BATCH_SIZE*BATCH_SIZE)
+        dataset = dataset.shuffle(PROC_BATCH_SIZE*GLOBAL_BATCH_SIZE)
         dataset = dataset.prefetch(buffer_size=1)
         self.train_dataset = dataset
 
@@ -614,7 +617,7 @@ class CloudHeightNN:
         self.logits = logits
 
     def build_training(self):
-        self.loss = tf.keras.losses.MeanSquaredError()
+        self.loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
 
         # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
         initial_learning_rate = 0.0016
@@ -665,6 +668,7 @@ class CloudHeightNN:
         with tf.GradientTape() as tape:
             pred = self.model(inputs, training=True)
             loss = self.loss(labels, pred)
+            loss = tf.nn.compute_average_loss(loss, global_batch_size=GLOBAL_BATCH_SIZE)
             total_loss = loss
             if len(self.model.losses) > 0:
                 reg_loss = tf.math.add_n(self.model.losses)
@@ -683,10 +687,20 @@ class CloudHeightNN:
         labels = mini_batch[2]
         pred = self.model(inputs, training=False)
         t_loss = self.loss(labels, pred)
+        t_loss = tf.nn.compute_average_loss(t_loss, global_batch_size=GLOBAL_BATCH_SIZE)
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
 
+    @tf.function
+    def distributed_train_step(self, dataset_inputs):
+        per_replica_losses = self.strategy.run(self.train_step, args=(dataset_inputs,))
+        return self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
+
+    @tf.function
+    def distributed_test_step(self, dataset_inputs):
+        return self.strategy.run(self.test_step, args=(dataset_inputs,))
+
     def predict(self, mini_batch):
         inputs = [mini_batch[0], mini_batch[1], mini_batch[3]]
         labels = mini_batch[2]
@@ -748,10 +762,11 @@ class CloudHeightNN:
 
             for abi, temp, lbfp, sfc in self.train_dataset:
                 trn_ds = tf.data.Dataset.from_tensor_slices((abi, temp, lbfp, sfc))
-                trn_ds = trn_ds.batch(BATCH_SIZE)
-                for mini_batch in trn_ds:
+                trn_ds = trn_ds.batch(GLOBAL_BATCH_SIZE)
+                trn_dist_ds = self.strategy.experimental_distribute_dataset(trn_ds)
+                for mini_batch in trn_dist_ds:
                     if self.learningRateSchedule is not None:
-                        loss = self.train_step(mini_batch)
+                        loss = self.distributed_train_step(mini_batch)
 
                     if (step % 100) == 0:
 
@@ -765,9 +780,10 @@ class CloudHeightNN:
 
                         for abi_tst, temp_tst, lbfp_tst, sfc_tst in self.test_dataset:
                             tst_ds = tf.data.Dataset.from_tensor_slices((abi_tst, temp_tst, lbfp_tst, sfc_tst))
-                            tst_ds = tst_ds.batch(BATCH_SIZE)
-                            for mini_batch_test in tst_ds:
-                                self.test_step(mini_batch_test)
+                            tst_ds = tst_ds.batch(GLOBAL_BATCH_SIZE)
+                            tst_dist_ds = self.strategy.experimental_distribute_dataset(tst_ds)
+                            for mini_batch_test in tst_dist_ds:
+                                self.distributed_test_step(mini_batch_test)
 
                         with self.writer_valid.as_default():
                             tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
@@ -841,7 +857,7 @@ class CloudHeightNN:
         print('acc_5', self.num_5, self.accuracy_5.result())
 
     def run(self, matchup_dict, train_dict=None, valid_dict=None):
-        with tf.device('/device:GPU:'+str(self.gpu_device)):
+        with self.strategy.scope():
             self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
             self.build_model()
             self.build_training()
-- 
GitLab