From 0617bd07c4591f4b59b426cb0d8f80b3cacf442d Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 9 Aug 2021 15:27:31 -0500
Subject: [PATCH] new method

---
 modules/deeplearning/icing_cnn.py | 41 ++++++++++++++++++++++---------
 1 file changed, 30 insertions(+), 11 deletions(-)

diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py
index 335b407f..d73948e4 100644
--- a/modules/deeplearning/icing_cnn.py
+++ b/modules/deeplearning/icing_cnn.py
@@ -38,19 +38,19 @@ NOISE_TRAINING = False
 
 img_width = 16
 
-# mean_std_file = homedir+'data/icing/mean_std_no_ice.pkl'
-mean_std_file = homedir+'data/icing/mean_std_l1b_no_ice.pkl'
+mean_std_file = homedir+'data/icing/mean_std_no_ice.pkl'
+# mean_std_file = homedir+'data/icing/mean_std_l1b_no_ice.pkl'
 f = open(mean_std_file, 'rb')
 mean_std_dct = pickle.load(f)
 f.close()
 
 # train_params = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_press_acha',
 #                 'cld_reff_acha', 'cld_opd_acha', 'conv_cloud_fraction', 'cld_emiss_acha']
-# train_params = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_press_acha',
-#                 'cld_reff_dcomp', 'cld_opd_dcomp', 'cld_cwp_dcomp', 'iwc_dcomp', 'lwc_dcomp', 'conv_cloud_fraction', 'cld_emiss_acha']
-train_params = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
-                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
-                'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
+train_params = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_press_acha',
+                'cld_reff_dcomp', 'cld_opd_dcomp', 'cld_cwp_dcomp', 'iwc_dcomp', 'lwc_dcomp', 'conv_cloud_fraction', 'cld_emiss_acha']
+# train_params = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
+#                 'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
+#                 'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
 # train_params = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
 #                 'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom']
 
@@ -418,8 +418,8 @@ class IcingIntensityNN:
         activation = tf.nn.leaky_relu
         momentum = 0.99
 
-        num_filters = 16
-        # num_filters = 12
+        # num_filters = 16
+        num_filters = 12
 
         conv = tf.keras.layers.Conv2D(num_filters, 5, strides=[1, 1], padding=padding, activation=activation)(self.inputs[0])
         conv = tf.keras.layers.MaxPool2D(padding=padding)(conv)
@@ -803,7 +803,7 @@ class IcingIntensityNN:
 
         self.h5f_tst.close()
 
-    def do_evaluate(self, ckpt_dir, ll, cc):
+    def do_evaluate(self, ckpt_dir, ll, cc, prob_thresh=0.5):
 
         ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
         ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
@@ -821,7 +821,7 @@ class IcingIntensityNN:
         preds = np.concatenate(pred_s)
 
         if NumClasses == 2:
-            preds = np.where(preds > 0.6, 1, 0)
+            preds = np.where(preds > prob_thresh, 1, 0)
         else:
             preds = np.argmax(preds, axis=1)
         print(preds.shape[0], np.sum(preds == 1))
@@ -871,6 +871,25 @@ class IcingIntensityNN:
         return filename, ice_lons, ice_lats
 
 
+def run_restore_static(filename_tst, ckpt_dir_s):
+    cm_s = []
+    for ckpt_dir in ckpt_dir_s:
+        nn = IcingIntensityNN()
+        nn.run_restore(filename_tst, ckpt_dir)
+        cm_s.append(tf.math.confusion_matrix(nn.test_labels, nn.test_preds))
+    num = len(cm_s)
+    cm_avg = cm_s[0]
+    for k in range(num-1):
+        cm_avg += cm_s[k+1]
+    cm_avg /= num
+
+    return cm_avg
+
+
+def run_evaluate_static(filename, ckpt_dir_s):
+    nn = IcingIntensityNN()
+
+
 if __name__ == "__main__":
     nn = IcingIntensityNN()
     nn.run('matchup_filename')
-- 
GitLab