From bff3266abf46994d2f4acb87ea3590c1f11e5584 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 13 Apr 2022 15:56:58 -0500
Subject: [PATCH] snapshot...

---
 modules/icing/util.py | 60 ++++++++++++++++++++++++++-----------------
 1 file changed, 37 insertions(+), 23 deletions(-)

diff --git a/modules/icing/util.py b/modules/icing/util.py
index 944d85a1..f830a8a1 100644
--- a/modules/icing/util.py
+++ b/modules/icing/util.py
@@ -115,7 +115,12 @@ def run_make_images(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', ckpt_dir_s_pat
 def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                       day_model_path=model_path_day, night_model_path=model_path_night,
                       prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
-                      l1b_andor_l2='both', use_flight_altitude=True, res_fac=1, use_nan=False):
+                      l1b_andor_l2='both', use_flight_altitude=True, res_fac=1, use_nan=False, model_type='CNN'):
+    if model_type == 'CNN':
+        model_module = icing_cnn
+    elif model_type == 'FCN':
+        model_module = icing_fcn
+
     if use_flight_altitude is True:
         flight_levels = [0, 1, 2, 3, 4]
     else:
@@ -194,10 +199,11 @@ def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=h
             for ds_name in day_train_params:
                 day_grd_dct[ds_name] = np.stack(day_data_dct[ds_name])
 
-            preds_day_dct, probs_day_dct = icing_cnn.run_evaluate_static(day_grd_dct, num_day_tiles, day_model_path,
-                                                                         day_night='DAY', l1b_or_l2=l1b_andor_l2, prob_thresh=prob_thresh,
-                                                                         use_flight_altitude=use_flight_altitude,
-                                                                         flight_levels=flight_levels)
+            preds_day_dct, probs_day_dct = model_module.run_evaluate_static(day_grd_dct, num_day_tiles, day_model_path,
+                                                                            day_night='DAY', l1b_or_l2=l1b_andor_l2,
+                                                                            prob_thresh=prob_thresh,
+                                                                            use_flight_altitude=use_flight_altitude,
+                                                                            flight_levels=flight_levels)
             day_idxs = np.array(day_idxs)
             for flvl in flight_levels:
                 day_preds = preds_day_dct[flvl]
@@ -217,10 +223,11 @@ def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=h
             for ds_name in nght_train_params:
                 nght_grd_dct[ds_name] = np.stack(nght_data_dct[ds_name])
 
-            preds_nght_dct, probs_nght_dct = icing_cnn.run_evaluate_static(nght_grd_dct, num_nght_tiles, night_model_path,
-                                                                           day_night='NIGHT', l1b_or_l2=l1b_andor_l2, prob_thresh=prob_thresh,
-                                                                           use_flight_altitude=use_flight_altitude,
-                                                                           flight_levels=flight_levels)
+            preds_nght_dct, probs_nght_dct = model_module.run_evaluate_static(nght_grd_dct, num_nght_tiles, night_model_path,
+                                                                              day_night='NIGHT', l1b_or_l2=l1b_andor_l2,
+                                                                              prob_thresh=prob_thresh,
+                                                                              use_flight_altitude=use_flight_altitude,
+                                                                              flight_levels=flight_levels)
             nght_idxs = np.array(nght_idxs)
             for flvl in flight_levels:
                 nght_preds = preds_nght_dct[flvl]
@@ -316,11 +323,11 @@ def run_icing_predict_fcn(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_d
                 fd_probs[day_idxs] = probs[day_idxs]
 
         if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
-            preds_nght_dct, probs_nght_dct = icing_fcn.run_evaluate_static_fcn(data_dct, 1, night_model_path,
-                                                                               day_night='NIGHT', l1b_or_l2=l1b_andor_l2,
-                                                                               prob_thresh=prob_thresh,
-                                                                               use_flight_altitude=use_flight_altitude,
-                                                                               flight_levels=flight_levels)
+            preds_nght_dct, probs_nght_dct = icing_fcn.run_evaluate_static(data_dct, 1, night_model_path,
+                                                                           day_night='NIGHT', l1b_or_l2=l1b_andor_l2,
+                                                                           prob_thresh=prob_thresh,
+                                                                           use_flight_altitude=use_flight_altitude,
+                                                                           flight_levels=flight_levels)
             for flvl in flight_levels:
                 preds = preds_nght_dct[flvl]
                 probs = probs_nght_dct[flvl]
@@ -346,10 +353,14 @@ def run_icing_predict_fcn(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_d
 def run_icing_predict_image(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                             day_model_path=model_path_day, night_model_path=model_path_night,
                             prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
-                            l1b_andor_l2='BOTH', use_flight_altitude=True, res_fac=1,
+                            l1b_andor_l2='BOTH', use_flight_altitude=True, res_fac=1, model_type='CNN',
                             extent=[-105, -70, 15, 50],
                             pirep_file='/Users/tomrink/data/pirep/pireps_202109200000_202109232359.csv',
                             obs_lons=None, obs_lats=None, obs_times=None, obs_alt=None, flight_level=None):
+    if model_type == 'CNN':
+        model_module = icing_cnn
+    elif model_type == 'FCN':
+        model_module = icing_fcn
 
     if use_flight_altitude is True:
         flight_levels = [0, 1, 2, 3, 4]
@@ -435,10 +446,12 @@ def run_icing_predict_image(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output
             for ds_name in day_train_params:
                 day_grd_dct[ds_name] = np.stack(day_data_dct[ds_name])
 
-            preds_day_dct, probs_day_dct = icing_cnn.run_evaluate_static(day_grd_dct, num_day_tiles, day_model_path,
-                                                                         day_night='DAY', l1b_or_l2=l1b_andor_l2, prob_thresh=prob_thresh,
-                                                                         use_flight_altitude=use_flight_altitude,
-                                                                         flight_levels=flight_levels)
+            preds_day_dct, probs_day_dct = model_module.run_evaluate_static(day_grd_dct, num_day_tiles, day_model_path,
+                                                                            day_night='DAY', l1b_or_l2=l1b_andor_l2,
+                                                                            prob_thresh=prob_thresh,
+                                                                            use_flight_altitude=use_flight_altitude,
+                                                                            flight_levels=flight_levels)
+
             day_idxs = np.array(day_idxs)
             for flvl in flight_levels:
                 day_preds = preds_day_dct[flvl]
@@ -458,10 +471,11 @@ def run_icing_predict_image(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output
             for ds_name in nght_train_params:
                 nght_grd_dct[ds_name] = np.stack(nght_data_dct[ds_name])
 
-            preds_nght_dct, probs_nght_dct = icing_cnn.run_evaluate_static(nght_grd_dct, num_nght_tiles, night_model_path,
-                                                                           day_night='NIGHT', l1b_or_l2=l1b_andor_l2, prob_thresh=prob_thresh,
-                                                                           use_flight_altitude=use_flight_altitude,
-                                                                           flight_levels=flight_levels)
+            preds_nght_dct, probs_nght_dct = model_module.run_evaluate_static(nght_grd_dct, num_nght_tiles, night_model_path,
+                                                                              day_night='NIGHT', l1b_or_l2=l1b_andor_l2,
+                                                                              prob_thresh=prob_thresh,
+                                                                              use_flight_altitude=use_flight_altitude,
+                                                                              flight_levels=flight_levels)
             nght_idxs = np.array(nght_idxs)
             for flvl in flight_levels:
                 nght_preds = preds_nght_dct[flvl]
-- 
GitLab