diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index e64d9e7fde96cdc8e0cf8b187e756bb9c5ea34d0..1cb21724891bc37ec6c43fa9074e419b21ee2f22 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -495,13 +495,6 @@ class IcingIntensityNN:
         else:
             self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
             self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
-            self.test_auc = tf.keras.metrics.AUC(name='test_auc')
-            self.test_recall = tf.keras.metrics.Recall(name='test_recall')
-            self.test_precision = tf.keras.metrics.Precision(name='test_precision')
-            self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
-            self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
-            self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
-            self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
 
     @tf.function
     def train_step(self, mini_batch):
@@ -562,13 +555,14 @@ class IcingIntensityNN:
     def reset_test_metrics(self):
         self.test_loss.reset_states()
         self.test_accuracy.reset_states()
-        self.test_auc.reset_states()
-        self.test_recall.reset_states()
-        self.test_precision.reset_states()
-        self.test_true_neg.reset_states()
-        self.test_true_pos.reset_states()
-        self.test_false_neg.reset_states()
-        self.test_false_pos.reset_states()
+        if NumClasses == 2:
+            self.test_auc.reset_states()
+            self.test_recall.reset_states()
+            self.test_precision.reset_states()
+            self.test_true_neg.reset_states()
+            self.test_true_pos.reset_states()
+            self.test_false_neg.reset_states()
+            self.test_false_pos.reset_states()
 
     def get_metrics(self):
         recall = self.test_recall.result()
@@ -634,18 +628,20 @@ class IcingIntensityNN:
                             for mini_batch_test in tst_ds:
                                 self.test_step(mini_batch_test)
 
-                        f1, mcc = self.get_metrics()
+                        if NumClasses == 2:
+                            f1, mcc = self.get_metrics()
 
                         with self.writer_valid.as_default():
                             tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
                             tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
-                            tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
-                            tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
-                            tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
-                            tf.summary.scalar('f1_val', f1, step=step)
-                            tf.summary.scalar('mcc_val', mcc, step=step)
-                            tf.summary.scalar('num_train_steps', step, step=step)
-                            tf.summary.scalar('num_epochs', epoch, step=step)
+                            if NumClasses == 2:
+                                tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
+                                tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
+                                tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
+                                tf.summary.scalar('f1_val', f1, step=step)
+                                tf.summary.scalar('mcc_val', mcc, step=step)
+                                tf.summary.scalar('num_train_steps', step, step=step)
+                                tf.summary.scalar('num_epochs', epoch, step=step)
 
                         print('****** test loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
 
@@ -667,10 +663,12 @@ class IcingIntensityNN:
                 for mini_batch in ds:
                     self.test_step(mini_batch)
 
-            f1, mcc = self.get_metrics()
-
-            print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
-                  self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
+            if NumClasses == 2:
+                f1, mcc = self.get_metrics()
+                print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
+                      self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
+            else:
+                print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
             print('------------------------------------------------------')
 
             if TRACK_MOVING_AVERAGE:  # This may not really work properly