diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index 157c734220c6798f31e8596e5592ea70b02e22af..c74626644c6cbfedbffe9d5a4c52304ab2c152a5 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -260,6 +260,8 @@ class CloudHeightNN: # Memory growth must be set before GPUs have been initialized print(e) + self.strategy = tf.distribute.MirroredStrategy() + def get_in_mem_data_batch(self, time_keys): images = [] vprof = [] @@ -808,13 +810,14 @@ class CloudHeightNN: self.writer_valid.close() def build_model(self): - flat = self.build_cnn() - flat_1d = self.build_1d_cnn() - # flat_anc = self.build_anc_dnn() - # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc]) - flat = tf.keras.layers.concatenate([flat, flat_1d]) - self.build_dnn(flat) - self.model = tf.keras.Model(self.inputs, self.logits) + with self.strategy.scope(): + flat = self.build_cnn() + flat_1d = self.build_1d_cnn() + # flat_anc = self.build_anc_dnn() + # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc]) + flat = tf.keras.layers.concatenate([flat, flat_1d]) + self.build_dnn(flat) + self.model = tf.keras.Model(self.inputs, self.logits) def restore(self, ckpt_dir):