diff --git a/modules/deeplearning/cloudheight.py b/modules/deeplearning/cloudheight.py index f57c5599855fe76e8f1bf2e90481b2e12ae26804..c4562315b0b87e92587f9ccbb9a678dfc22eaeb6 100644 --- a/modules/deeplearning/cloudheight.py +++ b/modules/deeplearning/cloudheight.py @@ -811,14 +811,13 @@ class CloudHeightNN: self.writer_valid.close() def build_model(self): - with self.strategy.scope(): - flat = self.build_cnn() - flat_1d = self.build_1d_cnn() - # flat_anc = self.build_anc_dnn() - # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc]) - flat = tf.keras.layers.concatenate([flat, flat_1d]) - self.build_dnn(flat) - self.model = tf.keras.Model(self.inputs, self.logits) + flat = self.build_cnn() + flat_1d = self.build_1d_cnn() + # flat_anc = self.build_anc_dnn() + # flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc]) + flat = tf.keras.layers.concatenate([flat, flat_1d]) + self.build_dnn(flat) + self.model = tf.keras.Model(self.inputs, self.logits) def restore(self, ckpt_dir): @@ -844,11 +843,12 @@ class CloudHeightNN: print('acc_5', self.num_5, self.accuracy_5.result()) def run(self, matchup_dict, train_dict=None, valid_dict=None): - self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict) - self.build_model() - self.build_training() - self.build_evaluation() - self.do_training() + with self.strategy.scope(): + self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict) + self.build_model() + self.build_training() + self.build_evaluation() + self.do_training() def run_restore(self, matchup_dict, ckpt_dir): self.setup_pipeline(None, None, matchup_dict)