Skip to content
Snippets Groups Projects
util.py 65.33 KiB
import numpy as np
from icing.pirep_goes import setup, time_filter_3
from icing.moon_phase import moon_phase
from util.util import get_time_tuple_utc, is_day, check_oblique, get_median, homedir, \
    get_cartopy_crs, \
    get_fill_attrs, get_grid_values, GenericException, get_timestamp, get_cf_nav_parameters
from util.plot import make_icing_image
from util.geos_nav import get_navigation, get_lon_lat_2d_mesh
from util.setup import model_path_day, model_path_night
from aeolus.datasource import CLAVRx, CLAVRx_VIIRS, CLAVRx_H08, CLAVRx_H09
import h5py
import datetime
from netCDF4 import Dataset
import joblib
import tensorflow as tf
import os
# from scipy.signal import medfilt2d


flt_level_ranges = {k: None for k in range(5)}
flt_level_ranges[0] = [0.0, 2000.0]
flt_level_ranges[1] = [2000.0, 4000.0]
flt_level_ranges[2] = [4000.0, 6000.0]
flt_level_ranges[3] = [6000.0, 8000.0]
flt_level_ranges[4] = [8000.0, 15000.0]

# Taiwan domain:
# lon, lat = 120.955098, 23.834310
# elem, line = (1789, 1505)
# # UR from Taiwan
# lon, lat = 135.0, 35.0
# elem_ur, line_ur = (2499, 995)
taiwan_i0 = 1079
taiwan_j0 = 995
taiwan_lenx = 1420
taiwan_leny = 1020
# geos.transform_point(135.0, 35.0, ccrs.PlateCarree(), False)
# geos.transform_point(106.61, 13.97, ccrs.PlateCarree(), False)
taiwain_extent = [-3342, -502, 1470, 3510]  # GEOS coordinates, not line, elem


def get_training_parameters(day_night='DAY', l1b_andor_l2='both', satellite='GOES16', use_dnb=False):
    if day_night == 'DAY':
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
        # train_params_l2 = ['cld_temp_acha', 'supercooled_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']

        if satellite == 'GOES16':
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
                                'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
                                # 'refl_2_10um_nom']
            # train_params_l1b = ['refl_1_38um_nom', 'refl_1_60um_nom', 'temp_8_5um_nom', 'temp_11_0um_nom']
        elif satellite == 'H08':
            train_params_l1b = ['temp_10_4um_nom', 'temp_12_0um_nom', 'temp_8_5um_nom', 'temp_3_75um_nom', 'refl_2_10um_nom',
                                'refl_1_60um_nom', 'refl_0_86um_nom', 'refl_0_47um_nom']
    else:
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_acha', 'cld_opd_acha']
        # train_params_l2 = ['cld_temp_acha', 'supercooled_cloud_fraction', 'cld_emiss_acha', 'cld_reff_acha', 'cld_opd_acha']

        if use_dnb is True:
            train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                               'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']

        if satellite == 'GOES16':
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom']
            # train_params_l1b = ['temp_8_5um_nom', 'temp_11_0um_nom']
        elif satellite == 'H08':
            train_params_l1b = ['temp_10_4um_nom', 'temp_12_0um_nom', 'temp_8_5um_nom', 'temp_3_75um_nom']

    if l1b_andor_l2 == 'both':
        train_params = train_params_l1b + train_params_l2
    elif l1b_andor_l2 == 'l1b':
        train_params = train_params_l1b
    elif l1b_andor_l2 == 'l2':
        train_params = train_params_l2

    return train_params, train_params_l1b, train_params_l2


def make_for_full_domain_predict(h5f, name_list=None, satellite='GOES16', domain='FD', res_fac=1,
                                 data_extent=None):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    geos_nav = get_navigation(satellite, domain)

    if satellite == 'GOES16':
        if data_extent is not None:
            ll_pt_a = data_extent[0]
            ll_pt_b = data_extent[1]
            if domain == 'CONUS':
                y_a, x_a = geos_nav.earth_to_lc(ll_pt_a.lon, ll_pt_a.lat)
                y_b, x_b = geos_nav.earth_to_lc(ll_pt_b.lon, ll_pt_b.lat)
                i_0 = x_a if x_a < x_b else x_b
                j_0 = y_a if y_a < y_b else y_b
                xlen = x_b - x_a if x_a < x_b else x_a - x_b
                ylen = y_b - y_a if y_a < y_b else y_a - y_b

    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    grd_dct = {name: None for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        if (satellite == 'H08' or satellite == 'H09') and ds_name == 'temp_6_2um_nom':
            gvals = np.full((ylen, xlen), np.nan)
        else:
            fill_value, fill_value_name = get_fill_attrs(ds_name)
            gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
            grd_dct[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    grd_dct_n = {name: [] for name in name_list}

    n_x = int(xlen/s_x) - 1
    n_y = int(ylen/s_y) - 1

    r_x = xlen - (n_x * s_x)
    x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    n_x -= x_d

    r_y = ylen - (n_y * s_y)
    y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    n_y -= y_d

    ll = [j_0 + j*s_y for j in range(n_y)]
    cc = [i_0 + i*s_x for i in range(n_x)]

    for ds_name in name_list:
        for j in range(n_y):
            j_ul = j * s_y
            j_ul_b = j_ul + w_y
            for i in range(n_x):
                i_ul = i * s_x
                i_ul_b = i_ul + w_x
                grd_dct_n[ds_name].append(grd_dct[ds_name][j_ul:j_ul_b, i_ul:i_ul_b])

    grd_dct = {name: None for name in name_list}
    for ds_name in name_list:
        grd_dct[ds_name] = np.stack(grd_dct_n[ds_name])

    return grd_dct, ll, cc


def make_for_full_domain_predict_viirs_clavrx(h5f, name_list=None, res_fac=1, day_night='DAY', use_dnb=False):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    ylen = h5f['scan_lines_along_track_direction'].shape[0]
    xlen = h5f['pixel_elements_along_scan_direction'].shape[0]

    use_nl_comp = False
    if (day_night == 'NIGHT' or day_night == 'AUTO') and use_dnb:
        use_nl_comp = True

    grd_dct = {name: None for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        name = ds_name

        if use_nl_comp:
            if ds_name == 'cld_reff_dcomp':
                name = 'cld_reff_nlcomp'
            elif ds_name == 'cld_opd_dcomp':
                name = 'cld_opd_nlcomp'

        fill_value, fill_value_name = get_fill_attrs(name)
        gvals = get_grid_values(h5f, name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
            grd_dct[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    # TODO: need to investigate discrepencies with compute_lwc_iwc
    # if use_nl_comp:
    #     cld_phase = get_grid_values(h5f, 'cloud_phase', j_0, i_0, None, num_j=ylen, num_i=xlen)
    #     cld_dz = get_grid_values(h5f, 'cld_geo_thick', j_0, i_0, None, num_j=ylen, num_i=xlen)
    #     reff = grd_dct['cld_reff_dcomp']
    #     opd = grd_dct['cld_opd_dcomp']
    #
    #     lwc_nlcomp, iwc_nlcomp = compute_lwc_iwc(cld_phase, reff, opd, cld_dz)
    #     grd_dct['iwc_dcomp'] = iwc_nlcomp
    #     grd_dct['lwc_dcomp'] = lwc_nlcomp

    grd_dct_n = {name: [] for name in name_list}

    n_x = int(xlen/s_x) - 1
    n_y = int(ylen/s_y) - 1

    r_x = xlen - (n_x * s_x)
    x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    n_x -= x_d

    r_y = ylen - (n_y * s_y)
    y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    n_y -= y_d

    ll = [j_0 + j*s_y for j in range(n_y)]
    cc = [i_0 + i*s_x for i in range(n_x)]

    for ds_name in name_list:
        for j in range(n_y):
            j_ul = j * s_y
            j_ul_b = j_ul + w_y
            for i in range(n_x):
                i_ul = i * s_x
                i_ul_b = i_ul + w_x
                grd_dct_n[ds_name].append(grd_dct[ds_name][j_ul:j_ul_b, i_ul:i_ul_b])

    grd_dct = {name: None for name in name_list}
    for ds_name in name_list:
        grd_dct[ds_name] = np.stack(grd_dct_n[ds_name])

    lats = get_grid_values(h5f, 'latitude', j_0, i_0, None, num_j=ylen, num_i=xlen)
    lons = get_grid_values(h5f, 'longitude', j_0, i_0, None, num_j=ylen, num_i=xlen)

    ll_2d, cc_2d = np.meshgrid(ll, cc, indexing='ij')

    lats = lats[ll_2d, cc_2d]
    lons = lons[ll_2d, cc_2d]

    return grd_dct, ll, cc, lats, lons


def make_for_full_domain_predict2(h5f, satellite='GOES16', domain='FD', res_fac=1):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = int(xlen/s_x)
    n_y = int(ylen/s_y)

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]
    satzen = satzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]

    return solzen, satzen
# -------------------------------------------------------------------------------------------


def prepare_evaluate(h5f, name_list, satellite='GOES16', domain='FD', res_fac=1, offset=0):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = w_x // res_fac
    s_y = w_y // res_fac

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = (xlen // s_x)
    n_y = (ylen // s_y)

    # r_x = xlen - (n_x * s_x)
    # x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    # n_x -= x_d
    #
    # r_y = ylen - (n_y * s_y)
    # y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    # n_y -= y_d

    ll = [(offset+j_0) + j*s_y for j in range(n_y)]
    cc = [(offset+i_0) + i*s_x for i in range(n_x)]

    grd_dct_n = {name: [] for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        fill_value, fill_value_name = get_fill_attrs(ds_name)
        gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        gvals = np.expand_dims(gvals, axis=0)
        if gvals is not None:
            grd_dct_n[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]
    satzen = satzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]

    return grd_dct_n, solzen, satzen, ll, cc


def prepare_evaluate1x1(h5f, name_list, satellite='GOES16', domain='FD', res_fac=1, offset=0):
    w_x = 1
    w_y = 1
    i_0 = 0
    j_0 = 0
    s_x = w_x // res_fac
    s_y = w_y // res_fac

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = (xlen // s_x)
    n_y = (ylen // s_y)

    ll = [(offset+j_0) + j*s_y for j in range(n_y)]
    cc = [(offset+i_0) + i*s_x for i in range(n_x)]

    grd_s = []

    cnt_a = 0
    for ds_name in name_list:
        fill_value, fill_value_name = get_fill_attrs(ds_name)
        gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
            grd_s.append(gvals.flatten())
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    cldmsk = get_grid_values(h5f, 'cloud_mask', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]
    satzen = satzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]
    cldmsk = cldmsk[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]

    varX = np.stack(grd_s, axis=1)

    return varX, solzen.flatten(), satzen.flatten(), cldmsk.flatten(), ll, cc


flt_level_ranges_str = {k: None for k in range(6)}
flt_level_ranges_str[0] = '0_2000'
flt_level_ranges_str[1] = '2000_4000'
flt_level_ranges_str[2] = '4000_6000'
flt_level_ranges_str[3] = '6000_8000'
flt_level_ranges_str[4] = '8000_15000'
flt_level_ranges_str[-1] = 'CTH_layer'

# flt_level_ranges_str = {k: None for k in range(1)}
# flt_level_ranges_str[0] = 'column'


def write_icing_file(clvrx_str_time, output_dir, preds_dct, probs_dct, x, y, lons, lats, elems, lines):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.h5'
    h5f_out = h5py.File(outfile_name, 'w')

    dim_0_name = 'x'
    dim_1_name = 'y'

    prob_s = []
    pred_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]
        pred_s.append(preds)
        icing_pred_ds = h5f_out.create_dataset('icing_prediction_level_'+flt_level_ranges_str[flvl], data=preds, dtype='i2')
        icing_pred_ds.attrs.create('coordinates', data='y x')
        icing_pred_ds.attrs.create('grid_mapping', data='Projection')
        icing_pred_ds.attrs.create('missing', data=-1)
        icing_pred_ds.dims[0].label = dim_0_name
        icing_pred_ds.dims[1].label = dim_1_name

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)
        icing_prob_ds = h5f_out.create_dataset('icing_probability_level_'+flt_level_ranges_str[flvl], data=probs, dtype='f4')
        icing_prob_ds.attrs.create('coordinates', data='y x')
        icing_prob_ds.attrs.create('grid_mapping', data='Projection')
        icing_prob_ds.attrs.create('missing', data=-1.0)
        icing_prob_ds.dims[0].label = dim_0_name
        icing_prob_ds.dims[1].label = dim_1_name

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)

    icing_prob_ds = h5f_out.create_dataset('max_icing_probability_column', data=max_prob, dtype='f4')
    icing_prob_ds.attrs.create('coordinates', data='y x')
    icing_prob_ds.attrs.create('grid_mapping', data='Projection')
    icing_prob_ds.attrs.create('missing', data=-1.0)
    icing_prob_ds.dims[0].label = dim_0_name
    icing_prob_ds.dims[1].label = dim_1_name

    max_lvl = np.argmax(prob_s, axis=2)

    icing_pred_ds = h5f_out.create_dataset('max_icing_probability_level', data=max_lvl, dtype='i2')
    icing_pred_ds.attrs.create('coordinates', data='y x')
    icing_pred_ds.attrs.create('grid_mapping', data='Projection')
    icing_pred_ds.attrs.create('missing', data=-1)
    icing_pred_ds.dims[0].label = dim_0_name
    icing_pred_ds.dims[1].label = dim_1_name

    lon_ds = h5f_out.create_dataset('longitude', data=lons, dtype='f4')
    lon_ds.attrs.create('units', data='degrees_east')
    lon_ds.attrs.create('long_name', data='icing prediction longitude')
    lon_ds.dims[0].label = dim_0_name
    lon_ds.dims[1].label = dim_1_name

    lat_ds = h5f_out.create_dataset('latitude', data=lats, dtype='f4')
    lat_ds.attrs.create('units', data='degrees_north')
    lat_ds.attrs.create('long_name', data='icing prediction latitude')
    lat_ds.dims[0].label = dim_0_name
    lat_ds.dims[1].label = dim_1_name

    proj_ds = h5f_out.create_dataset('Projection', data=0, dtype='b')
    proj_ds.attrs.create('long_name', data='Himawari Imagery Projection')
    proj_ds.attrs.create('grid_mapping_name', data='geostationary')
    proj_ds.attrs.create('sweep_angle_axis', data='y')
    proj_ds.attrs.create('units', data='rad')
    proj_ds.attrs.create('semi_major_axis', data=6378.137)
    proj_ds.attrs.create('semi_minor_axis', data=6356.7523)
    proj_ds.attrs.create('inverse_flattening', data=298.257)
    proj_ds.attrs.create('perspective_point_height', data=35785.863)
    proj_ds.attrs.create('latitude_of_projection_origin', data=0.0)
    proj_ds.attrs.create('longitude_of_projection_origin', data=140.7)
    proj_ds.attrs.create('CFAC', data=20466275)
    proj_ds.attrs.create('LFAC', data=20466275)
    proj_ds.attrs.create('COFF', data=2750.5)
    proj_ds.attrs.create('LOFF', data=2750.5)

    if x is not None:
        x_ds = h5f_out.create_dataset('x', data=x, dtype='f8')
        x_ds.dims[0].label = dim_0_name
        x_ds.attrs.create('units', data='rad')
        x_ds.attrs.create('standard_name', data='projection_x_coordinate')
        x_ds.attrs.create('long_name', data='GOES PUG W-E fixed grid viewing angle')
        x_ds.attrs.create('scale_factor', data=5.58879902955962e-05)
        x_ds.attrs.create('add_offset', data=-0.153719917308037)
        x_ds.attrs.create('CFAC', data=20466275)
        x_ds.attrs.create('COFF', data=2750.5)

        y_ds = h5f_out.create_dataset('y', data=y, dtype='f8')
        y_ds.dims[0].label = dim_1_name
        y_ds.attrs.create('units', data='rad')
        y_ds.attrs.create('standard_name', data='projection_y_coordinate')
        y_ds.attrs.create('long_name', data='GOES PUG S-N fixed grid viewing angle')
        y_ds.attrs.create('scale_factor', data=-5.58879902955962e-05)
        y_ds.attrs.create('add_offset', data=0.153719917308037)
        y_ds.attrs.create('LFAC', data=20466275)
        y_ds.attrs.create('LOFF', data=2750.5)

    if elems is not None:
        elem_ds = h5f_out.create_dataset('elems', data=elems, dtype='i2')
        elem_ds.dims[0].label = dim_0_name
        line_ds = h5f_out.create_dataset('lines', data=lines, dtype='i2')
        line_ds.dims[0].label = dim_1_name
        pass

    h5f_out.close()


def write_icing_file_nc4(clvrx_str_time, output_dir, preds_dct, probs_dct,
                         x, y, lons, lats, elems, lines, satellite='GOES16', domain='CONUS',
                         has_time=False, use_nan=False, prob_thresh=0.5, bt_10_4=None, cld_top_hgt_max=None):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.nc'
    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'time y x'

    dim_0 = rootgrp.createDimension(dim_0_name, size=x.shape[0])
    dim_1 = rootgrp.createDimension(dim_1_name, size=y.shape[0])
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    prob_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]

        icing_pred_ds = rootgrp.createVariable('icing_prediction_level_'+flt_level_ranges_str[flvl], 'i2', var_dim_list)
        icing_pred_ds.setncattr('coordinates', geo_coords)
        icing_pred_ds.setncattr('grid_mapping', 'Projection')
        icing_pred_ds.setncattr('missing', -1)
        if has_time:
            preds = preds.reshape((1, y.shape[0], x.shape[0]))
        icing_pred_ds[:,] = preds

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)

        icing_prob_ds = rootgrp.createVariable('icing_probability_level_'+flt_level_ranges_str[flvl], 'f4', var_dim_list)
        icing_prob_ds.setncattr('coordinates', geo_coords)
        icing_prob_ds.setncattr('grid_mapping', 'Projection')
        if not use_nan:
            icing_prob_ds.setncattr('missing', -1.0)
        else:
            icing_prob_ds.setncattr('missing', np.nan)
        if has_time:
            probs = probs.reshape((1, y.shape[0], x.shape[0]))
        if use_nan:
            probs = np.where(probs < prob_thresh, np.nan, probs)
        icing_prob_ds[:,] = probs

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)
    if use_nan:
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)
    if has_time:
        max_prob = max_prob.reshape(1, y.shape[0], x.shape[0])

    icing_prob_ds = rootgrp.createVariable('max_icing_probability_column', 'f4', var_dim_list)
    icing_prob_ds.setncattr('coordinates', geo_coords)
    icing_prob_ds.setncattr('grid_mapping', 'Projection')
    if not use_nan:
        icing_prob_ds.setncattr('missing', -1.0)
    else:
        icing_prob_ds.setncattr('missing', np.nan)
    icing_prob_ds[:,] = max_prob

    prob_s = np.where(prob_s < prob_thresh, -1.0, prob_s)
    max_lvl = np.where(np.all(prob_s == -1, axis=2), -1, np.argmax(prob_s, axis=2))
    if use_nan:
        max_lvl = np.where(max_lvl == -1, np.nan, max_lvl)
    if has_time:
        max_lvl = max_lvl.reshape((1, y.shape[0], x.shape[0]))

    icing_pred_ds = rootgrp.createVariable('max_icing_probability_level', 'i2', var_dim_list)
    icing_pred_ds.setncattr('coordinates', geo_coords)
    icing_pred_ds.setncattr('grid_mapping', 'Projection')
    icing_pred_ds.setncattr('missing', -1)
    icing_pred_ds[:,] = max_lvl

    if bt_10_4 is not None:
        bt_ds = rootgrp.createVariable('bt_10_4', 'f4', var_dim_list)
        bt_ds.setncattr('coordinates', geo_coords)
        bt_ds.setncattr('grid_mapping', 'Projection')
        bt_ds[:,] = bt_10_4

    if cld_top_hgt_max is not None:
        cth_ds = rootgrp.createVariable('cld_top_hgt_max', 'f4', var_dim_list)
        cth_ds.setncattr('coordinates', geo_coords)
        cth_ds.setncattr('grid_mapping', 'Projection')
        cth_ds.units = 'meter'
        cth_ds[:,] = cld_top_hgt_max

    lon_ds = rootgrp.createVariable('longitude', 'f4', [dim_1_name, dim_0_name])
    lon_ds.units = 'degrees_east'
    lon_ds[:,] = lons

    lat_ds = rootgrp.createVariable('latitude', 'f4', [dim_1_name, dim_0_name])
    lat_ds.units = 'degrees_north'
    lat_ds[:,] = lats

    cf_nav_dct = get_cf_nav_parameters(satellite, domain)

    if satellite == 'H08':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'H09':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'GOES16':
        long_name = 'GOES-16/17 Imagery Projection'

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('long_name', long_name)
    proj_ds.setncattr('grid_mapping_name', 'geostationary')
    proj_ds.setncattr('sweep_angle_axis', cf_nav_dct['sweep_angle_axis'])
    proj_ds.setncattr('semi_major_axis', cf_nav_dct['semi_major_axis'])
    proj_ds.setncattr('semi_minor_axis', cf_nav_dct['semi_minor_axis'])
    proj_ds.setncattr('inverse_flattening', cf_nav_dct['inverse_flattening'])
    proj_ds.setncattr('perspective_point_height', cf_nav_dct['perspective_point_height'])
    proj_ds.setncattr('latitude_of_projection_origin', cf_nav_dct['latitude_of_projection_origin'])
    proj_ds.setncattr('longitude_of_projection_origin', cf_nav_dct['longitude_of_projection_origin'])

    if x is not None:
        x_ds = rootgrp.createVariable(dim_0_name, 'f8', [dim_0_name])
        x_ds.units = 'rad'
        x_ds.setncattr('axis', 'X')
        x_ds.setncattr('standard_name', 'projection_x_coordinate')
        x_ds.setncattr('long_name', 'fixed grid viewing angle')
        x_ds.setncattr('scale_factor', cf_nav_dct['x_scale_factor'])
        x_ds.setncattr('add_offset', cf_nav_dct['x_add_offset'])
        x_ds[:] = x

        y_ds = rootgrp.createVariable(dim_1_name, 'f8', [dim_1_name])
        y_ds.units = 'rad'
        y_ds.setncattr('axis', 'Y')
        y_ds.setncattr('standard_name', 'projection_y_coordinate')
        y_ds.setncattr('long_name', 'fixed grid viewing angle')
        y_ds.setncattr('scale_factor', cf_nav_dct['y_scale_factor'])
        y_ds.setncattr('add_offset', cf_nav_dct['y_add_offset'])
        y_ds[:] = y

    if elems is not None:
        elem_ds = rootgrp.createVariable('elems', 'i2', [dim_0_name])
        elem_ds[:] = elems
        line_ds = rootgrp.createVariable('lines', 'i2', [dim_1_name])
        line_ds[:] = lines
        pass

    rootgrp.close()


def write_icing_file_nc4_viirs(clvrx_str_time, output_dir, preds_dct, probs_dct, lons, lats,
                               has_time=False, use_nan=False, prob_thresh=0.5, bt_10_4=None):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.nc'
    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'longitude latitude'

    dim_1_len, dim_0_len = lons.shape

    dim_0 = rootgrp.createDimension(dim_0_name, size=dim_0_len)
    dim_1 = rootgrp.createDimension(dim_1_name, size=dim_1_len)
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    prob_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]

        icing_pred_ds = rootgrp.createVariable('icing_prediction_level_'+flt_level_ranges_str[flvl], 'i2', var_dim_list)
        icing_pred_ds.setncattr('coordinates', geo_coords)
        icing_pred_ds.setncattr('grid_mapping', 'Projection')
        icing_pred_ds.setncattr('missing', -1)
        if has_time:
            preds = preds.reshape((1, dim_1_len, dim_0_len))
        icing_pred_ds[:,] = preds

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)

        icing_prob_ds = rootgrp.createVariable('icing_probability_level_'+flt_level_ranges_str[flvl], 'f4', var_dim_list)
        icing_prob_ds.setncattr('coordinates', geo_coords)
        icing_prob_ds.setncattr('grid_mapping', 'Projection')
        if not use_nan:
            icing_prob_ds.setncattr('missing', -1.0)
        else:
            icing_prob_ds.setncattr('missing', np.nan)
        if has_time:
            probs = probs.reshape((1, dim_1_len, dim_0_len))
        if use_nan:
            probs = np.where(probs < prob_thresh, np.nan, probs)
        icing_prob_ds[:,] = probs

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)
    if use_nan:
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)
    if has_time:
        max_prob = max_prob.reshape(1, dim_1_len, dim_0_len)

    icing_prob_ds = rootgrp.createVariable('max_icing_probability_column', 'f4', var_dim_list)
    icing_prob_ds.setncattr('coordinates', geo_coords)
    icing_prob_ds.setncattr('grid_mapping', 'Projection')
    if not use_nan:
        icing_prob_ds.setncattr('missing', -1.0)
    else:
        icing_prob_ds.setncattr('missing', np.nan)
    icing_prob_ds[:,] = max_prob

    prob_s = np.where(prob_s < prob_thresh, -1.0, prob_s)
    max_lvl = np.where(np.all(prob_s == -1, axis=2), -1, np.argmax(prob_s, axis=2))
    if use_nan:
        max_lvl = np.where(max_lvl == -1, np.nan, max_lvl)
    if has_time:
        max_lvl = max_lvl.reshape((1, dim_1_len, dim_0_len))

    icing_pred_ds = rootgrp.createVariable('max_icing_probability_level', 'i2', var_dim_list)
    icing_pred_ds.setncattr('coordinates', geo_coords)
    icing_pred_ds.setncattr('grid_mapping', 'Projection')
    icing_pred_ds.setncattr('missing', -1)
    icing_pred_ds[:,] = max_lvl

    if bt_10_4 is not None:
        bt_ds = rootgrp.createVariable('bt_10_4', 'f4', var_dim_list)
        bt_ds.setncattr('coordinates', geo_coords)
        bt_ds.setncattr('grid_mapping', 'Projection')
        bt_ds[:,] = bt_10_4

    lon_ds = rootgrp.createVariable('longitude', 'f4', [dim_1_name, dim_0_name])
    lon_ds.units = 'degrees_east'
    lon_ds[:,] = lons

    lat_ds = rootgrp.createVariable('latitude', 'f4', [dim_1_name, dim_0_name])
    lat_ds.units = 'degrees_north'
    lat_ds[:,] = lats

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('grid_mapping_name', 'latitude_longitude')

    rootgrp.close()


def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                      day_model_path=model_path_day, night_model_path=model_path_night,
                      prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                      l1b_andor_l2='both', use_flight_altitude=True, res_fac=1, use_nan=False,
                      has_time=False, model_type='FCN', use_dnb=False, use_max_cth_level=False):

    if model_type == 'CNN':
        import deeplearning.icing_cnn as icing_cnn
        model_module = icing_cnn
    elif model_type == 'FCN':
        import deeplearning.icing_fcn as icing_fcn
        model_module = icing_fcn

    if day_model_path is not None:
        day_model = model_module.load_model(day_model_path, day_night='DAY', l1b_andor_l2=l1b_andor_l2,
                                            use_flight_altitude=use_flight_altitude)
    if night_model_path is not None:
        night_model = model_module.load_model(night_model_path, day_night='NIGHT', l1b_andor_l2=l1b_andor_l2,
                                              use_flight_altitude=use_flight_altitude)

    flight_levels = [0]
    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
        if use_max_cth_level:
            flight_levels = [-1]

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2, use_dnb=False)
    nght_train_params_dnb, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2, use_dnb=True)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
        train_params_dnb = list(set(day_train_params + nght_train_params_dnb))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params
        train_params_dnb = nght_train_params_dnb

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    elif satellite == 'GOES16':
        clvrx_ds = CLAVRx(clvrx_dir)
    elif satellite == 'VIIRS':
        clvrx_ds = CLAVRx_VIIRS(clvrx_dir)

    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        lunar_illuminated = moon_phase(ts)
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        if satellite == 'GOES16' or satellite == 'H08' or satellite == 'H09':
            data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain, res_fac=res_fac)

            if fidx == 0:  # These don't change for geostationary fixed grids
                nav = get_navigation(satellite, domain)
                lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc, offset=int(8 / res_fac))

            bt_fld_name = 'temp_10_4um_nom'

            ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=
                                ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick', bt_fld_name],
                                satellite=satellite, domain=domain, res_fac=res_fac)
        elif satellite == 'VIIRS':
            if use_dnb and lunar_illuminated:
                train_params = train_params_dnb
            data_dct, ll, cc, lats_2d, lons_2d = make_for_full_domain_predict_viirs_clavrx(h5f, name_list=train_params,
                                                    day_night=day_night, res_fac=res_fac, use_dnb=use_dnb)

            bt_fld_name = 'temp_11_0um_nom'
            ancil_data_dct, _, _, _, _ = make_for_full_domain_predict_viirs_clavrx(h5f, name_list=
                                ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick', bt_fld_name],
                                res_fac=res_fac)

        num_elems, num_lines = len(cc), len(ll)

        satzen = ancil_data_dct['sensor_zenith_angle']
        solzen = ancil_data_dct['solar_zenith_angle']
        cth = ancil_data_dct['cld_height_acha']
        bt_10_4 = ancil_data_dct[bt_fld_name]
        day_idxs = []
        nght_idxs = []
        all_idxs = []
        avg_bt = []
        day_cth_max = []
        nght_cth_max = []
        cth_max = []
        cth_avg = []

        for j in range(num_lines):
            for i in range(num_elems):
                k = i + j*num_elems
                avg_bt.append(get_median(bt_10_4[k]))

                c = cth[k].flatten()
                c_m = np.mean(np.sort(c[np.invert(np.isnan(c))])[-2:])
                cth_max.append(c_m)
                cth_avg.append(np.mean(c[np.invert(np.isnan(c))]))

                c_m = 0 if 2000 > c_m >= 0 else c_m
                c_m = 1 if 4000 > c_m >= 2000 else c_m
                c_m = 2 if 6000 > c_m >= 4000 else c_m
                c_m = 3 if 8000 > c_m >= 6000 else c_m
                c_m = 4 if 15000 > c_m >= 8000 else c_m

                if not check_oblique(satzen[k]):
                    continue

                all_idxs.append(k)
                if is_day(solzen[k]):
                    day_idxs.append(k)
                    day_cth_max.append(c_m)
                else:
                    nght_idxs.append(k)
                    nght_cth_max.append(c_m)

        num_tiles = len(all_idxs)
        num_day_tiles = len(day_idxs)
        num_nght_tiles = len(nght_idxs)
        day_cth_max = np.array(day_cth_max)
        nght_cth_max = np.array(nght_cth_max)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if day_night == 'AUTO' or day_night == 'DAY':
            if num_day_tiles > 0:
                day_idxs = np.array(day_idxs)
                day_grd_dct = {name: None for name in day_train_params}
                for name in day_train_params:
                    day_grd_dct[name] = data_dct[name][day_idxs, ]

                if use_max_cth_level:
                    day_grd_dct['cth_high_avg'] = day_cth_max

                preds_day_dct, probs_day_dct = \
                    model_module.run_evaluate_static_2(day_model, day_grd_dct, num_day_tiles,
                                                       prob_thresh=prob_thresh,
                                                       flight_levels=flight_levels)
                for flvl in flight_levels:
                    day_preds = preds_day_dct[flvl]
                    day_probs = probs_day_dct[flvl]
                    fd_preds = preds_2d_dct[flvl]
                    fd_probs = probs_2d_dct[flvl]
                    fd_preds[day_idxs] = day_preds[:]
                    fd_probs[day_idxs] = day_probs[:]

            if num_nght_tiles > 0:
                mode = 'NIGHT'
                model_path = night_model_path
                if use_dnb and lunar_illuminated:
                    model_path = day_model_path
                    mode = 'DAY'
                    nght_train_params = train_params_dnb

                nght_idxs = np.array(nght_idxs)
                nght_grd_dct = {name: None for name in nght_train_params}
                for name in nght_train_params:
                    nght_grd_dct[name] = data_dct[name][nght_idxs, ]

                if use_max_cth_level:
                    nght_grd_dct['cth_high_avg'] = nght_cth_max

                preds_nght_dct, probs_nght_dct = \
                    model_module.run_evaluate_static_2(night_model, nght_grd_dct, num_nght_tiles,
                                                       prob_thresh=prob_thresh,
                                                       flight_levels=flight_levels)
                for flvl in flight_levels:
                    nght_preds = preds_nght_dct[flvl]
                    nght_probs = probs_nght_dct[flvl]
                    fd_preds = preds_2d_dct[flvl]
                    fd_probs = probs_2d_dct[flvl]
                    fd_preds[nght_idxs] = nght_preds[:]
                    fd_probs[nght_idxs] = nght_probs[:]

        elif day_night == 'NIGHT':
            model_path = night_model_path
            mode = 'NIGHT'
            if use_dnb and lunar_illuminated:
                model_path = day_model_path
                mode = 'DAY'
                nght_train_params = train_params_dnb

            all_idxs = np.array(all_idxs)
            nght_grd_dct = {name: None for name in nght_train_params}
            for name in nght_train_params:
                nght_grd_dct[name] = data_dct[name][all_idxs,]

            if use_max_cth_level:
                nght_grd_dct['cth_high_avg'] = nght_cth_max

            preds_nght_dct, probs_nght_dct = \
                model_module.run_evaluate_static_2(night_model, nght_grd_dct, num_nght_tiles,
                                                   prob_thresh=prob_thresh,
                                                   flight_levels=flight_levels)
            for flvl in flight_levels:
                nght_preds = preds_nght_dct[flvl]
                nght_probs = probs_nght_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[all_idxs] = nght_preds[:]
                fd_probs[all_idxs] = nght_probs[:]

        # combine day and night into full grid  ------------------------------------------
        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        avg_bt = np.array(avg_bt)
        bt_10_4_2d = avg_bt.reshape((num_lines, num_elems))

        cth_max = np.array(cth_max)
        cth_max = cth_max.reshape((num_lines, num_elems))
        cth_avg = np.array(cth_avg)
        cth_avg = cth_avg.reshape((num_lines, num_elems))

        if satellite == 'GOES16' or satellite == 'H08' or satellite == 'H09':
            write_icing_file_nc4(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct,
                                 x_rad, y_rad, lons_2d, lats_2d, cc, ll,
                                 satellite=satellite, domain=domain, use_nan=use_nan, has_time=has_time,
                                 prob_thresh=prob_thresh, bt_10_4=bt_10_4_2d, cld_top_hgt_max=cth_avg)
        elif satellite == 'VIIRS':
            write_icing_file_nc4_viirs(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct,
                                       lons_2d, lats_2d,
                                       use_nan=use_nan, has_time=has_time,
                                       prob_thresh=prob_thresh, bt_10_4=bt_10_4_2d)

        print('Done: ', clvrx_str_time)
        h5f.close()


def run_icing_predict_fcn(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                          day_model_path=model_path_day, night_model_path=model_path_night,
                          prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                          l1b_andor_l2='both', use_flight_altitude=False, use_max_cth_level=False, use_nan=False):

    import deeplearning.icing_fcn as icing_fcn
    model_module = icing_fcn

    if day_model_path is not None:
        day_model = model_module.load_model(day_model_path, day_night='DAY', l1b_andor_l2=l1b_andor_l2,
                                            use_flight_altitude=use_flight_altitude)
    if night_model_path is not None:
        night_model = model_module.load_model(night_model_path, day_night='NIGHT', l1b_andor_l2=l1b_andor_l2,
                                              use_flight_altitude=use_flight_altitude)

    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
        if use_max_cth_level:
            flight_levels = [-1]
    else:
        flight_levels = [0]

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    else:
        clvrx_ds = CLAVRx(clvrx_dir)
    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        data_dct, solzen, satzen, ll, cc = prepare_evaluate(h5f, name_list=train_params, satellite=satellite, domain=domain, offset=8)
        num_elems = len(cc)
        num_lines = len(ll)

        if fidx == 0:
            nav = get_navigation(satellite, domain)
            lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc)

        day_idxs = solzen < 80.0
        day_idxs = day_idxs.flatten()
        num_day_tiles = np.sum(day_idxs)

        nght_idxs = solzen > 100.0
        nght_idxs = nght_idxs.flatten()
        num_nght_tiles = np.sum(nght_idxs)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:

            preds_day_dct, probs_day_dct = icing_fcn.run_evaluate_static_2(day_model, data_dct, 1,
                                                                           prob_thresh=prob_thresh,
                                                                           flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = (preds_day_dct[flvl]).flatten()
                probs = (probs_day_dct[flvl]).flatten()
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[day_idxs] = preds[day_idxs]
                fd_probs[day_idxs] = probs[day_idxs]

        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
            preds_nght_dct, probs_nght_dct = icing_fcn.run_evaluate_static_2(night_model, data_dct, 1,
                                                                             prob_thresh=prob_thresh,
                                                                             flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = (preds_nght_dct[flvl]).flatten()
                probs = (probs_nght_dct[flvl]).flatten()
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[nght_idxs] = preds[nght_idxs]
                fd_probs[nght_idxs] = probs[nght_idxs]

        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        write_icing_file_nc4(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct,
                             x_rad, y_rad, lons_2d, lats_2d, cc, ll,
                             satellite=satellite, domain=domain, use_nan=use_nan, prob_thresh=prob_thresh)

        print('Done: ', clvrx_str_time)
        h5f.close()


def run_icing_predict_image(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                            day_model_path=model_path_day, night_model_path=model_path_night,
                            prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                            l1b_andor_l2='BOTH', use_flight_altitude=True, flight_levels=[0, 1, 2, 3, 4],
                            res_fac=1, model_type='CNN', extent=[-105, -70, 15, 50], pirep_file=None, icing_dict=None,
                            data_extent=None, use_max_cth_level=False,
                            obs_lons=None, obs_lats=None, obs_times=None, obs_alt=None, obs_intensity=None,
                            icing_intensity_plot_min=1):
    if model_type == 'CNN':
        import deeplearning.icing_cnn as icing_cnn
        model_module = icing_cnn
    elif model_type == 'FCN':
        import deeplearning.icing_fcn as icing_fcn
        model_module = icing_fcn

    if day_model_path is not None:
        day_model = model_module.load_model(day_model_path, day_night='DAY', l1b_andor_l2=l1b_andor_l2,
                                            use_flight_altitude=use_flight_altitude)
    if night_model_path is not None:
        night_model = model_module.load_model(night_model_path, day_night='NIGHT', l1b_andor_l2=l1b_andor_l2,
                                              use_flight_altitude=use_flight_altitude)

    alt_lo, alt_hi = 0.0, 15000.0
    if use_flight_altitude is True:
        flight_levels = flight_levels
        alt_lo = flt_level_ranges[flight_levels[0]][0]
        alt_hi = flt_level_ranges[flight_levels[-1]][1]
    else:
        flight_levels = [0]
        if use_max_cth_level:
            flight_levels = [-1]

    if pirep_file is not None:
        ice_dict, no_ice_dict, neg_ice_dict = setup(pirep_file)
    if icing_dict is not None:
        ice_dict = icing_dict

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    else:
        clvrx_ds = CLAVRx(clvrx_dir)
    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain,
                                                        res_fac=res_fac, data_extent=data_extent)

        if fidx == 0:
            num_elems = len(cc)
            num_lines = len(ll)
            nav = get_navigation(satellite, domain)

        ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=
                            ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick'],
                            satellite=satellite, domain=domain, res_fac=res_fac,
                            data_extent=data_extent)

        satzen = ancil_data_dct['sensor_zenith_angle']
        solzen = ancil_data_dct['solar_zenith_angle']
        cth = ancil_data_dct['cld_height_acha']
        day_idxs = []
        nght_idxs = []
        day_cth_max = []
        nght_cth_max = []
        for j in range(num_lines):
            for i in range(num_elems):
                k = i + j*num_elems
                c = cth[k].flatten()
                c_m = np.mean(np.sort(c[np.invert(np.isnan(c))])[-2:])

                c_m = 0 if 2000 > c_m >= 0 else c_m
                c_m = 1 if 4000 > c_m >= 2000 else c_m
                c_m = 2 if 6000 > c_m >= 4000 else c_m
                c_m = 3 if 8000 > c_m >= 6000 else c_m
                c_m = 4 if 15000 > c_m >= 8000 else c_m

                if not check_oblique(satzen[k]):
                    continue
                if is_day(solzen[k]):
                    day_idxs.append(k)
                    day_cth_max.append(c_m)
                else:
                    nght_idxs.append(k)
                    nght_cth_max.append(c_m)

        num_tiles = num_lines * num_elems
        num_day_tiles = len(day_idxs)
        num_nght_tiles = len(nght_idxs)
        day_cth_max = np.array(day_cth_max)
        nght_cth_max = np.array(nght_cth_max)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:
            day_idxs = np.array(day_idxs)
            day_grd_dct = {name: None for name in day_train_params}
            for name in day_train_params:
                day_grd_dct[name] = data_dct[name][day_idxs,]

            if use_max_cth_level:
                day_grd_dct['cth_high_avg'] = day_cth_max

            preds_day_dct, probs_day_dct = \
                model_module.run_evaluate_static_2(day_model, day_grd_dct, num_day_tiles, prob_thresh=prob_thresh,
                                                   flight_levels=flight_levels)

            day_idxs = np.array(day_idxs)
            for flvl in flight_levels:
                day_preds = preds_day_dct[flvl]
                day_probs = probs_day_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[day_idxs] = day_preds[:]
                fd_probs[day_idxs] = day_probs[:]

        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
            nght_idxs = np.array(nght_idxs)
            nght_grd_dct = {name: None for name in nght_train_params}
            for name in nght_train_params:
                nght_grd_dct[name] = data_dct[name][nght_idxs,]

            if use_max_cth_level:
                nght_grd_dct['cth_high_avg'] = nght_cth_max

            preds_nght_dct, probs_nght_dct = \
                model_module.run_evaluate_static_2(night_model, nght_grd_dct, num_nght_tiles, prob_thresh=prob_thresh,
                                                   flight_levels=flight_levels)
            nght_idxs = np.array(nght_idxs)
            for flvl in flight_levels:
                nght_preds = preds_nght_dct[flvl]
                nght_probs = probs_nght_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[nght_idxs] = nght_preds[:]
                fd_probs[nght_idxs] = nght_probs[:]

        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        dto, _ = get_time_tuple_utc(ts)
        dto_0 = dto - datetime.timedelta(minutes=30)
        dto_1 = dto + datetime.timedelta(minutes=30)
        ts_0 = dto_0.timestamp()
        ts_1 = dto_1.timestamp()

        if pirep_file is not None:
            _, keep_lons, keep_lats, _ = time_filter_3(ice_dict, ts_0, ts_1, alt_lo, alt_hi)
        elif obs_times is not None:
            keep = np.logical_and(obs_times >= ts_0, obs_times < ts_1)
            keep = np.where(keep, np.logical_and(obs_alt >= alt_lo, obs_alt < alt_hi), False)
            keep = np.where(keep, obs_intensity >= icing_intensity_plot_min, False)
            keep_lons = obs_lons[keep]
            keep_lats = obs_lats[keep]
            keep_obs_intensity = obs_intensity[keep]
        else:
            keep_lons = None
            keep_lats = None
            keep_obs_intensity = None

        prob_s = []
        for flvl in flight_levels:
            probs = probs_2d_dct[flvl]
            prob_s.append(probs)
        prob_s = np.stack(prob_s, axis=-1)
        max_prob = np.max(prob_s, axis=2)
        # max_prob = medfilt2d(max_prob)
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)

        make_icing_image(h5f, max_prob, None, None, clvrx_str_time, satellite, domain,
                         ice_lons_vld=keep_lons, ice_lats_vld=keep_lats, ice_intensity=keep_obs_intensity, extent=extent)

        print('Done: ', clvrx_str_time)
        h5f.close()


def run_icing_predict_image_fcn(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                                day_model_path=model_path_day, night_model_path=model_path_night,
                                prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
                                l1b_andor_l2='BOTH', use_flight_altitude=True, use_max_cth_level=False,
                                extent=[-105, -70, 15, 50],
                                pirep_file='/Users/tomrink/data/pirep/pireps_202109200000_202109232359.csv',
                                obs_lons=None, obs_lats=None, obs_times=None, obs_alt=None, flight_level=None, obs_intensity=None):

    import deeplearning.icing_fcn as icing_fcn
    model_module = icing_fcn

    if day_model_path is not None:
        day_model = model_module.load_model(day_model_path, day_night='DAY', l1b_andor_l2=l1b_andor_l2,
                                            use_flight_altitude=use_flight_altitude)
    if night_model_path is not None:
        night_model = model_module.load_model(night_model_path, day_night='NIGHT', l1b_andor_l2=l1b_andor_l2,
                                              use_flight_altitude=use_flight_altitude)

    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
        if use_max_cth_level:
            flight_levels = [-1]
    else:
        flight_levels = [0]

    if pirep_file is not None:
        ice_dict, no_ice_dict, neg_ice_dict = setup(pirep_file)

    alt_lo, alt_hi = 0.0, 15000.0
    if flight_level is not None:
        alt_lo, alt_hi = flt_level_ranges[flight_level]

    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    else:
        clvrx_ds = CLAVRx(clvrx_dir)
    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        dto, _ = get_time_tuple_utc(ts)
        dto_0 = dto - datetime.timedelta(minutes=30)
        dto_1 = dto + datetime.timedelta(minutes=30)
        ts_0 = dto_0.timestamp()
        ts_1 = dto_1.timestamp()

        if pirep_file is not None:
            _, keep_lons, keep_lats, _ = time_filter_3(ice_dict, ts_0, ts_1, alt_lo, alt_hi)
        elif obs_times is not None:
            keep = np.logical_and(obs_times >= ts_0, obs_times < ts_1)
            keep = np.where(keep, np.logical_and(obs_alt >= alt_lo, obs_alt < alt_hi), False)
            keep_lons = obs_lons[keep]
            keep_lats = obs_lats[keep]
        else:
            keep_lons = None
            keep_lats = None

        data_dct, solzen, satzen, ll, cc = prepare_evaluate(h5f, name_list=train_params, satellite=satellite, domain=domain, offset=8)
        num_elems = len(cc)
        num_lines = len(ll)

        if fidx == 0:
            nav = get_navigation(satellite, domain)
            lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc)

        day_idxs = solzen < 80.0
        day_idxs = day_idxs.flatten()
        num_day_tiles = np.sum(day_idxs)

        nght_idxs = solzen > 100.0
        nght_idxs = nght_idxs.flatten()
        num_nght_tiles = np.sum(nght_idxs)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:
            preds_day_dct, probs_day_dct = icing_fcn.run_evaluate_static_2(day_model, data_dct, 1,
                                                                           prob_thresh=prob_thresh,
                                                                           flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = preds_day_dct[flvl].flatten()
                probs = probs_day_dct[flvl].flatten()
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[day_idxs] = preds[day_idxs]
                fd_probs[day_idxs] = probs[day_idxs]

        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
            preds_nght_dct, probs_nght_dct = icing_fcn.run_evaluate_static_2(night_model, data_dct, 1,
                                                                             prob_thresh=prob_thresh,
                                                                             flight_levels=flight_levels)
            for flvl in flight_levels:
                preds = preds_nght_dct[flvl].flatten()
                probs = probs_nght_dct[flvl].flatten()
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
                fd_preds[nght_idxs] = preds[nght_idxs]
                fd_probs[nght_idxs] = probs[nght_idxs]

        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

        prob_s = []
        for flvl in flight_levels:
            probs = probs_2d_dct[flvl]
            prob_s.append(probs)
        prob_s = np.stack(prob_s, axis=-1)
        max_prob = np.max(prob_s, axis=2)
        max_prob = np.where(max_prob < 0.5, np.nan, max_prob)

        make_icing_image(h5f, max_prob, None, None, clvrx_str_time, satellite, domain,
                         ice_lons_vld=keep_lons, ice_lats_vld=keep_lats, extent=extent)

        print('Done: ', clvrx_str_time)
        h5f.close()


def run_icing_predict_image_1x1(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                                day_model_path=model_path_day, night_model_path=model_path_night,
                                prob_thresh=0.52, satellite='GOES16', domain='CONUS', day_night='AUTO',
                                l1b_andor_l2='BOTH', use_flight_altitude=True, use_max_cth_level=False,
                                extent=[-105, -70, 15, 50],
                                pirep_file='/Users/tomrink/data/pirep/pireps_202109200000_202109232359.csv',
                                obs_lons=None, obs_lats=None, obs_times=None, obs_alt=None, flight_level=None, obs_intensity=None):


    # load parameter stats and model from disk
    stdSclr_day = joblib.load('/home/rink/stdSclr_day.pkl')
    day_model = joblib.load('/home/rink/icing_gbm_day.pkl')

    stdSclr_nght = joblib.load('/home/rink/stdSclr_nght.pkl')
    nght_model = joblib.load('/home/rink/icing_gbm_nght.pkl')

    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
        if use_max_cth_level:
            flight_levels = [-1]
    else:
        flight_levels = [0]

    if pirep_file is not None:
        ice_dict, no_ice_dict, neg_ice_dict = setup(pirep_file)

    alt_lo, alt_hi = 0.0, 15000.0
    if flight_level is not None:
        alt_lo, alt_hi = flt_level_ranges[flight_level]

    day_train_params = ['cld_temp_acha', 'supercooled_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp']
    nght_train_params = ['cld_temp_acha', 'supercooled_cloud_fraction', 'cld_reff_acha', 'cld_opd_acha']
    #
    # if day_night == 'AUTO':
    #     train_params = list(set(day_train_params + nght_train_params))
    # elif day_night == 'DAY':
    #     train_params = day_train_params
    # elif day_night == 'NIGHT':
    #     train_params = nght_train_params

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
    else:
        clvrx_ds = CLAVRx(clvrx_dir)
    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

        dto, _ = get_time_tuple_utc(ts)
        dto_0 = dto - datetime.timedelta(minutes=30)
        dto_1 = dto + datetime.timedelta(minutes=30)
        ts_0 = dto_0.timestamp()
        ts_1 = dto_1.timestamp()

        if pirep_file is not None:
            _, keep_lons, keep_lats, _ = time_filter_3(ice_dict, ts_0, ts_1, alt_lo, alt_hi)
        elif obs_times is not None:
            keep = np.logical_and(obs_times >= ts_0, obs_times < ts_1)
            keep = np.where(keep, np.logical_and(obs_alt >= alt_lo, obs_alt < alt_hi), False)
            keep_lons = obs_lons[keep]
            keep_lats = obs_lats[keep]
        else:
            keep_lons = None
            keep_lats = None

        varX_day, solzen, satzen, cldmsk, ll, cc = prepare_evaluate1x1(h5f, name_list=day_train_params, satellite=satellite, domain=domain, offset=8)
        varX_ngth, solzen, satzen, cldmsk, ll, cc = prepare_evaluate1x1(h5f, name_list=nght_train_params, satellite=satellite, domain=domain, offset=8)
        num_elems = len(cc)
        num_lines = len(ll)
        print('data: ', varX_day.shape, varX_ngth.shape, solzen.shape)
        print('num_lines, num_elems: ', num_lines, num_elems)

        day_idxs = solzen < 80.0
        num_day_tiles = np.sum(day_idxs)
        print('num day tiles: ', num_day_tiles)

        nght_idxs = solzen > 100.0
        num_nght_tiles = np.sum(nght_idxs)
        print('num nght tiles: ', num_nght_tiles)

        clr_idxs = cldmsk < 2
        num_clr_tiles = np.sum(clr_idxs)
        print('num clear tiles: ', num_clr_tiles)

        fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
        fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
        fd_preds[:] = -1
        fd_probs[:] = -1.0

        if (day_night == 'AUTO' or day_night == 'DAY') and num_day_tiles > 0:
            varX_std = stdSclr_day.transform(varX_day)
            varX_std = np.where(np.isnan(varX_std), 0, varX_std)
            probs = day_model.predict_proba(varX_std)[:, 1]
            fd_probs[day_idxs] = probs[day_idxs]

        if (day_night == 'AUTO' or day_night == 'NIGHT') and num_nght_tiles > 0:
            varX_std = stdSclr_nght.transform(varX_ngth)
            varX_std = np.where(np.isnan(varX_std), 0, varX_std)
            probs = nght_model.predict_proba(varX_std)[:, 1]
            fd_probs[np.invert(day_idxs)] = probs[np.invert(day_idxs)]

        fd_probs[clr_idxs] = -1.0
        fd_probs = np.where(fd_probs < prob_thresh, np.nan, fd_probs)

        max_prob = fd_probs.reshape((num_lines, num_elems))

        make_icing_image(h5f, max_prob, None, None, clvrx_str_time, satellite, domain,
                         ice_lons_vld=keep_lons, ice_lats_vld=keep_lats, extent=extent, prob_thresh=prob_thresh)

        print('Done: ', clvrx_str_time)
        h5f.close()