diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 52b60959fe0f668964ae7ba9a3437e2c6238414c..061cbaa3c7ab884bc83f1d938414645e7628ed59 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -8,8 +8,7 @@ import numpy as np import pickle import h5py -from icing.pirep_goes import split_data, normalize -from util.plot_cm import plot_confusion_matrix +from icing.pirep_goes import normalize LOG_DEVICE_PLACEMENT = False @@ -151,6 +150,10 @@ class IcingIntensityNN: self.test_recall = None self.test_precision = None self.test_confusion_matrix = None + self.test_true_pos = None + self.test_true_neg = None + self.test_false_pos = None + self.test_false_neg = None self.test_labels = [] self.test_preds = [] @@ -479,12 +482,20 @@ class IcingIntensityNN: self.test_auc = tf.keras.metrics.AUC(name='test_auc') self.test_recall = tf.keras.metrics.Recall(name='test_recall') self.test_precision = tf.keras.metrics.Precision(name='test_precision') + self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg') + self.test_true_pos = tf.keras.metrics.TrueNegatives(name='test_true_pos') + self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg') + self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos') else: self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') self.test_auc = tf.keras.metrics.AUC(name='test_auc') self.test_recall = tf.keras.metrics.Recall(name='test_recall') self.test_precision = tf.keras.metrics.Precision(name='test_precision') + self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg') + self.test_true_pos = tf.keras.metrics.TrueNegatives(name='test_true_pos') + self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg') + self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos') def build_predict(self): _, pred = tf.nn.top_k(self.logits) @@ -526,6 +537,10 @@ class IcingIntensityNN: self.test_auc(labels, pred) self.test_recall(labels, pred) self.test_precision(labels, pred) + self.test_true_neg(labels, pred) + self.test_true_pos(labels, pred) + self.test_false_neg(labels, pred) + self.test_false_pos(labels, pred) def predict(self, mini_batch): inputs = [mini_batch[0]] @@ -541,6 +556,10 @@ class IcingIntensityNN: self.test_auc(labels, pred) self.test_recall(labels, pred) self.test_precision(labels, pred) + self.test_true_neg(labels, pred) + self.test_true_pos(labels, pred) + self.test_false_neg(labels, pred) + self.test_false_pos(labels, pred) def do_training(self, ckpt_dir=None): @@ -587,6 +606,10 @@ class IcingIntensityNN: self.test_auc.reset_states() self.test_recall.reset_states() self.test_precision.reset_states() + self.test_true_neg.reset_states() + self.test_true_pos.reset_states() + self.test_false_neg.reset_states() + self.test_false_pos.reset_states() for data0_tst, label_tst in self.test_dataset: tst_ds = tf.data.Dataset.from_tensor_slices((data0_tst, label_tst)) @@ -621,6 +644,10 @@ class IcingIntensityNN: self.test_auc.reset_states() self.test_recall.reset_states() self.test_precision.reset_states() + self.test_true_neg.reset_states() + self.test_true_pos.reset_states() + self.test_false_neg.reset_states() + self.test_false_pos.reset_states() for data0, label in self.test_dataset: ds = tf.data.Dataset.from_tensor_slices((data0, label)) @@ -632,8 +659,15 @@ class IcingIntensityNN: precsn = self.test_precision.result().numpy() f1 = 2 * (precsn * recall) / (precsn + recall) - print('loss, acc, recall, precision, auc, f1: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(), - recall, precsn, self.test_auc.result().numpy(), f1) + tn = self.test_true_neg.result().numpy() + tp = self.test_true_pos.result().numpy() + fn = self.test_false_neg.result().numpy() + fp = self.test_false_pos.result().numpy() + + mcc = (tp*tn) - (fp*fn) / np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) + + print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(), + recall, precsn, self.test_auc.result().numpy(), f1, mcc) print('--------------------------------------------------') ckpt_manager.save()