From 5459e898a19fa93f6937e3a4e8a5bd9998b37653 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 19 Jan 2021 12:51:29 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloudheight.py | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py
index 157c7342..c7462664 100644
--- a/modules/deeplearning/cloudheight.py
+++ b/modules/deeplearning/cloudheight.py
@@ -260,6 +260,8 @@ class CloudHeightNN:
                 # Memory growth must be set before GPUs have been initialized
                 print(e)
 
+        self.strategy = tf.distribute.MirroredStrategy()
+
     def get_in_mem_data_batch(self, time_keys):
         images = []
         vprof = []
@@ -808,13 +810,14 @@ class CloudHeightNN:
         self.writer_valid.close()
 
     def build_model(self):
-        flat = self.build_cnn()
-        flat_1d = self.build_1d_cnn()
-        # flat_anc = self.build_anc_dnn()
-        # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc])
-        flat = tf.keras.layers.concatenate([flat, flat_1d])
-        self.build_dnn(flat)
-        self.model = tf.keras.Model(self.inputs, self.logits)
+        with self.strategy.scope():
+            flat = self.build_cnn()
+            flat_1d = self.build_1d_cnn()
+            # flat_anc = self.build_anc_dnn()
+            # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc])
+            flat = tf.keras.layers.concatenate([flat, flat_1d])
+            self.build_dnn(flat)
+            self.model = tf.keras.Model(self.inputs, self.logits)
 
     def restore(self, ckpt_dir):
 
-- 
GitLab