From 44bb526e08c1fad812724b4eace953b979ce4d39 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Wed, 26 Jan 2022 12:44:09 -0600
Subject: [PATCH] snapshot...

---
 modules/icing/pirep_goes.py | 121 ++++++++++++++++++++++++++++--------
 1 file changed, 94 insertions(+), 27 deletions(-)

diff --git a/modules/icing/pirep_goes.py b/modules/icing/pirep_goes.py
index bab1b3f0..34a7e674 100644
--- a/modules/icing/pirep_goes.py
+++ b/modules/icing/pirep_goes.py
@@ -4,8 +4,8 @@ import pickle
 import matplotlib.pyplot as plt
 import os
 from util.util import get_time_tuple_utc, GenericException, add_time_range_to_filename, is_night, is_day, \
-    check_oblique, get_timestamp, homedir, write_icing_file, write_icing_file_nc4, make_for_full_domain_predict, \
-    make_for_full_domain_predict2, get_indexes_within_threshold
+    check_oblique, get_timestamp, homedir, write_icing_file_nc4, make_for_full_domain_predict, \
+    get_indexes_within_threshold
 from util.plot import make_icing_image
 from util.geos_nav import get_navigation, get_lon_lat_2d_mesh
 from util.setup import model_path_day, model_path_night
@@ -2175,18 +2175,21 @@ def run_make_images(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', ckpt_dir_s_pat
         print('Done: ', clvrx_str_time)
 
 
-def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir, model_path=None,
-                      prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='DAY',
+def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
+                      day_model_path=model_path_day, night_model_path=model_path_night,
+                      prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                       l1b_andor_l2='BOTH', use_flight_altitude=True):
+    flight_levels = [0, 1, 2, 3, 4]
 
-    if day_night == 'DAY':
-        if model_path is None:
-            model_path = model_path_day
-    else:
-        if model_path is None:
-            model_path = model_path_night
+    day_train_params = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
+    nght_train_params = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2)
 
-    train_params = get_training_parameters(day_night=day_night, l1b_andor_l2=l1b_andor_l2)
+    if day_night == 'AUTO':
+        train_params = list(set(day_train_params + nght_train_params))
+    elif day_night == 'DAY':
+        train_params = day_train_params
+    elif day_night == 'NIGHT':
+        train_params = nght_train_params
 
     if satellite == 'H08':
         clvrx_ds = CLAVRx_H08(clvrx_dir)
@@ -2200,7 +2203,6 @@ def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=h
         clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')
 
         data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain)
-        # ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=['cld_height_acha', 'cld_geo_thick'])
 
         if fidx == 0:
             num_elems = len(cc)
@@ -2208,21 +2210,86 @@ def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=h
             nav = get_navigation(satellite, domain)
             lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc)
 
-        solzen, satzen = make_for_full_domain_predict2(h5f, satellite=satellite, domain=domain)
-        keep = np.logical_or(lats_2d > -63.0, lats_2d < 63.0)
-        keep = np.where(keep, satzen < 70, False)
-        if day_night == 'DAY':
-            keep = np.where(keep, solzen < 80, False)
-
-        preds_2d_dct, probs_2d_dct = run_evaluate_static(data_dct, num_lines, num_elems, day_night=day_night,
-                                                         ckpt_dir_s_path=model_path, prob_thresh=prob_thresh,
-                                                         use_flight_altitude=use_flight_altitude)
-        flt_lvls = list(preds_2d_dct.keys())
-        for flvl in flt_lvls:
-            probs = probs_2d_dct[flvl]
-            preds = preds_2d_dct[flvl]
-            np.where(keep, preds, -1)
-            np.where(keep, probs, -1.0)
+        ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=
+                            ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick'],
+                            satellite=satellite, domain=domain)
+
+        solzen = ancil_data_dct['solar_zenith_angle']
+        day_idxs = []
+        nght_idxs = []
+        for j in range(num_lines):
+            for i in range(num_elems):
+                k = i + j*num_elems
+                if is_day(solzen[k]):
+                    day_idxs.append(k)
+                else:
+                    nght_idxs.append(k)
+
+        num_tiles = num_lines * num_elems
+        num_day_tiles = len(day_idxs)
+        num_nght_tiles = len(nght_idxs)
+
+        # initialize output arrays
+        probs_2d_dct = {flvl: None for flvl in flight_levels}
+        preds_2d_dct = {flvl: None for flvl in flight_levels}
+        for flvl in flight_levels:
+            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
+            fd_preds[:] = -1
+            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
+            fd_probs[:] = -1.0
+
+            preds_2d_dct[flvl] = fd_preds
+            probs_2d_dct[flvl] = fd_probs
+
+        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:
+
+            day_data_dct = {name: [] for name in day_train_params}
+            for name in day_train_params:
+                for k in day_idxs:
+                    day_data_dct[name].append(data_dct[name][k])
+            day_grd_dct = {name: None for name in day_train_params}
+            for ds_name in day_train_params:
+                day_grd_dct[ds_name] = np.stack(day_data_dct[ds_name])
+
+            preds_day_dct, probs_day_dct = run_evaluate_static(day_grd_dct, num_day_tiles, day_night='DAY',
+                                                               ckpt_dir_s_path=day_model_path, prob_thresh=prob_thresh,
+                                                               use_flight_altitude=use_flight_altitude)
+            day_preds = preds_day_dct[flvl]
+            day_probs = probs_day_dct[flvl]
+            day_idxs = np.array(day_idxs)
+            fd_preds[day_idxs] = day_preds[:]
+            fd_probs[day_idxs] = day_probs[:]
+
+        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
+
+            nght_data_dct = {name: [] for name in nght_train_params}
+            for name in nght_train_params:
+                for k in nght_idxs:
+                    nght_data_dct[name].append(data_dct[name][k])
+            nght_grd_dct = {name: None for name in nght_train_params}
+            for ds_name in nght_train_params:
+                nght_grd_dct[ds_name] = np.stack(nght_data_dct[ds_name])
+
+            preds_nght_dct, probs_nght_dct = run_evaluate_static(nght_grd_dct, num_nght_tiles, day_night='NIGHT',
+                                                                 ckpt_dir_s_path=night_model_path, prob_thresh=prob_thresh,
+                                                                 use_flight_altitude=use_flight_altitude)
+            nght_preds = preds_nght_dct[flvl]
+            nght_probs = probs_nght_dct[flvl]
+            nght_idxs = np.array(nght_idxs)
+            fd_preds[nght_idxs] = nght_preds[:]
+            fd_probs[nght_idxs] = nght_probs[:]
+
+        # solzen, satzen = make_for_full_domain_predict2(h5f, satellite=satellite, domain=domain)
+        # keep = np.logical_or(lats_2d > -63.0, lats_2d < 63.0)
+        # keep = np.where(keep, satzen < 70, False)
+        # if day_night == 'DAY':
+        #     keep = np.where(keep, solzen < 80, False)
+
+        for flvl in flight_levels:
+            fd_preds = preds_2d_dct[flvl]
+            fd_probs = probs_2d_dct[flvl]
+            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
+            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))
 
         write_icing_file_nc4(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct,
                              x_rad, y_rad, lons_2d, lats_2d, cc, ll, satellite=satellite, domain=domain)
-- 
GitLab