Skip to content
Snippets Groups Projects
Commit 0d655e54 authored by rink's avatar rink
Browse files

snapshot..

parent 64290210
No related branches found
No related tags found
No related merge requests found
......@@ -215,12 +215,14 @@ class CloudHeightNN:
self.accuracy_2 = None
self.accuracy_3 = None
self.accuracy_4 = None
self.accuracy_5 = None
self.num_0 = 0
self.num_1 = 0
self.num_2 = 0
self.num_3 = 0
self.num_4 = 0
self.num_5 = 0
self.learningRateSchedule = None
self.num_data_samples = None
......@@ -628,6 +630,7 @@ class CloudHeightNN:
self.accuracy_2 = tf.keras.metrics.MeanAbsoluteError(name='acc_2')
self.accuracy_3 = tf.keras.metrics.MeanAbsoluteError(name='acc_3')
self.accuracy_4 = tf.keras.metrics.MeanAbsoluteError(name='acc_4')
self.accuracy_5 = tf.keras.metrics.MeanAbsoluteError(name='acc_5')
def build_predict(self):
_, pred = tf.nn.top_k(self.logits)
......@@ -695,6 +698,10 @@ class CloudHeightNN:
self.num_4 += np.sum(m)
self.accuracy_4(labels[m], pred[m])
m = np.logical_and(labels >= 0.01, labels < 0.5)
self.num_5 += np.sum(m)
self.accuracy_5(labels[m], pred[m])
def do_training(self, ckpt_dir=None):
if ckpt_dir is None:
......@@ -813,6 +820,7 @@ class CloudHeightNN:
print('acc_2', self.num_2, self.accuracy_2.result())
print('acc_3', self.num_3, self.accuracy_3.result())
print('acc_4', self.num_4, self.accuracy_4.result())
print('acc_5', self.num_5, self.accuracy_5.result())
def run(self, matchup_dict, train_dict=None, valid_dict=None):
self.setup_pipeline(matchup_dict, train_dict=train_dict, valid_test_dict=valid_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment