Skip to content
Snippets Groups Projects
Commit 0d655e54 authored by rink's avatar rink
Browse files

snapshot..

parent 64290210
No related branches found
No related tags found
No related merge requests found
...@@ -215,12 +215,14 @@ class CloudHeightNN: ...@@ -215,12 +215,14 @@ class CloudHeightNN:
self.accuracy_2 = None self.accuracy_2 = None
self.accuracy_3 = None self.accuracy_3 = None
self.accuracy_4 = None self.accuracy_4 = None
self.accuracy_5 = None
self.num_0 = 0 self.num_0 = 0
self.num_1 = 0 self.num_1 = 0
self.num_2 = 0 self.num_2 = 0
self.num_3 = 0 self.num_3 = 0
self.num_4 = 0 self.num_4 = 0
self.num_5 = 0
self.learningRateSchedule = None self.learningRateSchedule = None
self.num_data_samples = None self.num_data_samples = None
...@@ -628,6 +630,7 @@ class CloudHeightNN: ...@@ -628,6 +630,7 @@ class CloudHeightNN:
self.accuracy_2 = tf.keras.metrics.MeanAbsoluteError(name='acc_2') self.accuracy_2 = tf.keras.metrics.MeanAbsoluteError(name='acc_2')
self.accuracy_3 = tf.keras.metrics.MeanAbsoluteError(name='acc_3') self.accuracy_3 = tf.keras.metrics.MeanAbsoluteError(name='acc_3')
self.accuracy_4 = tf.keras.metrics.MeanAbsoluteError(name='acc_4') self.accuracy_4 = tf.keras.metrics.MeanAbsoluteError(name='acc_4')
self.accuracy_5 = tf.keras.metrics.MeanAbsoluteError(name='acc_5')
def build_predict(self): def build_predict(self):
_, pred = tf.nn.top_k(self.logits) _, pred = tf.nn.top_k(self.logits)
...@@ -695,6 +698,10 @@ class CloudHeightNN: ...@@ -695,6 +698,10 @@ class CloudHeightNN:
self.num_4 += np.sum(m) self.num_4 += np.sum(m)
self.accuracy_4(labels[m], pred[m]) self.accuracy_4(labels[m], pred[m])
m = np.logical_and(labels >= 0.01, labels < 0.5)
self.num_5 += np.sum(m)
self.accuracy_5(labels[m], pred[m])
def do_training(self, ckpt_dir=None): def do_training(self, ckpt_dir=None):
if ckpt_dir is None: if ckpt_dir is None:
...@@ -813,6 +820,7 @@ class CloudHeightNN: ...@@ -813,6 +820,7 @@ class CloudHeightNN:
print('acc_2', self.num_2, self.accuracy_2.result()) print('acc_2', self.num_2, self.accuracy_2.result())
print('acc_3', self.num_3, self.accuracy_3.result()) print('acc_3', self.num_3, self.accuracy_3.result())
print('acc_4', self.num_4, self.accuracy_4.result()) print('acc_4', self.num_4, self.accuracy_4.result())
print('acc_5', self.num_5, self.accuracy_5.result())
def run(self, matchup_dict, train_dict=None, valid_dict=None): def run(self, matchup_dict, train_dict=None, valid_dict=None):
self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict) self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment