Skip to content
Snippets Groups Projects
util.py 31.25 KiB
import numpy as np
import deeplearning.icing_fcn as icing_fcn
import deeplearning.icing_cnn as icing_cnn
from deeplearning.icing_fcn import IcingIntensityFCN
from icing.pirep_goes import setup, time_filter_3
from icing.moon_phase import moon_phase
from util.util import get_time_tuple_utc, is_day, check_oblique, get_median, homedir, write_icing_file_nc4,\
    write_icing_file_nc4_viirs, get_training_parameters,\
    make_for_full_domain_predict, make_for_full_domain_predict_viirs_clavrx, prepare_evaluate
from util.plot import make_icing_image
from util.geos_nav import get_navigation, get_lon_lat_2d_mesh
from util.setup import model_path_day, model_path_night
from aeolus.datasource import CLAVRx, CLAVRx_VIIRS, CLAVRx_H08, CLAVRx_H09
import h5py
import datetime
import tensorflow as tf
import os
# from scipy.signal import medfilt2d


flt_level_ranges = {k: None for k in range(5)}
flt_level_ranges[0] = [0.0, 2000.0]
flt_level_ranges[1] = [2000.0, 4000.0]
flt_level_ranges[2] = [4000.0, 6000.0]
flt_level_ranges[3] = [6000.0, 8000.0]
flt_level_ranges[4] = [8000.0, 15000.0]


def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                      day_model_path=model_path_day, night_model_path=model_path_night,
                      prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                      l1b_andor_l2='l2', use_flight_altitude=True, res_fac=1, use_nan=False, has_time=False,
                      model_type='FCN', use_dnb=False):

    if model_type == 'CNN':
        model_module = icing_cnn
    elif model_type == 'FCN':
        model_module = icing_fcn

    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
    else:
        flight_levels = [0]

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2, use_dnb=False)
    nght_train_params_dnb, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2, use_dnb=True)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
        train_params_dnb = list(set(day_train_params + nght_train_params_dnb))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params
        train_params_dnb = nght_train_params_dnb

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    elif satellite == 'GOES16':
        clvrx_ds = CLAVRx(clvrx_dir)
    elif satellite == 'VIIRS':
        clvrx_ds = CLAVRx_VIIRS(clvrx_dir)

    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        lunar_illuminated = moon_phase(ts)
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        if satellite == 'GOES16' or satellite == 'H08' or satellite == 'H09':
            data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain, res_fac=res_fac)

            if fidx == 0:  # These don't change for geostationary fixed grids
                nav = get_navigation(satellite, domain)
                lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc, offset=int(8 / res_fac))

            bt_fld_name = 'temp_10_4um_nom'

            ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=
                                ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick', bt_fld_name],
                                satellite=satellite, domain=domain, res_fac=res_fac)
        elif satellite == 'VIIRS':
            if use_dnb and lunar_illuminated:
                train_params = train_params_dnb
            data_dct, ll, cc, lats_2d, lons_2d = make_for_full_domain_predict_viirs_clavrx(h5f, name_list=train_params,
                                                    day_night=day_night, res_fac=res_fac, use_dnb=use_dnb)

            bt_fld_name = 'temp_11_0um_nom'
            ancil_data_dct, _, _, _, _ = make_for_full_domain_predict_viirs_clavrx(h5f, name_list=
                                ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick', bt_fld_name],
                                res_fac=res_fac)

        num_elems, num_lines = len(cc), len(ll)

        satzen = ancil_data_dct['sensor_zenith_angle']
        solzen = ancil_data_dct['solar_zenith_angle']
        bt_10_4 = ancil_data_dct[bt_fld_name]
        day_idxs = []
        nght_idxs = []
        all_idxs = []
        avg_bt = []
        for j in range(num_lines):
            for i in range(num_elems):
                k = i + j*num_elems
                avg_bt.append(get_median(bt_10_4[k]))

                if not check_oblique(satzen[k]):
                    continue

                all_idxs.append(k)
                if is_day(solzen[k]):
                    day_idxs.append(k)
                else:
                    nght_idxs.append(k)

        num_tiles = len(all_idxs)
        num_day_tiles = len(day_idxs)
        num_nght_tiles = len(nght_idxs)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if day_night == 'AUTO' or day_night == 'DAY':
            if num_day_tiles > 0:
                day_data_dct = {name: [] for name in day_train_params}
                for name in day_train_params:
                    for k in day_idxs:
                        day_data_dct[name].append(data_dct[name][k])
                day_grd_dct = {name: None for name in day_train_params}
                for ds_name in day_train_params:
                    day_grd_dct[ds_name] = np.stack(day_data_dct[ds_name])

                preds_day_dct, probs_day_dct = model_module.run_evaluate_static(day_grd_dct, num_day_tiles, day_model_path,
                                                                                day_night='DAY', l1b_or_l2=l1b_andor_l2,
                                                                                prob_thresh=prob_thresh,
                                                                                use_flight_altitude=use_flight_altitude,
                                                                                flight_levels=flight_levels)
                day_idxs = np.array(day_idxs)
                for flvl in flight_levels:
                    day_preds = preds_day_dct[flvl]
                    day_probs = probs_day_dct[flvl]
                    fd_preds = preds_2d_dct[flvl]
                    fd_probs = probs_2d_dct[flvl]
                    fd_preds[day_idxs] = day_preds[:]
                    fd_probs[day_idxs] = day_probs[:]

            if num_nght_tiles > 0:
                mode = 'NIGHT'
                model_path = night_model_path
                if use_dnb and lunar_illuminated:
                    model_path = day_model_path
                    mode = 'DAY'
                    nght_train_params = train_params_dnb

                nght_data_dct = {name: [] for name in nght_train_params}
                for name in nght_train_params:
                    for k in nght_idxs:
                        nght_data_dct[name].append(data_dct[name][k])
                nght_grd_dct = {name: None for name in nght_train_params}
                for ds_name in nght_train_params:
                    nght_grd_dct[ds_name] = np.stack(nght_data_dct[ds_name])

                preds_nght_dct, probs_nght_dct = model_module.run_evaluate_static(nght_grd_dct, num_nght_tiles,
                                                                                  model_path,
                                                                                  day_night=mode,
                                                                                  l1b_or_l2=l1b_andor_l2,
                                                                                  prob_thresh=prob_thresh,
                                                                                  use_flight_altitude=use_flight_altitude,
                                                                                  flight_levels=flight_levels)
                nght_idxs = np.array(nght_idxs)
                for flvl in flight_levels:
                    nght_preds = preds_nght_dct[flvl]
                    nght_probs = probs_nght_dct[flvl]
                    fd_preds = preds_2d_dct[flvl]
                    fd_probs = probs_2d_dct[flvl]
                    fd_preds[nght_idxs] = nght_preds[:]
                    fd_probs[nght_idxs] = nght_probs[:]

        elif day_night == 'NIGHT':
            model_path = night_model_path
            mode = 'NIGHT'
            if use_dnb and lunar_illuminated:
                model_path = day_model_path
                mode = 'DAY'
                nght_train_params = train_params_dnb

            nght_data_dct = {name: [] for name in nght_train_params}
            for name in nght_train_params:
                for k in all_idxs:
                    nght_data_dct[name].append(data_dct[name][k])
            nght_grd_dct = {name: None for name in nght_train_params}
            for ds_name in nght_train_params:
                nght_grd_dct[ds_name] = np.stack(nght_data_dct[ds_name])

            preds_nght_dct, probs_nght_dct = model_module.run_evaluate_static(nght_grd_dct, num_tiles, model_path,
                                                                              day_night=mode, l1b_or_l2=l1b_andor_l2,
                                                                              prob_thresh=prob_thresh,
                                                                              use_flight_altitude=use_flight_altitude,
                                                                              flight_levels=flight_levels)
            for flvl in flight_levels:
                nght_preds = preds_nght_dct[flvl]
                nght_probs = probs_nght_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[all_idxs] = nght_preds[:]
                fd_probs[all_idxs] = nght_probs[:]

        # combine day and night into full grid  ------------------------------------------
        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        avg_bt = np.array(avg_bt)
        bt_10_4_2d = avg_bt.reshape((num_lines, num_elems))

        if satellite == 'GOES16' or satellite == 'H08' or satellite == 'H09':
            write_icing_file_nc4(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct,
                                 x_rad, y_rad, lons_2d, lats_2d, cc, ll,
                                 satellite=satellite, domain=domain, use_nan=use_nan, has_time=has_time,
                                 prob_thresh=prob_thresh, bt_10_4=bt_10_4_2d)
        elif satellite == 'VIIRS':
            write_icing_file_nc4_viirs(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct,
                                       lons_2d, lats_2d,
                                       use_nan=use_nan, has_time=has_time,
                                       prob_thresh=prob_thresh, bt_10_4=bt_10_4_2d)

        print('Done: ', clvrx_str_time)
        h5f.close()


def run_icing_predict_fcn(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                          day_model_path=model_path_day, night_model_path=model_path_night,
                          prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                          l1b_andor_l2='both', use_flight_altitude=False, res_fac=1, use_nan=False):
    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
    else:
        flight_levels = [0]

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    else:
        clvrx_ds = CLAVRx(clvrx_dir)
    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        data_dct, solzen, satzen, ll, cc = prepare_evaluate(h5f, name_list=train_params, satellite=satellite, domain=domain, offset=8)
        num_elems = len(cc)
        num_lines = len(ll)

        if fidx == 0:
            nav = get_navigation(satellite, domain)
            lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc)

        day_idxs = solzen < 80.0
        num_day_tiles = np.sum(day_idxs)

        nght_idxs = solzen > 100.0
        num_nght_tiles = np.sum(nght_idxs)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:

            preds_day_dct, probs_day_dct = icing_fcn.run_evaluate_static(data_dct, 1, day_model_path,
                                                                         day_night='DAY', l1b_or_l2=l1b_andor_l2,
                                                                         prob_thresh=prob_thresh,
                                                                         use_flight_altitude=use_flight_altitude,
                                                                         flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = preds_day_dct[flvl]
                probs = probs_day_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[day_idxs] = preds[day_idxs]
                fd_probs[day_idxs] = probs[day_idxs]

        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
            preds_nght_dct, probs_nght_dct = icing_fcn.run_evaluate_static(data_dct, 1, night_model_path,
                                                                           day_night='NIGHT', l1b_or_l2=l1b_andor_l2,
                                                                           prob_thresh=prob_thresh,
                                                                           use_flight_altitude=use_flight_altitude,
                                                                           flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = preds_nght_dct[flvl]
                probs = probs_nght_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[nght_idxs] = preds[nght_idxs]
                fd_probs[nght_idxs] = probs[nght_idxs]

        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        write_icing_file_nc4(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct,
                             x_rad, y_rad, lons_2d, lats_2d, cc, ll,
                             satellite=satellite, domain=domain, use_nan=use_nan, prob_thresh=prob_thresh)

        print('Done: ', clvrx_str_time)
        h5f.close()


def run_icing_predict_image(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                            day_model_path=model_path_day, night_model_path=model_path_night,
                            prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                            l1b_andor_l2='BOTH', use_flight_altitude=True, flight_levels=[0, 1, 2, 3, 4],
                            res_fac=1, model_type='CNN', extent=[-105, -70, 15, 50], pirep_file=None, icing_dict=None,
                            obs_lons=None, obs_lats=None, obs_times=None, obs_alt=None, obs_intensity=None):
    if model_type == 'CNN':
        model_module = icing_cnn
    elif model_type == 'FCN':
        model_module = icing_fcn

    if day_model_path is not None:
        day_model = model_module.load_model(day_model_path, day_night='DAY', l1b_andor_l2=l1b_andor_l2,
                                            satellite=satellite, use_flight_altitude=use_flight_altitude)
    if night_model_path is not None:
        night_model = model_module.load_model(night_model_path, day_night='NIGHT', l1b_andor_l2=l1b_andor_l2,
                                              satellite=satellite, use_flight_altitude=use_flight_altitude)

    alt_lo, alt_hi = 0.0, 15000.0
    if use_flight_altitude is True:
        flight_levels = flight_levels
        alt_lo = flt_level_ranges[flight_levels[0]][0]
        alt_hi = flt_level_ranges[flight_levels[-1]][1]
    else:
        flight_levels = [0]

    if pirep_file is not None:
        ice_dict, no_ice_dict, neg_ice_dict = setup(pirep_file)
    if icing_dict is not None:
        ice_dict = icing_dict

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    else:
        clvrx_ds = CLAVRx(clvrx_dir)
    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain, res_fac=res_fac)

        if fidx == 0:
            num_elems = len(cc)
            num_lines = len(ll)
            nav = get_navigation(satellite, domain)

        ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=
                            ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick'],
                            satellite=satellite, domain=domain, res_fac=res_fac)

        satzen = ancil_data_dct['sensor_zenith_angle']
        solzen = ancil_data_dct['solar_zenith_angle']
        cth = ancil_data_dct['cld_height_acha']
        day_idxs = []
        nght_idxs = []
        day_cth_max = []
        nght_cth_max = []
        for j in range(num_lines):
            for i in range(num_elems):
                k = i + j*num_elems
                c = cth[k].flatten()
                c_m = np.mean(np.sort(c[np.invert(np.isnan(c))])[-2:])

                c_m = 0 if 2000 > c_m >= 0 else c_m
                c_m = 1 if 4000 > c_m >= 2000 else c_m
                c_m = 2 if 6000 > c_m >= 4000 else c_m
                c_m = 3 if 8000 > c_m >= 6000 else c_m
                c_m = 4 if 15000 > c_m >= 8000 else c_m

                if not check_oblique(satzen[k]):
                    continue
                if is_day(solzen[k]):
                    day_idxs.append(k)
                    day_cth_max.append(c_m)
                else:
                    nght_idxs.append(k)
                    nght_cth_max.append(c_m)

        num_tiles = num_lines * num_elems
        num_day_tiles = len(day_idxs)
        num_nght_tiles = len(nght_idxs)
        day_cth_max = np.array(day_cth_max)
        nght_cth_max = np.array(nght_cth_max)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:

            day_data_dct = {name: [] for name in day_train_params}
            for name in day_train_params:
                for k in day_idxs:
                    day_data_dct[name].append(data_dct[name][k])
            day_grd_dct = {name: None for name in day_train_params}
            for ds_name in day_train_params:
                day_grd_dct[ds_name] = np.stack(day_data_dct[ds_name])

            day_grd_dct['cth_high_avg'] = day_cth_max

            preds_day_dct, probs_day_dct = \
                model_module.run_evaluate_static_2(day_model, day_grd_dct, num_day_tiles, prob_thresh=prob_thresh,
                                                   use_flight_altitude=use_flight_altitude,
                                                   flight_levels=flight_levels)

            day_idxs = np.array(day_idxs)
            for flvl in flight_levels:
                day_preds = preds_day_dct[flvl]
                day_probs = probs_day_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[day_idxs] = day_preds[:]
                fd_probs[day_idxs] = day_probs[:]

        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:

            nght_data_dct = {name: [] for name in nght_train_params}
            for name in nght_train_params:
                for k in nght_idxs:
                    nght_data_dct[name].append(data_dct[name][k])
            nght_grd_dct = {name: None for name in nght_train_params}
            for ds_name in nght_train_params:
                nght_grd_dct[ds_name] = np.stack(nght_data_dct[ds_name])

            nght_grd_dct['cth_high_avg'] = nght_cth_max

            preds_nght_dct, probs_nght_dct = \
                model_module.run_evaluate_static_2(night_model, nght_grd_dct, num_nght_tiles, prob_thresh=prob_thresh,
                                                   use_flight_altitude=use_flight_altitude, flight_levels=flight_levels)
            nght_idxs = np.array(nght_idxs)
            for flvl in flight_levels:
                nght_preds = preds_nght_dct[flvl]
                nght_probs = probs_nght_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[nght_idxs] = nght_preds[:]
                fd_probs[nght_idxs] = nght_probs[:]

        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        dto, _ = get_time_tuple_utc(ts)
        dto_0 = dto - datetime.timedelta(minutes=30)
        dto_1 = dto + datetime.timedelta(minutes=30)
        ts_0 = dto_0.timestamp()
        ts_1 = dto_1.timestamp()

        if pirep_file is not None:
            _, keep_lons, keep_lats, _ = time_filter_3(ice_dict, ts_0, ts_1, alt_lo, alt_hi)
        elif obs_times is not None:
            keep = np.logical_and(obs_times >= ts_0, obs_times < ts_1)
            keep = np.where(keep, np.logical_and(obs_alt >= alt_lo, obs_alt < alt_hi), False)
            keep = np.where(keep, obs_intensity > 1, False)
            keep_lons = obs_lons[keep]
            keep_lats = obs_lats[keep]
        else:
            keep_lons = None
            keep_lats = None

        prob_s = []
        for flvl in flight_levels:
            probs = probs_2d_dct[flvl]
            prob_s.append(probs)
        prob_s = np.stack(prob_s, axis=-1)
        max_prob = np.max(prob_s, axis=2)
        # max_prob = medfilt2d(max_prob)
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)

        make_icing_image(h5f, max_prob, None, None, clvrx_str_time, satellite, domain,
                         ice_lons_vld=keep_lons, ice_lats_vld=keep_lats, extent=extent)

        print('Done: ', clvrx_str_time)
        h5f.close()


