diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py
index 157c734220c6798f31e8596e5592ea70b02e22af..c74626644c6cbfedbffe9d5a4c52304ab2c152a5 100644
--- a/modules/deeplearning/cloudheight.py
+++ b/modules/deeplearning/cloudheight.py
@@ -260,6 +260,8 @@ class CloudHeightNN:
                 # Memory growth must be set before GPUs have been initialized
                 print(e)
 
+        self.strategy = tf.distribute.MirroredStrategy()
+
     def get_in_mem_data_batch(self, time_keys):
         images = []
         vprof = []
@@ -808,13 +810,14 @@ class CloudHeightNN:
         self.writer_valid.close()
 
     def build_model(self):
-        flat = self.build_cnn()
-        flat_1d = self.build_1d_cnn()
-        # flat_anc = self.build_anc_dnn()
-        # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc])
-        flat = tf.keras.layers.concatenate([flat, flat_1d])
-        self.build_dnn(flat)
-        self.model = tf.keras.Model(self.inputs, self.logits)
+        with self.strategy.scope():
+            flat = self.build_cnn()
+            flat_1d = self.build_1d_cnn()
+            # flat_anc = self.build_anc_dnn()
+            # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc])
+            flat = tf.keras.layers.concatenate([flat, flat_1d])
+            self.build_dnn(flat)
+            self.model = tf.keras.Model(self.inputs, self.logits)
 
     def restore(self, ckpt_dir):