Skip to content
Snippets Groups Projects
Commit 34bafa0a authored by tomrink's avatar tomrink
Browse files

Merge remote-tracking branch 'origin/master'

parents 952f6873 0d655e54
No related merge requests found
......@@ -215,12 +215,14 @@ class CloudHeightNN:
self.accuracy_2 = None
self.accuracy_3 = None
self.accuracy_4 = None
self.accuracy_5 = None
self.num_0 = 0
self.num_1 = 0
self.num_2 = 0
self.num_3 = 0
self.num_4 = 0
self.num_5 = 0
self.learningRateSchedule = None
self.num_data_samples = None
......@@ -628,6 +630,7 @@ class CloudHeightNN:
self.accuracy_2 = tf.keras.metrics.MeanAbsoluteError(name='acc_2')
self.accuracy_3 = tf.keras.metrics.MeanAbsoluteError(name='acc_3')
self.accuracy_4 = tf.keras.metrics.MeanAbsoluteError(name='acc_4')
self.accuracy_5 = tf.keras.metrics.MeanAbsoluteError(name='acc_5')
def build_predict(self):
_, pred = tf.nn.top_k(self.logits)
......@@ -695,6 +698,10 @@ class CloudHeightNN:
self.num_4 += np.sum(m)
self.accuracy_4(labels[m], pred[m])
m = np.logical_and(labels >= 0.01, labels < 0.5)
self.num_5 += np.sum(m)
self.accuracy_5(labels[m], pred[m])
def do_training(self, ckpt_dir=None):
if ckpt_dir is None:
......@@ -813,6 +820,7 @@ class CloudHeightNN:
print('acc_2', self.num_2, self.accuracy_2.result())
print('acc_3', self.num_3, self.accuracy_3.result())
print('acc_4', self.num_4, self.accuracy_4.result())
print('acc_5', self.num_5, self.accuracy_5.result())
def run(self, matchup_dict, train_dict=None, valid_dict=None):
self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment