Skip to content
Snippets Groups Projects
Commit 34bafa0a authored by tomrink's avatar tomrink
Browse files

Merge remote-tracking branch 'origin/master'

parents 952f6873 0d655e54
No related branches found
No related tags found
No related merge requests found
...@@ -215,12 +215,14 @@ class CloudHeightNN: ...@@ -215,12 +215,14 @@ class CloudHeightNN:
self.accuracy_2 = None self.accuracy_2 = None
self.accuracy_3 = None self.accuracy_3 = None
self.accuracy_4 = None self.accuracy_4 = None
self.accuracy_5 = None
self.num_0 = 0 self.num_0 = 0
self.num_1 = 0 self.num_1 = 0
self.num_2 = 0 self.num_2 = 0
self.num_3 = 0 self.num_3 = 0
self.num_4 = 0 self.num_4 = 0
self.num_5 = 0
self.learningRateSchedule = None self.learningRateSchedule = None
self.num_data_samples = None self.num_data_samples = None
...@@ -628,6 +630,7 @@ class CloudHeightNN: ...@@ -628,6 +630,7 @@ class CloudHeightNN:
self.accuracy_2 = tf.keras.metrics.MeanAbsoluteError(name='acc_2') self.accuracy_2 = tf.keras.metrics.MeanAbsoluteError(name='acc_2')
self.accuracy_3 = tf.keras.metrics.MeanAbsoluteError(name='acc_3') self.accuracy_3 = tf.keras.metrics.MeanAbsoluteError(name='acc_3')
self.accuracy_4 = tf.keras.metrics.MeanAbsoluteError(name='acc_4') self.accuracy_4 = tf.keras.metrics.MeanAbsoluteError(name='acc_4')
self.accuracy_5 = tf.keras.metrics.MeanAbsoluteError(name='acc_5')
def build_predict(self): def build_predict(self):
_, pred = tf.nn.top_k(self.logits) _, pred = tf.nn.top_k(self.logits)
...@@ -695,6 +698,10 @@ class CloudHeightNN: ...@@ -695,6 +698,10 @@ class CloudHeightNN:
self.num_4 += np.sum(m) self.num_4 += np.sum(m)
self.accuracy_4(labels[m], pred[m]) self.accuracy_4(labels[m], pred[m])
m = np.logical_and(labels >= 0.01, labels < 0.5)
self.num_5 += np.sum(m)
self.accuracy_5(labels[m], pred[m])
def do_training(self, ckpt_dir=None): def do_training(self, ckpt_dir=None):
if ckpt_dir is None: if ckpt_dir is None:
...@@ -813,6 +820,7 @@ class CloudHeightNN: ...@@ -813,6 +820,7 @@ class CloudHeightNN:
print('acc_2', self.num_2, self.accuracy_2.result()) print('acc_2', self.num_2, self.accuracy_2.result())
print('acc_3', self.num_3, self.accuracy_3.result()) print('acc_3', self.num_3, self.accuracy_3.result())
print('acc_4', self.num_4, self.accuracy_4.result()) print('acc_4', self.num_4, self.accuracy_4.result())
print('acc_5', self.num_5, self.accuracy_5.result())
def run(self, matchup_dict, train_dict=None, valid_dict=None): def run(self, matchup_dict, train_dict=None, valid_dict=None):
self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict) self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment