From d89b3892c460c5dac78428bb26cc2708ef3aaf3e Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 3 May 2021 15:15:03 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/icing_cnn.py | 38 ++++++++++++++++++++++---------
 1 file changed, 27 insertions(+), 11 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 4fc9fc52..58dd9c05 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -17,7 +17,8 @@ CACHE_DATA_IN_MEM = True
 
 PROC_BATCH_SIZE = 2046
 PROC_BATCH_BUFFER_SIZE = 50000
-NumLabels = 1
+NumClasses = 3
+NumLogits = 1
 BATCH_SIZE = 256
 NUM_EPOCHS = 50
 
@@ -210,8 +211,12 @@ class IcingIntensityNN:
         label = np.where(label == -1, 0, label)
 
         # binary, two class
-        label = np.where(label != 0, 1, label)
-        label = label.reshape((label.shape[0], 1))
+        if NumClasses == 2:
+            label = np.where(label != 0, 1, label)
+            label = label.reshape((label.shape[0], 1))
+        elif NumClasses == 3:
+            label = np.where((label == 1 | label == 2), 1, label)
+            label = np.where((label == 3 | label == 4 | label == 5 | label == 6), 2, label)
 
         if CACHE_DATA_IN_MEM:
             self.in_mem_data_cache[key] = (data, label)
@@ -379,14 +384,17 @@ class IcingIntensityNN:
         # activation = tf.nn.softmax # For multi-class
         activation = tf.nn.sigmoid  # For binary
 
-        logits = tf.keras.layers.Dense(NumLabels, activation=activation)(fc)
+        # Called logits, but these are actually probabilities see activation
+        logits = tf.keras.layers.Dense(NumLogits, activation=activation)(fc)
         print(logits.shape)
         
         self.logits = logits
 
     def build_training(self):
-        self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)  # for two-class only
-        #self.loss = tf.keras.losses.SparseCategoricalCrossentropy()  # For multi-class
+        if NumClasses == 2:
+            self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)  # for two-class only
+        else:
+            self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)  # For multi-class
 
         # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
         initial_learning_rate = 0.002
@@ -411,14 +419,22 @@ class IcingIntensityNN:
         self.initial_learning_rate = initial_learning_rate
 
     def build_evaluation(self):
-        self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
-        self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
-        self.test_auc = tf.keras.metrics.AUC(name='test_auc')
-        self.test_recall = tf.keras.metrics.Recall(name='test_recall')
-        self.test_precision = tf.keras.metrics.Precision(name='test_precision')
         self.train_loss = tf.keras.metrics.Mean(name='train_loss')
         self.test_loss = tf.keras.metrics.Mean(name='test_loss')
 
+        if NumClasses == 2:
+            self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
+            self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
+            self.test_auc = tf.keras.metrics.AUC(name='test_auc')
+            self.test_recall = tf.keras.metrics.Recall(name='test_recall')
+            self.test_precision = tf.keras.metrics.Precision(name='test_precision')
+        else:
+            self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
+            self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
+            self.test_auc = tf.keras.metrics.AUC(name='test_auc')
+            self.test_recall = tf.keras.metrics.Recall(name='test_recall')
+            self.test_precision = tf.keras.metrics.Precision(name='test_precision')
+
     def build_predict(self):
         _, pred = tf.nn.top_k(self.logits)
         self.pred_class = pred
-- 
GitLab