Skip to content
Snippets Groups Projects
Commit f23dca45 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent dd45704c
Branches
No related tags found
No related merge requests found
......@@ -811,14 +811,13 @@ class CloudHeightNN:
self.writer_valid.close()
def build_model(self):
with self.strategy.scope():
flat = self.build_cnn()
flat_1d = self.build_1d_cnn()
# flat_anc = self.build_anc_dnn()
# flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc])
flat = tf.keras.layers.concatenate([flat, flat_1d])
self.build_dnn(flat)
self.model = tf.keras.Model(self.inputs, self.logits)
flat = self.build_cnn()
flat_1d = self.build_1d_cnn()
# flat_anc = self.build_anc_dnn()
# flat = tf.keras.layers.concatenate([flat, flat_1d, flat_anc])
flat = tf.keras.layers.concatenate([flat, flat_1d])
self.build_dnn(flat)
self.model = tf.keras.Model(self.inputs, self.logits)
def restore(self, ckpt_dir):
......@@ -844,11 +843,12 @@ class CloudHeightNN:
print('acc_5', self.num_5, self.accuracy_5.result())
def run(self, matchup_dict, train_dict=None, valid_dict=None):
self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training()
with self.strategy.scope():
self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training()
def run_restore(self, matchup_dict, ckpt_dir):
self.setup_pipeline(None, None, matchup_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment