Skip to content
Snippets Groups Projects
util.py 65.3 KiB
Newer Older
import numpy as np
from icing.pirep_goes import setup, time_filter_3
tomrink's avatar
tomrink committed
from icing.moon_phase import moon_phase
tomrink's avatar
tomrink committed
from util.util import get_time_tuple_utc, is_day, check_oblique, get_median, homedir, \
    get_cartopy_crs, \
    get_fill_attrs, get_grid_values, GenericException, get_timestamp, get_cf_nav_parameters
from util.plot import make_icing_image
from util.geos_nav import get_navigation, get_lon_lat_2d_mesh
from util.setup import model_path_day, model_path_night
tomrink's avatar
tomrink committed
from aeolus.datasource import CLAVRx, CLAVRx_VIIRS, CLAVRx_H08, CLAVRx_H09
import h5py
import datetime
tomrink's avatar
tomrink committed
from netCDF4 import Dataset
tomrink's avatar
tomrink committed
import joblib
tomrink's avatar
tomrink committed
import tensorflow as tf
import os
tomrink's avatar
tomrink committed
# from scipy.signal import medfilt2d
tomrink's avatar
tomrink committed

flt_level_ranges = {k: None for k in range(5)}
flt_level_ranges[0] = [0.0, 2000.0]
flt_level_ranges[1] = [2000.0, 4000.0]
flt_level_ranges[2] = [4000.0, 6000.0]
flt_level_ranges[3] = [6000.0, 8000.0]
flt_level_ranges[4] = [8000.0, 15000.0]

tomrink's avatar
tomrink committed
# Taiwan domain:
# lon, lat = 120.955098, 23.834310
# elem, line = (1789, 1505)
# # UR from Taiwan
# lon, lat = 135.0, 35.0
# elem_ur, line_ur = (2499, 995)
taiwan_i0 = 1079
taiwan_j0 = 995
taiwan_lenx = 1420
taiwan_leny = 1020
# geos.transform_point(135.0, 35.0, ccrs.PlateCarree(), False)
# geos.transform_point(106.61, 13.97, ccrs.PlateCarree(), False)
taiwain_extent = [-3342, -502, 1470, 3510]  # GEOS coordinates, not line, elem


def get_training_parameters(day_night='DAY', l1b_andor_l2='both', satellite='GOES16', use_dnb=False):
    if day_night == 'DAY':
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
tomrink's avatar
tomrink committed
        # train_params_l2 = ['cld_temp_acha', 'supercooled_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
tomrink's avatar
tomrink committed

        if satellite == 'GOES16':
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
                                'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
                                # 'refl_2_10um_nom']
tomrink's avatar
tomrink committed
            # train_params_l1b = ['refl_1_38um_nom', 'refl_1_60um_nom', 'temp_8_5um_nom', 'temp_11_0um_nom']
tomrink's avatar
tomrink committed
        elif satellite == 'H08':
            train_params_l1b = ['temp_10_4um_nom', 'temp_12_0um_nom', 'temp_8_5um_nom', 'temp_3_75um_nom', 'refl_2_10um_nom',
                                'refl_1_60um_nom', 'refl_0_86um_nom', 'refl_0_47um_nom']
    else:
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_acha', 'cld_opd_acha']
tomrink's avatar
tomrink committed
        # train_params_l2 = ['cld_temp_acha', 'supercooled_cloud_fraction', 'cld_emiss_acha', 'cld_reff_acha', 'cld_opd_acha']
tomrink's avatar
tomrink committed

        if use_dnb is True:
            train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                               'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']

        if satellite == 'GOES16':
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom']
tomrink's avatar
tomrink committed
            # train_params_l1b = ['temp_8_5um_nom', 'temp_11_0um_nom']
tomrink's avatar
tomrink committed
        elif satellite == 'H08':
            train_params_l1b = ['temp_10_4um_nom', 'temp_12_0um_nom', 'temp_8_5um_nom', 'temp_3_75um_nom']

    if l1b_andor_l2 == 'both':
        train_params = train_params_l1b + train_params_l2
    elif l1b_andor_l2 == 'l1b':
        train_params = train_params_l1b
    elif l1b_andor_l2 == 'l2':
        train_params = train_params_l2

    return train_params, train_params_l1b, train_params_l2


def make_for_full_domain_predict(h5f, name_list=None, satellite='GOES16', domain='FD', res_fac=1,
                                 data_extent=None):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    geos_nav = get_navigation(satellite, domain)

    if satellite == 'GOES16':
        if data_extent is not None:
            ll_pt_a = data_extent[0]
            ll_pt_b = data_extent[1]
            if domain == 'CONUS':
tomrink's avatar
tomrink committed
                y_a, x_a = geos_nav.earth_to_lc(ll_pt_a.lon, ll_pt_a.lat)
                y_b, x_b = geos_nav.earth_to_lc(ll_pt_b.lon, ll_pt_b.lat)
                i_0 = x_a if x_a < x_b else x_b
                j_0 = y_a if y_a < y_b else y_b
                xlen = x_b - x_a if x_a < x_b else x_a - x_b
                ylen = y_b - y_a if y_a < y_b else y_a - y_b
tomrink's avatar
tomrink committed

    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    grd_dct = {name: None for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
tomrink's avatar
tomrink committed
        if (satellite == 'H08' or satellite == 'H09') and ds_name == 'temp_6_2um_nom':
            gvals = np.full((ylen, xlen), np.nan)
        else:
            fill_value, fill_value_name = get_fill_attrs(ds_name)
            gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
tomrink's avatar
tomrink committed
        if gvals is not None:
            grd_dct[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    grd_dct_n = {name: [] for name in name_list}

    n_x = int(xlen/s_x) - 1
    n_y = int(ylen/s_y) - 1

    r_x = xlen - (n_x * s_x)
    x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    n_x -= x_d

    r_y = ylen - (n_y * s_y)
    y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    n_y -= y_d

    ll = [j_0 + j*s_y for j in range(n_y)]
    cc = [i_0 + i*s_x for i in range(n_x)]

    for ds_name in name_list:
        for j in range(n_y):
            j_ul = j * s_y
            j_ul_b = j_ul + w_y
            for i in range(n_x):
                i_ul = i * s_x
                i_ul_b = i_ul + w_x
                grd_dct_n[ds_name].append(grd_dct[ds_name][j_ul:j_ul_b, i_ul:i_ul_b])

    grd_dct = {name: None for name in name_list}
    for ds_name in name_list:
        grd_dct[ds_name] = np.stack(grd_dct_n[ds_name])

    return grd_dct, ll, cc


def make_for_full_domain_predict_viirs_clavrx(h5f, name_list=None, res_fac=1, day_night='DAY', use_dnb=False):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    ylen = h5f['scan_lines_along_track_direction'].shape[0]
    xlen = h5f['pixel_elements_along_scan_direction'].shape[0]

    use_nl_comp = False
    if (day_night == 'NIGHT' or day_night == 'AUTO') and use_dnb:
        use_nl_comp = True

    grd_dct = {name: None for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        name = ds_name

        if use_nl_comp:
            if ds_name == 'cld_reff_dcomp':
                name = 'cld_reff_nlcomp'
            elif ds_name == 'cld_opd_dcomp':
                name = 'cld_opd_nlcomp'

        fill_value, fill_value_name = get_fill_attrs(name)
        gvals = get_grid_values(h5f, name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
            grd_dct[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    # TODO: need to investigate discrepencies with compute_lwc_iwc
    # if use_nl_comp:
    #     cld_phase = get_grid_values(h5f, 'cloud_phase', j_0, i_0, None, num_j=ylen, num_i=xlen)
    #     cld_dz = get_grid_values(h5f, 'cld_geo_thick', j_0, i_0, None, num_j=ylen, num_i=xlen)
    #     reff = grd_dct['cld_reff_dcomp']
    #     opd = grd_dct['cld_opd_dcomp']
    #
    #     lwc_nlcomp, iwc_nlcomp = compute_lwc_iwc(cld_phase, reff, opd, cld_dz)
    #     grd_dct['iwc_dcomp'] = iwc_nlcomp
    #     grd_dct['lwc_dcomp'] = lwc_nlcomp

    grd_dct_n = {name: [] for name in name_list}

    n_x = int(xlen/s_x) - 1
    n_y = int(ylen/s_y) - 1

    r_x = xlen - (n_x * s_x)
    x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    n_x -= x_d

    r_y = ylen - (n_y * s_y)
    y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    n_y -= y_d

    ll = [j_0 + j*s_y for j in range(n_y)]
    cc = [i_0 + i*s_x for i in range(n_x)]

    for ds_name in name_list:
        for j in range(n_y):
            j_ul = j * s_y
            j_ul_b = j_ul + w_y
            for i in range(n_x):
                i_ul = i * s_x
                i_ul_b = i_ul + w_x
                grd_dct_n[ds_name].append(grd_dct[ds_name][j_ul:j_ul_b, i_ul:i_ul_b])

    grd_dct = {name: None for name in name_list}
    for ds_name in name_list:
        grd_dct[ds_name] = np.stack(grd_dct_n[ds_name])

    lats = get_grid_values(h5f, 'latitude', j_0, i_0, None, num_j=ylen, num_i=xlen)
    lons = get_grid_values(h5f, 'longitude', j_0, i_0, None, num_j=ylen, num_i=xlen)

    ll_2d, cc_2d = np.meshgrid(ll, cc, indexing='ij')

    lats = lats[ll_2d, cc_2d]
    lons = lons[ll_2d, cc_2d]

    return grd_dct, ll, cc, lats, lons


def make_for_full_domain_predict2(h5f, satellite='GOES16', domain='FD', res_fac=1):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = int(xlen/s_x)
    n_y = int(ylen/s_y)

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]
    satzen = satzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]

    return solzen, satzen
# -------------------------------------------------------------------------------------------


def prepare_evaluate(h5f, name_list, satellite='GOES16', domain='FD', res_fac=1, offset=0):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = w_x // res_fac
    s_y = w_y // res_fac

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = (xlen // s_x)
    n_y = (ylen // s_y)

    # r_x = xlen - (n_x * s_x)
    # x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    # n_x -= x_d
    #
    # r_y = ylen - (n_y * s_y)
    # y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    # n_y -= y_d

    ll = [(offset+j_0) + j*s_y for j in range(n_y)]
    cc = [(offset+i_0) + i*s_x for i in range(n_x)]

    grd_dct_n = {name: [] for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        fill_value, fill_value_name = get_fill_attrs(ds_name)
        gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        gvals = np.expand_dims(gvals, axis=0)
        if gvals is not None:
            grd_dct_n[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]
    satzen = satzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]

    return grd_dct_n, solzen, satzen, ll, cc


tomrink's avatar
tomrink committed
def prepare_evaluate1x1(h5f, name_list, satellite='GOES16', domain='FD', res_fac=1, offset=0):
    w_x = 1
    w_y = 1
    i_0 = 0
    j_0 = 0
    s_x = w_x // res_fac
    s_y = w_y // res_fac

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = (xlen // s_x)
    n_y = (ylen // s_y)

    ll = [(offset+j_0) + j*s_y for j in range(n_y)]
    cc = [(offset+i_0) + i*s_x for i in range(n_x)]

tomrink's avatar
tomrink committed
    grd_s = []
tomrink's avatar
tomrink committed

    cnt_a = 0
    for ds_name in name_list:
        fill_value, fill_value_name = get_fill_attrs(ds_name)
        gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
tomrink's avatar
tomrink committed
            grd_s.append(gvals.flatten())
tomrink's avatar
tomrink committed
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    cldmsk = get_grid_values(h5f, 'cloud_mask', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]
    satzen = satzen[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]
    cldmsk = cldmsk[0:n_y*s_y:s_y, 0:n_x*s_x:s_x]

tomrink's avatar
tomrink committed
    varX = np.stack(grd_s, axis=1)

    return varX, solzen.flatten(), satzen.flatten(), cldmsk.flatten(), ll, cc
tomrink's avatar
tomrink committed


tomrink's avatar
tomrink committed
flt_level_ranges_str = {k: None for k in range(6)}
tomrink's avatar
tomrink committed
flt_level_ranges_str[0] = '0_2000'
flt_level_ranges_str[1] = '2000_4000'
flt_level_ranges_str[2] = '4000_6000'
flt_level_ranges_str[3] = '6000_8000'
flt_level_ranges_str[4] = '8000_15000'
tomrink's avatar
tomrink committed
flt_level_ranges_str[-1] = 'CTH_layer'
tomrink's avatar
tomrink committed

# flt_level_ranges_str = {k: None for k in range(1)}
# flt_level_ranges_str[0] = 'column'


def write_icing_file(clvrx_str_time, output_dir, preds_dct, probs_dct, x, y, lons, lats, elems, lines):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.h5'
    h5f_out = h5py.File(outfile_name, 'w')

    dim_0_name = 'x'
    dim_1_name = 'y'

    prob_s = []
    pred_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]
        pred_s.append(preds)
        icing_pred_ds = h5f_out.create_dataset('icing_prediction_level_'+flt_level_ranges_str[flvl], data=preds, dtype='i2')
        icing_pred_ds.attrs.create('coordinates', data='y x')
        icing_pred_ds.attrs.create('grid_mapping', data='Projection')
        icing_pred_ds.attrs.create('missing', data=-1)
        icing_pred_ds.dims[0].label = dim_0_name
        icing_pred_ds.dims[1].label = dim_1_name

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)
        icing_prob_ds = h5f_out.create_dataset('icing_probability_level_'+flt_level_ranges_str[flvl], data=probs, dtype='f4')
        icing_prob_ds.attrs.create('coordinates', data='y x')
        icing_prob_ds.attrs.create('grid_mapping', data='Projection')
        icing_prob_ds.attrs.create('missing', data=-1.0)
        icing_prob_ds.dims[0].label = dim_0_name
        icing_prob_ds.dims[1].label = dim_1_name

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)

    icing_prob_ds = h5f_out.create_dataset('max_icing_probability_column', data=max_prob, dtype='f4')
    icing_prob_ds.attrs.create('coordinates', data='y x')
    icing_prob_ds.attrs.create('grid_mapping', data='Projection')
    icing_prob_ds.attrs.create('missing', data=-1.0)
    icing_prob_ds.dims[0].label = dim_0_name
    icing_prob_ds.dims[1].label = dim_1_name

    max_lvl = np.argmax(prob_s, axis=2)

    icing_pred_ds = h5f_out.create_dataset('max_icing_probability_level', data=max_lvl, dtype='i2')
    icing_pred_ds.attrs.create('coordinates', data='y x')
    icing_pred_ds.attrs.create('grid_mapping', data='Projection')
    icing_pred_ds.attrs.create('missing', data=-1)
    icing_pred_ds.dims[0].label = dim_0_name
    icing_pred_ds.dims[1].label = dim_1_name

    lon_ds = h5f_out.create_dataset('longitude', data=lons, dtype='f4')
    lon_ds.attrs.create('units', data='degrees_east')
    lon_ds.attrs.create('long_name', data='icing prediction longitude')
    lon_ds.dims[0].label = dim_0_name
    lon_ds.dims[1].label = dim_1_name

    lat_ds = h5f_out.create_dataset('latitude', data=lats, dtype='f4')
    lat_ds.attrs.create('units', data='degrees_north')
    lat_ds.attrs.create('long_name', data='icing prediction latitude')
    lat_ds.dims[0].label = dim_0_name
    lat_ds.dims[1].label = dim_1_name

    proj_ds = h5f_out.create_dataset('Projection', data=0, dtype='b')
    proj_ds.attrs.create('long_name', data='Himawari Imagery Projection')
    proj_ds.attrs.create('grid_mapping_name', data='geostationary')
    proj_ds.attrs.create('sweep_angle_axis', data='y')
    proj_ds.attrs.create('units', data='rad')
    proj_ds.attrs.create('semi_major_axis', data=6378.137)
    proj_ds.attrs.create('semi_minor_axis', data=6356.7523)
    proj_ds.attrs.create('inverse_flattening', data=298.257)
    proj_ds.attrs.create('perspective_point_height', data=35785.863)
    proj_ds.attrs.create('latitude_of_projection_origin', data=0.0)
    proj_ds.attrs.create('longitude_of_projection_origin', data=140.7)
    proj_ds.attrs.create('CFAC', data=20466275)
    proj_ds.attrs.create('LFAC', data=20466275)
    proj_ds.attrs.create('COFF', data=2750.5)
    proj_ds.attrs.create('LOFF', data=2750.5)

    if x is not None:
        x_ds = h5f_out.create_dataset('x', data=x, dtype='f8')
        x_ds.dims[0].label = dim_0_name
        x_ds.attrs.create('units', data='rad')
        x_ds.attrs.create('standard_name', data='projection_x_coordinate')
        x_ds.attrs.create('long_name', data='GOES PUG W-E fixed grid viewing angle')
        x_ds.attrs.create('scale_factor', data=5.58879902955962e-05)
        x_ds.attrs.create('add_offset', data=-0.153719917308037)
        x_ds.attrs.create('CFAC', data=20466275)
        x_ds.attrs.create('COFF', data=2750.5)

        y_ds = h5f_out.create_dataset('y', data=y, dtype='f8')
        y_ds.dims[0].label = dim_1_name
        y_ds.attrs.create('units', data='rad')
        y_ds.attrs.create('standard_name', data='projection_y_coordinate')
        y_ds.attrs.create('long_name', data='GOES PUG S-N fixed grid viewing angle')
        y_ds.attrs.create('scale_factor', data=-5.58879902955962e-05)
        y_ds.attrs.create('add_offset', data=0.153719917308037)
        y_ds.attrs.create('LFAC', data=20466275)
        y_ds.attrs.create('LOFF', data=2750.5)

    if elems is not None:
        elem_ds = h5f_out.create_dataset('elems', data=elems, dtype='i2')
        elem_ds.dims[0].label = dim_0_name
        line_ds = h5f_out.create_dataset('lines', data=lines, dtype='i2')
        line_ds.dims[0].label = dim_1_name
        pass

    h5f_out.close()


def write_icing_file_nc4(clvrx_str_time, output_dir, preds_dct, probs_dct,
                         x, y, lons, lats, elems, lines, satellite='GOES16', domain='CONUS',
tomrink's avatar
tomrink committed
                         has_time=False, use_nan=False, prob_thresh=0.5, bt_10_4=None, cld_top_hgt_max=None):
tomrink's avatar
tomrink committed
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.nc'
    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'time y x'

    dim_0 = rootgrp.createDimension(dim_0_name, size=x.shape[0])
    dim_1 = rootgrp.createDimension(dim_1_name, size=y.shape[0])
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    prob_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]

        icing_pred_ds = rootgrp.createVariable('icing_prediction_level_'+flt_level_ranges_str[flvl], 'i2', var_dim_list)
        icing_pred_ds.setncattr('coordinates', geo_coords)
        icing_pred_ds.setncattr('grid_mapping', 'Projection')
        icing_pred_ds.setncattr('missing', -1)
        if has_time:
            preds = preds.reshape((1, y.shape[0], x.shape[0]))
        icing_pred_ds[:,] = preds

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)

        icing_prob_ds = rootgrp.createVariable('icing_probability_level_'+flt_level_ranges_str[flvl], 'f4', var_dim_list)
        icing_prob_ds.setncattr('coordinates', geo_coords)
        icing_prob_ds.setncattr('grid_mapping', 'Projection')
        if not use_nan:
            icing_prob_ds.setncattr('missing', -1.0)
        else:
            icing_prob_ds.setncattr('missing', np.nan)
        if has_time:
            probs = probs.reshape((1, y.shape[0], x.shape[0]))
        if use_nan:
            probs = np.where(probs < prob_thresh, np.nan, probs)
        icing_prob_ds[:,] = probs

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)
    if use_nan:
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)
    if has_time:
        max_prob = max_prob.reshape(1, y.shape[0], x.shape[0])

    icing_prob_ds = rootgrp.createVariable('max_icing_probability_column', 'f4', var_dim_list)
    icing_prob_ds.setncattr('coordinates', geo_coords)
    icing_prob_ds.setncattr('grid_mapping', 'Projection')
    if not use_nan:
        icing_prob_ds.setncattr('missing', -1.0)
    else:
        icing_prob_ds.setncattr('missing', np.nan)
    icing_prob_ds[:,] = max_prob

    prob_s = np.where(prob_s < prob_thresh, -1.0, prob_s)
    max_lvl = np.where(np.all(prob_s == -1, axis=2), -1, np.argmax(prob_s, axis=2))
    if use_nan:
        max_lvl = np.where(max_lvl == -1, np.nan, max_lvl)
    if has_time:
        max_lvl = max_lvl.reshape((1, y.shape[0], x.shape[0]))

    icing_pred_ds = rootgrp.createVariable('max_icing_probability_level', 'i2', var_dim_list)
    icing_pred_ds.setncattr('coordinates', geo_coords)
    icing_pred_ds.setncattr('grid_mapping', 'Projection')
    icing_pred_ds.setncattr('missing', -1)
    icing_pred_ds[:,] = max_lvl

    if bt_10_4 is not None:
        bt_ds = rootgrp.createVariable('bt_10_4', 'f4', var_dim_list)
        bt_ds.setncattr('coordinates', geo_coords)
        bt_ds.setncattr('grid_mapping', 'Projection')
        bt_ds[:,] = bt_10_4

tomrink's avatar
tomrink committed
    if cld_top_hgt_max is not None:
        cth_ds = rootgrp.createVariable('cld_top_hgt_max', 'f4', var_dim_list)
        cth_ds.setncattr('coordinates', geo_coords)
        cth_ds.setncattr('grid_mapping', 'Projection')
        cth_ds.units = 'meter'
        cth_ds[:,] = cld_top_hgt_max

tomrink's avatar
tomrink committed
    lon_ds = rootgrp.createVariable('longitude', 'f4', [dim_1_name, dim_0_name])
    lon_ds.units = 'degrees_east'
    lon_ds[:,] = lons

    lat_ds = rootgrp.createVariable('latitude', 'f4', [dim_1_name, dim_0_name])
    lat_ds.units = 'degrees_north'
    lat_ds[:,] = lats

    cf_nav_dct = get_cf_nav_parameters(satellite, domain)

    if satellite == 'H08':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'H09':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'GOES16':
        long_name = 'GOES-16/17 Imagery Projection'

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('long_name', long_name)
    proj_ds.setncattr('grid_mapping_name', 'geostationary')
    proj_ds.setncattr('sweep_angle_axis', cf_nav_dct['sweep_angle_axis'])
    proj_ds.setncattr('semi_major_axis', cf_nav_dct['semi_major_axis'])
    proj_ds.setncattr('semi_minor_axis', cf_nav_dct['semi_minor_axis'])
    proj_ds.setncattr('inverse_flattening', cf_nav_dct['inverse_flattening'])
    proj_ds.setncattr('perspective_point_height', cf_nav_dct['perspective_point_height'])
    proj_ds.setncattr('latitude_of_projection_origin', cf_nav_dct['latitude_of_projection_origin'])
    proj_ds.setncattr('longitude_of_projection_origin', cf_nav_dct['longitude_of_projection_origin'])

    if x is not None:
        x_ds = rootgrp.createVariable(dim_0_name, 'f8', [dim_0_name])
        x_ds.units = 'rad'
        x_ds.setncattr('axis', 'X')
        x_ds.setncattr('standard_name', 'projection_x_coordinate')
        x_ds.setncattr('long_name', 'fixed grid viewing angle')
        x_ds.setncattr('scale_factor', cf_nav_dct['x_scale_factor'])
        x_ds.setncattr('add_offset', cf_nav_dct['x_add_offset'])
        x_ds[:] = x

        y_ds = rootgrp.createVariable(dim_1_name, 'f8', [dim_1_name])
        y_ds.units = 'rad'
        y_ds.setncattr('axis', 'Y')
        y_ds.setncattr('standard_name', 'projection_y_coordinate')
        y_ds.setncattr('long_name', 'fixed grid viewing angle')
        y_ds.setncattr('scale_factor', cf_nav_dct['y_scale_factor'])
        y_ds.setncattr('add_offset', cf_nav_dct['y_add_offset'])
        y_ds[:] = y

    if elems is not None:
        elem_ds = rootgrp.createVariable('elems', 'i2', [dim_0_name])
        elem_ds[:] = elems
        line_ds = rootgrp.createVariable('lines', 'i2', [dim_1_name])
        line_ds[:] = lines
        pass

    rootgrp.close()


def write_icing_file_nc4_viirs(clvrx_str_time, output_dir, preds_dct, probs_dct, lons, lats,
                               has_time=False, use_nan=False, prob_thresh=0.5, bt_10_4=None):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.nc'
    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'longitude latitude'

    dim_1_len, dim_0_len = lons.shape

    dim_0 = rootgrp.createDimension(dim_0_name, size=dim_0_len)
    dim_1 = rootgrp.createDimension(dim_1_name, size=dim_1_len)
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    prob_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]

        icing_pred_ds = rootgrp.createVariable('icing_prediction_level_'+flt_level_ranges_str[flvl], 'i2', var_dim_list)
        icing_pred_ds.setncattr('coordinates', geo_coords)
        icing_pred_ds.setncattr('grid_mapping', 'Projection')
        icing_pred_ds.setncattr('missing', -1)
        if has_time:
            preds = preds.reshape((1, dim_1_len, dim_0_len))
        icing_pred_ds[:,] = preds

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)

        icing_prob_ds = rootgrp.createVariable('icing_probability_level_'+flt_level_ranges_str[flvl], 'f4', var_dim_list)
        icing_prob_ds.setncattr('coordinates', geo_coords)
        icing_prob_ds.setncattr('grid_mapping', 'Projection')
        if not use_nan:
            icing_prob_ds.setncattr('missing', -1.0)
        else:
            icing_prob_ds.setncattr('missing', np.nan)
        if has_time:
            probs = probs.reshape((1, dim_1_len, dim_0_len))
        if use_nan:
            probs = np.where(probs < prob_thresh, np.nan, probs)
        icing_prob_ds[:,] = probs

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)
    if use_nan:
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)
    if has_time:
        max_prob = max_prob.reshape(1, dim_1_len, dim_0_len)

    icing_prob_ds = rootgrp.createVariable('max_icing_probability_column', 'f4', var_dim_list)
    icing_prob_ds.setncattr('coordinates', geo_coords)
    icing_prob_ds.setncattr('grid_mapping', 'Projection')
    if not use_nan:
        icing_prob_ds.setncattr('missing', -1.0)
    else:
        icing_prob_ds.setncattr('missing', np.nan)
    icing_prob_ds[:,] = max_prob

    prob_s = np.where(prob_s < prob_thresh, -1.0, prob_s)
    max_lvl = np.where(np.all(prob_s == -1, axis=2), -1, np.argmax(prob_s, axis=2))
    if use_nan:
        max_lvl = np.where(max_lvl == -1, np.nan, max_lvl)
    if has_time:
        max_lvl = max_lvl.reshape((1, dim_1_len, dim_0_len))

    icing_pred_ds = rootgrp.createVariable('max_icing_probability_level', 'i2', var_dim_list)
    icing_pred_ds.setncattr('coordinates', geo_coords)
    icing_pred_ds.setncattr('grid_mapping', 'Projection')
    icing_pred_ds.setncattr('missing', -1)
    icing_pred_ds[:,] = max_lvl

    if bt_10_4 is not None:
        bt_ds = rootgrp.createVariable('bt_10_4', 'f4', var_dim_list)
        bt_ds.setncattr('coordinates', geo_coords)
        bt_ds.setncattr('grid_mapping', 'Projection')
        bt_ds[:,] = bt_10_4

    lon_ds = rootgrp.createVariable('longitude', 'f4', [dim_1_name, dim_0_name])
    lon_ds.units = 'degrees_east'
    lon_ds[:,] = lons

    lat_ds = rootgrp.createVariable('latitude', 'f4', [dim_1_name, dim_0_name])
    lat_ds.units = 'degrees_north'
    lat_ds[:,] = lats

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('grid_mapping_name', 'latitude_longitude')

    rootgrp.close()


def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir,
                      day_model_path=model_path_day, night_model_path=model_path_night,
                      prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='AUTO',
tomrink's avatar
tomrink committed
                      l1b_andor_l2='both', use_flight_altitude=True, res_fac=1, use_nan=False,
tomrink's avatar
tomrink committed
                      has_time=False, model_type='FCN', use_dnb=False, use_max_cth_level=False):
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
    if model_type == 'CNN':
tomrink's avatar
tomrink committed
        import deeplearning.icing_cnn as icing_cnn
tomrink's avatar
tomrink committed
        model_module = icing_cnn
    elif model_type == 'FCN':
tomrink's avatar
tomrink committed
        import deeplearning.icing_fcn as icing_fcn
tomrink's avatar
tomrink committed
        model_module = icing_fcn

tomrink's avatar
tomrink committed
    if day_model_path is not None:
        day_model = model_module.load_model(day_model_path, day_night='DAY', l1b_andor_l2=l1b_andor_l2,
                                            use_flight_altitude=use_flight_altitude)
    if night_model_path is not None:
        night_model = model_module.load_model(night_model_path, day_night='NIGHT', l1b_andor_l2=l1b_andor_l2,
                                              use_flight_altitude=use_flight_altitude)

tomrink's avatar
tomrink committed
    flight_levels = [0]
tomrink's avatar
tomrink committed
    if use_flight_altitude is True:
        flight_levels = [0, 1, 2, 3, 4]
tomrink's avatar
tomrink committed
        if use_max_cth_level:
            flight_levels = [-1]
tomrink's avatar
tomrink committed
    day_train_params, _, _ = get_training_parameters(day_night='DAY', l1b_andor_l2=l1b_andor_l2)
    nght_train_params, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2, use_dnb=False)
    nght_train_params_dnb, _, _ = get_training_parameters(day_night='NIGHT', l1b_andor_l2=l1b_andor_l2, use_dnb=True)

    if day_night == 'AUTO':
        train_params = list(set(day_train_params + nght_train_params))
tomrink's avatar
tomrink committed
        train_params_dnb = list(set(day_train_params + nght_train_params_dnb))
    elif day_night == 'DAY':
        train_params = day_train_params
    elif day_night == 'NIGHT':
        train_params = nght_train_params
tomrink's avatar
tomrink committed
        train_params_dnb = nght_train_params_dnb

    if satellite == 'H08':
        clvrx_ds = CLAVRx_H08(clvrx_dir)
tomrink's avatar
tomrink committed
    elif satellite == 'H09':
        clvrx_ds = CLAVRx_H09(clvrx_dir)
tomrink's avatar
tomrink committed
    elif satellite == 'GOES16':
        clvrx_ds = CLAVRx(clvrx_dir)
tomrink's avatar
tomrink committed
    elif satellite == 'VIIRS':
        clvrx_ds = CLAVRx_VIIRS(clvrx_dir)

    clvrx_files = clvrx_ds.flist

    for fidx, fname in enumerate(clvrx_files):
        h5f = h5py.File(fname, 'r')
        dto = clvrx_ds.get_datetime(fname)
        ts = dto.timestamp()
tomrink's avatar
tomrink committed
        lunar_illuminated = moon_phase(ts)
        clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')

tomrink's avatar
tomrink committed
        if satellite == 'GOES16' or satellite == 'H08' or satellite == 'H09':
tomrink's avatar
tomrink committed
            data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain, res_fac=res_fac)
tomrink's avatar
tomrink committed
            if fidx == 0:  # These don't change for geostationary fixed grids
                nav = get_navigation(satellite, domain)
                lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc, offset=int(8 / res_fac))
tomrink's avatar
tomrink committed
            bt_fld_name = 'temp_10_4um_nom'

tomrink's avatar
tomrink committed
            ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=
tomrink's avatar
tomrink committed
                                ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick', bt_fld_name],
tomrink's avatar
tomrink committed
                                satellite=satellite, domain=domain, res_fac=res_fac)
        elif satellite == 'VIIRS':
tomrink's avatar
tomrink committed
            if use_dnb and lunar_illuminated:
tomrink's avatar
tomrink committed
                train_params = train_params_dnb
tomrink's avatar
tomrink committed
            data_dct, ll, cc, lats_2d, lons_2d = make_for_full_domain_predict_viirs_clavrx(h5f, name_list=train_params,
tomrink's avatar
tomrink committed
                                                    day_night=day_night, res_fac=res_fac, use_dnb=use_dnb)
tomrink's avatar
tomrink committed

            bt_fld_name = 'temp_11_0um_nom'
tomrink's avatar
tomrink committed
            ancil_data_dct, _, _, _, _ = make_for_full_domain_predict_viirs_clavrx(h5f, name_list=
tomrink's avatar
tomrink committed
                                ['solar_zenith_angle', 'sensor_zenith_angle', 'cld_height_acha', 'cld_geo_thick', bt_fld_name],
tomrink's avatar
tomrink committed
                                res_fac=res_fac)

        num_elems, num_lines = len(cc), len(ll)

        satzen = ancil_data_dct['sensor_zenith_angle']
        solzen = ancil_data_dct['solar_zenith_angle']
tomrink's avatar
tomrink committed
        cth = ancil_data_dct['cld_height_acha']
tomrink's avatar
tomrink committed
        bt_10_4 = ancil_data_dct[bt_fld_name]
        day_idxs = []
        nght_idxs = []
tomrink's avatar
tomrink committed
        all_idxs = []
tomrink's avatar
tomrink committed
        avg_bt = []
tomrink's avatar
tomrink committed
        day_cth_max = []
        nght_cth_max = []
tomrink's avatar
tomrink committed
        cth_max = []
tomrink's avatar
tomrink committed
        cth_avg = []
tomrink's avatar
tomrink committed

        for j in range(num_lines):
            for i in range(num_elems):
                k = i + j*num_elems
tomrink's avatar
tomrink committed
                avg_bt.append(get_median(bt_10_4[k]))
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
                c = cth[k].flatten()
                c_m = np.mean(np.sort(c[np.invert(np.isnan(c))])[-2:])
tomrink's avatar
tomrink committed
                cth_max.append(c_m)
tomrink's avatar
tomrink committed
                cth_avg.append(np.mean(c[np.invert(np.isnan(c))]))
tomrink's avatar
tomrink committed

                c_m = 0 if 2000 > c_m >= 0 else c_m
                c_m = 1 if 4000 > c_m >= 2000 else c_m
                c_m = 2 if 6000 > c_m >= 4000 else c_m
                c_m = 3 if 8000 > c_m >= 6000 else c_m
                c_m = 4 if 15000 > c_m >= 8000 else c_m

                if not check_oblique(satzen[k]):
                    continue
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
                all_idxs.append(k)
                if is_day(solzen[k]):
                    day_idxs.append(k)
tomrink's avatar
tomrink committed
                    day_cth_max.append(c_m)
                else:
                    nght_idxs.append(k)
tomrink's avatar
tomrink committed
                    nght_cth_max.append(c_m)
tomrink's avatar
tomrink committed
        num_tiles = len(all_idxs)
        num_day_tiles = len(day_idxs)
        num_nght_tiles = len(nght_idxs)
tomrink's avatar
tomrink committed
        day_cth_max = np.array(day_cth_max)
        nght_cth_max = np.array(nght_cth_max)

        # initialize output arrays
        probs_2d_dct = {flvl: None for flvl in flight_levels}
        preds_2d_dct = {flvl: None for flvl in flight_levels}
        for flvl in flight_levels:
            fd_preds = np.zeros(num_lines * num_elems, dtype=np.int8)
            fd_preds[:] = -1
            fd_probs = np.zeros(num_lines * num_elems, dtype=np.float32)
            fd_probs[:] = -1.0
            preds_2d_dct[flvl] = fd_preds
            probs_2d_dct[flvl] = fd_probs