def run_icing_predict_image_fcn(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                                day_model_path=model_path_day, night_model_path=model_path_night,
                                prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                                l1b_andor_l2='BOTH', use_flight_altitude=True, res_fac=1,
                                extent=[-105, -70, 15, 50],
                                pirep_file='/Users/tomrink/data/pirep/pireps_202109200000_202109232359.csv',
                                obs_lons=None, obs_lats=None, obs_times=None, obs_alt=None, flight_level=None, obs_intensity=None):

    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
    else:
        flight_levels = [0]

    if pirep_file is not None:
        ice_dict, no_ice_dict, neg_ice_dict = setup(pirep_file)

    alt_lo, alt_hi = 0.0, 15000.0
    if flight_level is not None:
        alt_lo, alt_hi = flt_level_ranges[flight_level]

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    else:
        clvrx_ds = CLAVRx(clvrx_dir)
    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        dto, _ = get_time_tuple_utc(ts)
        dto_0 = dto - datetime.timedelta(minutes=30)
        dto_1 = dto + datetime.timedelta(minutes=30)
        ts_0 = dto_0.timestamp()
        ts_1 = dto_1.timestamp()

        if pirep_file is not None:
            _, keep_lons, keep_lats, _ = time_filter_3(ice_dict, ts_0, ts_1, alt_lo, alt_hi)
        elif obs_times is not None:
            keep = np.logical_and(obs_times >= ts_0, obs_times < ts_1)
            keep = np.where(keep, np.logical_and(obs_alt >= alt_lo, obs_alt < alt_hi), False)
            keep_lons = obs_lons[keep]
            keep_lats = obs_lats[keep]
        else:
            keep_lons = None
            keep_lats = None

        data_dct, solzen, satzen, ll, cc = prepare_evaluate(h5f, name_list=train_params, satellite=satellite, domain=domain, offset=8)
        num_elems = len(cc)
        num_lines = len(ll)

        if fidx == 0:
            nav = get_navigation(satellite, domain)
            lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc)

        day_idxs = solzen < 80.0
        num_day_tiles = np.sum(day_idxs)

        nght_idxs = solzen > 100.0
        num_nght_tiles = np.sum(nght_idxs)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:

            preds_day_dct, probs_day_dct = icing_fcn.run_evaluate_static(data_dct, day_model_path,
                                                                         day_night='DAY', l1b_or_l2=l1b_andor_l2,
                                                                         prob_thresh=prob_thresh,
                                                                         use_flight_altitude=use_flight_altitude,
                                                                         flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = preds_day_dct[flvl]
                probs = probs_day_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[day_idxs] = preds[day_idxs]
                fd_probs[day_idxs] = probs[day_idxs]

        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
            preds_nght_dct, probs_nght_dct = icing_fcn.run_evaluate_static_fcn(data_dct, night_model_path,
                                                                               day_night='NIGHT', l1b_or_l2=l1b_andor_l2,
                                                                               prob_thresh=prob_thresh,
                                                                               use_flight_altitude=use_flight_altitude,
                                                                               flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = preds_nght_dct[flvl]
                probs = probs_nght_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[nght_idxs] = preds[nght_idxs]
                fd_probs[nght_idxs] = probs[nght_idxs]

        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        prob_s = []
        for flvl in flight_levels:
            probs = probs_2d_dct[flvl]
            prob_s.append(probs)
        prob_s = np.stack(prob_s, axis=-1)
        max_prob = np.max(prob_s, axis=2)
        max_prob = np.where(max_prob < 0.5, np.nan, max_prob)

        make_icing_image(h5f, max_prob, None, None, clvrx_str_time, satellite, domain,
                         ice_lons_vld=keep_lons, ice_lats_vld=keep_lats, extent=extent)

        print('Done: ', clvrx_str_time)
        h5f.close()