From 5459e898a19fa93f6937e3a4e8a5bd9998b37653 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Tue, 19 Jan 2021 12:51:29 -0600 Subject: [PATCH] snapshot... --- modules/deeplearning/cloudheight.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index 157c7342..c7462664 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -260,6 +260,8 @@ class CloudHeightNN: # Memory growth must be set before GPUs have been initialized print(e) + self.strategy = tf.distribute.MirroredStrategy() + def get_in_mem_data_batch(self, time_keys): images = [] vprof = [] @@ -808,13 +810,14 @@ class CloudHeightNN: self.writer_valid.close() def build_model(self): - flat = self.build_cnn() - flat_1d = self.build_1d_cnn() - # flat_anc = self.build_anc_dnn() - # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc]) - flat = tf.keras.layers.concatenate([flat, flat_1d]) - self.build_dnn(flat) - self.model = tf.keras.Model(self.inputs, self.logits) + with self.strategy.scope(): + flat = self.build_cnn() + flat_1d = self.build_1d_cnn() + # flat_anc = self.build_anc_dnn() + # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc]) + flat = tf.keras.layers.concatenate([flat, flat_1d]) + self.build_dnn(flat) + self.model = tf.keras.Model(self.inputs, self.logits) def restore(self, ckpt_dir): -- GitLab