diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 57bc98b02dfc3adec3d83ac23b7445e6447161ca..4fc9fc5287f55eb9eeeb833c46dbed002f15435a 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -9,6 +9,7 @@ import pickle
 import h5py
 
 from icing.pirep_goes import split_data, normalize
+from util.plot_cm import plot_confusion_matrix
 
 LOG_DEVICE_PLACEMENT = False
 
@@ -240,12 +241,13 @@ class IcingIntensityNN:
         dataset = dataset.map(self.data_function, num_parallel_calls=8)
         self.test_dataset = dataset
 
-    def setup_pipeline(self, filename):
+    def setup_pipeline(self, filename, trn_idxs=None, tst_idxs=None):
         self.filename = filename
         self.h5f = h5py.File(filename, 'r')
-        time = self.h5f['time']
-        num_obs = time.shape[0]
-        trn_idxs, tst_idxs = split_data(num_obs)
+        if trn_idxs is None and tst_idxs is None:
+            time = self.h5f['time']
+            num_obs = time.shape[0]
+            trn_idxs, tst_idxs = split_data(num_obs)
         self.num_data_samples = trn_idxs.shape[0]
 
         self.get_train_dataset(trn_idxs)
@@ -464,7 +466,7 @@ class IcingIntensityNN:
         t_loss = self.loss(labels, pred)
 
         self.test_labels.append(labels)
-        self.test_preds.append(pred.result().numpy())
+        self.test_preds.append(pred.numpy())
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
@@ -594,9 +596,13 @@ class IcingIntensityNN:
             for mini_batch_test in ds:
                 self.predict(mini_batch_test)
         print('loss, acc: ', self.test_loss.result(), self.test_accuracy.result())
-        cm = tf.math.confusion_matrix(np.concatenate(self.test_labels), np.concatenate(self.test_preds), num_classes=2)
-        cm = cm.result().numpy()
-        print(cm)
+
+        labels = np.concatenate(self.test_labels)
+        preds = np.concatenate(self.test_preds)
+        preds = np.where(preds > 0.5, 1, 0)
+
+        self.test_labels = labels
+        self.test_preds = preds
 
     def run(self, filename, filename_l1b=None):
         with tf.device('/device:GPU:'+str(self.gpu_device)):