From 8455676119a05b3d791421434f47b73b8a5a4aa0 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 3 May 2021 13:16:57 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/icing_cnn.py | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 57bc98b0..4fc9fc52 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -9,6 +9,7 @@ import pickle
 import h5py
 
 from icing.pirep_goes import split_data, normalize
+from util.plot_cm import plot_confusion_matrix
 
 LOG_DEVICE_PLACEMENT = False
 
@@ -240,12 +241,13 @@ class IcingIntensityNN:
         dataset = dataset.map(self.data_function, num_parallel_calls=8)
         self.test_dataset = dataset
 
-    def setup_pipeline(self, filename):
+    def setup_pipeline(self, filename, trn_idxs=None, tst_idxs=None):
         self.filename = filename
         self.h5f = h5py.File(filename, 'r')
-        time = self.h5f['time']
-        num_obs = time.shape[0]
-        trn_idxs, tst_idxs = split_data(num_obs)
+        if trn_idxs is None and tst_idxs is None:
+            time = self.h5f['time']
+            num_obs = time.shape[0]
+            trn_idxs, tst_idxs = split_data(num_obs)
         self.num_data_samples = trn_idxs.shape[0]
 
         self.get_train_dataset(trn_idxs)
@@ -464,7 +466,7 @@ class IcingIntensityNN:
         t_loss = self.loss(labels, pred)
 
         self.test_labels.append(labels)
-        self.test_preds.append(pred.result().numpy())
+        self.test_preds.append(pred.numpy())
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
@@ -594,9 +596,13 @@ class IcingIntensityNN:
             for mini_batch_test in ds:
                 self.predict(mini_batch_test)
         print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
-        cm = tf.math.confusion_matrix(np.concatenate(self.test_labels), np.concatenate(self.test_preds), num_classes=2)
-        cm = cm.result().numpy()
-        print(cm)
+
+        labels = np.concatenate(self.test_labels)
+        preds = np.concatenate(self.test_preds)
+        preds = np.where(preds > 0.5, 1, 0)
+
+        self.test_labels = labels
+        self.test_preds = preds
 
     def run(self, filename, filename_l1b=None):
         with tf.device('/device:GPU:'+str(self.gpu_device)):
-- 
GitLab