tomrink's avatar
tomrink committed
        if day_night == 'AUTO' or day_night == 'DAY':
            if num_day_tiles > 0:
tomrink's avatar
tomrink committed
                day_idxs = np.array(day_idxs)
tomrink's avatar
tomrink committed
                day_grd_dct = {name: None for name in day_train_params}
tomrink's avatar
tomrink committed
                for name in day_train_params:
                    day_grd_dct[name] = data_dct[name][day_idxs, ]

tomrink's avatar
tomrink committed
                if use_max_cth_level:
                    day_grd_dct['cth_high_avg'] = day_cth_max
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
                preds_day_dct, probs_day_dct = \
tomrink's avatar
tomrink committed
                    model_module.run_evaluate_static_2(day_model, day_grd_dct, num_day_tiles,
                                                       prob_thresh=prob_thresh,
                                                       flight_levels=flight_levels)
tomrink's avatar
tomrink committed
                for flvl in flight_levels:
                    day_preds = preds_day_dct[flvl]
                    day_probs = probs_day_dct[flvl]
                    fd_preds = preds_2d_dct[flvl]
                    fd_probs = probs_2d_dct[flvl]
                    fd_preds[day_idxs] = day_preds[:]
                    fd_probs[day_idxs] = day_probs[:]

            if num_nght_tiles > 0:
                mode = 'NIGHT'
                model_path = night_model_path
                if use_dnb and lunar_illuminated:
                    model_path = day_model_path
                    mode = 'DAY'
                    nght_train_params = train_params_dnb

tomrink's avatar
tomrink committed
                nght_idxs = np.array(nght_idxs)
tomrink's avatar
tomrink committed
                nght_grd_dct = {name: None for name in nght_train_params}
tomrink's avatar
tomrink committed
                for name in nght_train_params:
                    nght_grd_dct[name] = data_dct[name][nght_idxs, ]

tomrink's avatar
tomrink committed
                if use_max_cth_level:
                    nght_grd_dct['cth_high_avg'] = nght_cth_max
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
                preds_nght_dct, probs_nght_dct = \
                    model_module.run_evaluate_static_2(night_model, nght_grd_dct, num_nght_tiles,
                                                       prob_thresh=prob_thresh,
tomrink's avatar
tomrink committed
                                                       flight_levels=flight_levels)
tomrink's avatar
tomrink committed
                for flvl in flight_levels:
                    nght_preds = preds_nght_dct[flvl]
                    nght_probs = probs_nght_dct[flvl]
                    fd_preds = preds_2d_dct[flvl]
                    fd_probs = probs_2d_dct[flvl]
                    fd_preds[nght_idxs] = nght_preds[:]
                    fd_probs[nght_idxs] = nght_probs[:]

        elif day_night == 'NIGHT':
tomrink's avatar
tomrink committed
            model_path = night_model_path
tomrink's avatar
tomrink committed
            mode = 'NIGHT'
tomrink's avatar
tomrink committed
            if use_dnb and lunar_illuminated:
tomrink's avatar
tomrink committed
                model_path = day_model_path
tomrink's avatar
tomrink committed
                mode = 'DAY'
tomrink's avatar
tomrink committed
                nght_train_params = train_params_dnb
tomrink's avatar
tomrink committed
            all_idxs = np.array(all_idxs)
            nght_grd_dct = {name: None for name in nght_train_params}
tomrink's avatar
tomrink committed
            for name in nght_train_params:
                nght_grd_dct[name] = data_dct[name][all_idxs,]

tomrink's avatar
tomrink committed
            if use_max_cth_level:
                nght_grd_dct['cth_high_avg'] = nght_cth_max

            preds_nght_dct, probs_nght_dct = \
                model_module.run_evaluate_static_2(night_model, nght_grd_dct, num_nght_tiles,
                                                   prob_thresh=prob_thresh,
tomrink's avatar
tomrink committed
                                                   flight_levels=flight_levels)
            for flvl in flight_levels:
                nght_preds = preds_nght_dct[flvl]
                nght_probs = probs_nght_dct[flvl]
                fd_preds = preds_2d_dct[flvl]
                fd_probs = probs_2d_dct[flvl]
tomrink's avatar
tomrink committed
                fd_preds[all_idxs] = nght_preds[:]
                fd_probs[all_idxs] = nght_probs[:]
tomrink's avatar
tomrink committed
        # combine day and night into full grid  ------------------------------------------
        for flvl in flight_levels:
            fd_preds = preds_2d_dct[flvl]
            fd_probs = probs_2d_dct[flvl]
            preds_2d_dct[flvl] = fd_preds.reshape((num_lines, num_elems))
            probs_2d_dct[flvl] = fd_probs.reshape((num_lines, num_elems))

tomrink's avatar
tomrink committed
        avg_bt = np.array(avg_bt)
        bt_10_4_2d = avg_bt.reshape((num_lines, num_elems))
tomrink's avatar
tomrink committed

tomrink's avatar
tomrink committed
        cth_max = np.array(cth_max)
        cth_max = cth_max.reshape((num_lines, num_elems))
tomrink's avatar
tomrink committed
        cth_avg = np.array(cth_avg)