Skip to content
Snippets Groups Projects
pirep_goes.py 101.89 KiB
from icing.pireps import pirep_icing
import numpy as np
import pickle
import os
from util.util import get_time_tuple_utc, GenericException, add_time_range_to_filename, is_night, is_day, \
    check_oblique, get_timestamp, homedir, get_indexes_within_threshold
from aeolus.datasource import CLAVRx, CLAVRx_VIIRS, GOESL1B, CLAVRx_H08
import h5py
import datetime
from datetime import timezone
import glob
from skyfield import api, almanac

goes_date_format = '%Y%j%H'
goes16_directory = '/arcdata/goes/grb/goes16'  # /year/date/abi/L1b/RadC
clavrx_dir = '/ships19/cloud/scratch/ICING/'
# next one appears to be gone. Keeping this here for reference.
clavrx_viirs_dir = '/apollo/cloud/scratch/Satellite_Output/NASA-SNPP_VIIRS/global/2019_DNB_for_Rink_wDBfix/level2_h5/'
clavrx_test_dir = '/data/Personal/rink/clavrx/'
dir_fmt = '%Y_%m_%d_%j'
# dir_list = [f.path for f in os.scandir('.') if f.is_dir()]
ds_dct = {}
goes_ds_dct = {}


# --- CLAVRx Radiometric parameters and metadata ------------------------------------------------
l1b_ds_list = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
               'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
               'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
l1b_ds_types = ['f4' for ds in l1b_ds_list]
l1b_ds_fill = [-32767 for i in range(10)] + [-32768 for i in range(5)]
l1b_ds_range = ['actual_range' for ds in l1b_ds_list]

# --- CLAVRx L2 parameters and metadata
ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'sensor_zenith_angle', 'supercooled_prob_acha',
           'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha', 'solar_zenith_angle',
           'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
           'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
           'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction', 'cloud_type', 'cloud_phase', 'cloud_mask']
ds_types = ['f4' for i in range(23)] + ['i1' for i in range(3)]
ds_fill = [-32768 for i in range(23)] + [-128 for i in range(3)]
ds_range = ['actual_range' for i in range(23)] + [None for i in range(3)]
# --------------------------------------------
# --- CLAVRx VIIRS L2 parameters and metadata
# ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'sensor_zenith_angle', 'supercooled_prob_acha',
#            'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha', 'solar_zenith_angle',
#            'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
#            'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
#            'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction', 'cld_opd_nlcomp', 'cld_reff_nlcomp', 'cloud_type',
#            'cloud_phase', 'cloud_mask']
# ds_types = ['f4' for i in range(25)] + ['i1' for i in range(3)]
# ds_fill = [-32768 for i in range(25)] + [-128 for i in range(3)]
# ds_range = ['actual_range' for i in range(25)] + [None for i in range(3)]
# --------------------------------------------


# An example file for accessing and copying metadata
a_clvr_file = homedir+'data/clavrx/clavrx_OR_ABI-L1b-RadC-M3C01_G16_s20190020002186.level2.nc'
#a_clvr_file = homedir+'data/clavrx/RadC/265/clavrx_OR_ABI-L1b-RadC-M6C01_G16_s20212651711172.level2.nc'
# VIIRS
#a_clvr_file = homedir+'data/clavrx/clavrx_snpp_viirs.A2019071.0000.001.2019071061610.uwssec_B00038187.level2.h5'

# Location of files for tile/FOV extraction
# icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/icing_2*_DAY.h5')]
# icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/icing_l1b_2*_DAY.h5')]
# icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/icing_l1b_2*_ANY.h5')]
icing_l1b_files = []

# no_icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/no_icing_2*_DAY.h5')]
# no_icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/no_icing_l1b_2*_DAY.h5')]
# no_icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/no_icing_l1b_2*_ANY.h5')]
no_icing_l1b_files = []


train_params_day = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_press_acha',
                    'solar_zenith_angle', 'cld_reff_dcomp', 'cld_opd_dcomp', 'cld_cwp_dcomp', 'iwc_dcomp', 'lwc_dcomp',
                    'cloud_phase', 'cloud_mask']

train_params_night = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha',
                      'cld_press_acha', 'cld_reff_acha', 'cld_opd_acha', 'cloud_phase', 'cloud_mask']

HALF_WIDTH = 20

# Just a wrapper for the conversion of PIREP reports to in-memory dictionary
def setup(pirep_file=homedir+'data/pirep/pireps_20180101_20200331.csv'):
    ice_dict, no_ice_dict, neg_ice_dict = pirep_icing(pirep_file)
    return ice_dict, no_ice_dict, neg_ice_dict


def get_clavrx_datasource(timestamp, platform, file_time_span=10):
    if platform == 'GOES':
        return get_clavrx_datasource_goes(timestamp, file_time_span=file_time_span)
    elif platform == 'VIIRS':
        return get_clavrx_datasource_viirs(timestamp)


def get_clavrx_datasource_goes(timestamp, file_time_span=10):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)
    date_dir_str = dt_obj.strftime(dir_fmt)
    ds = ds_dct.get(date_dir_str)
    if ds is None:
        ds = CLAVRx(clavrx_dir + date_dir_str + '/', file_time_span=file_time_span)
        ds_dct[date_dir_str] = ds
    return ds


def get_clavrx_datasource_viirs(timestamp):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)
    date_dir_str = dt_obj.strftime('%j')
    ds = ds_dct.get(date_dir_str)
    if ds is None:
        ds = CLAVRx_VIIRS(clavrx_viirs_dir + date_dir_str + '/')
        ds_dct[date_dir_str] = ds
    return ds


def get_goes_datasource(timestamp):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)

    yr_dir = str(dt_obj.timetuple().tm_year)
    date_dir = dt_obj.strftime(dir_fmt)

    files_path = goes16_directory + '/' + yr_dir + '/' + date_dir + '/abi' + '/L1b' + '/RadC/'
    ds = goes_ds_dct.get(date_dir)
    if ds is None:
        ds = GOESL1B(files_path)
        goes_ds_dct[date_dir] = ds
    return ds


def get_grid_values(h5f, grid_name, j_c, i_c, half_width, num_j=None, num_i=None, scale_factor_name='scale_factor', add_offset_name='add_offset',
                    fill_value_name='_FillValue', range_name='actual_range', fill_value=None):
    try:
        hfds = h5f[grid_name]
    except Exception:
        return None
    attrs = hfds.attrs
    if attrs is None:
        raise GenericException('No attributes object for: '+grid_name)

    ylen, xlen = hfds.shape

    if half_width is not None:
        j_l = j_c-half_width
        i_l = i_c-half_width
        if j_l < 0 or i_l < 0:
            return None

        j_r = j_c+half_width+1
        i_r = i_c+half_width+1
        if j_r >= ylen or i_r >= xlen:
            return None
    else:
        j_l = j_c
        j_r = j_c + num_j + 1
        i_l = i_c
        i_r = i_c + num_i + 1

    grd_vals = hfds[j_l:j_r, i_l:i_r]
    if fill_value is not None:
        grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)

    if scale_factor_name is not None:
        attr = attrs.get(scale_factor_name)
        if attr is not None:
            if np.isscalar(attr):
                scale_factor = attr
            else:
                scale_factor = attr[0]
            grd_vals = grd_vals * scale_factor

    if add_offset_name is not None:
        attr = attrs.get(add_offset_name)
        if attr is not None:
            if np.isscalar(attr):
                add_offset = attr
            else:
                add_offset = attr[0]
            grd_vals = grd_vals + add_offset

    if range_name is not None:
        attr = attrs.get(range_name)
        if attr is None:
            raise GenericException('Attribute: '+range_name+' not found for dataset: '+grid_name)
        low = attr[0]
        high = attr[1]
        grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
        grd_vals = np.where(grd_vals > high, np.nan, grd_vals)
    elif fill_value_name is not None:
        attr = attrs.get(fill_value_name)
        if attr is None:
            raise GenericException('Attribute: '+fill_value_name+' not found for dataset: '+grid_name)
        if np.isscalar(attr):
            fill_value = attr
        else:
            fill_value = attr[0]
        grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)

    return grd_vals


def create_file(filename, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, icing_intensity, unq_ids, mask=None):
    h5f_expl = h5py.File(a_clvr_file, 'r')
    h5f = h5py.File(filename, 'w')

    for idx, ds_name in enumerate(ds_list):
        data = data_dct[ds_name]
        h5f.create_dataset(ds_name, data=data, dtype=ds_types[idx])

    lon_ds = h5f.create_dataset('longitude', data=lon_c, dtype='f4')
    lon_ds.dims[0].label = 'time'
    lon_ds.attrs.create('units', data='degrees_east')
    lon_ds.attrs.create('long_name', data='PIREP longitude')

    lat_ds = h5f.create_dataset('latitude', data=lat_c, dtype='f4')
    lat_ds.dims[0].label = 'time'
    lat_ds.attrs.create('units', data='degrees_north')
    lat_ds.attrs.create('long_name', data='PIREP latitude')

    time_ds = h5f.create_dataset('time', data=time_s)
    time_ds.dims[0].label = 'time'
    time_ds.attrs.create('units', data='seconds since 1970-1-1 00:00:00')
    time_ds.attrs.create('long_name', data='PIREP time')

    ice_alt_ds = h5f.create_dataset('icing_altitude', data=fl_alt_s, dtype='f4')
    ice_alt_ds.dims[0].label = 'time'
    ice_alt_ds.attrs.create('units', data='m')
    ice_alt_ds.attrs.create('long_name', data='PIREP altitude')

    if icing_intensity is not None:
        icing_int_ds = h5f.create_dataset('icing_intensity', data=icing_intensity, dtype='i4')
        icing_int_ds.attrs.create('long_name', data='From PIREP. 0:No intensity report, 1:Trace, 2:Light, 3:Light Moderate, 4:Moderate, 5:Moderate Severe, 6:Severe')

    unq_ids_ds = h5f.create_dataset('unique_id', data=unq_ids, dtype='i4')
    unq_ids_ds.attrs.create('long_name', data='ID mapping to PIREP icing dictionary: see pireps.py')

    if mask is not None:
        mask = mask.astype(np.byte)
        mask_ds = h5f.create_dataset('FOV_mask', data=mask, dtype=np.byte)
        mask_ds.attrs.create('long_name', data='The FOVs which pass the cloudy icing report test')
        mask_ds.dims[0].label = 'time'
        mask_ds.dims[1].label = 'y'
        mask_ds.dims[2].label = 'x'

    # copy relevant attributes
    for ds_name in ds_list:
        h5f_ds = h5f[ds_name]
        h5f_ds.attrs.create('standard_name', data=h5f_expl[ds_name].attrs.get('standard_name'))
        h5f_ds.attrs.create('long_name', data=h5f_expl[ds_name].attrs.get('long_name'))
        h5f_ds.attrs.create('units', data=h5f_expl[ds_name].attrs.get('units'))

        h5f_ds.dims[0].label = 'time'
        h5f_ds.dims[1].label = 'y'
        h5f_ds.dims[2].label = 'x'

    h5f.close()
    h5f_expl.close()


def run(pirep_dct, platform, outfile=None, outfile_l1b=None, dt_str_start=None, dt_str_end=None, skip=True, file_time_span=10):
    time_keys = list(pirep_dct.keys())
    l1b_grd_dct = {name: [] for name in l1b_ds_list}
    ds_grd_dct = {name: [] for name in ds_list}

    t_start = None
    t_end = None
    if (dt_str_start is not None) and (dt_str_end is not None):
        dto = datetime.datetime.strptime(dt_str_start, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
        dto.replace(tzinfo=timezone.utc)
        t_start = dto.timestamp()

        dto = datetime.datetime.strptime(dt_str_end, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
        dto.replace(tzinfo=timezone.utc)
        t_end = dto.timestamp()

    lon_s = np.zeros(1)
    lat_s = np.zeros(1)
    last_clvr_file = None
    last_h5f = None
    nav = None

    lon_c = []
    lat_c = []
    time_s = []
    fl_alt_s = []
    ice_int_s = []
    unq_ids = []

    for idx, time in enumerate(time_keys):
        if t_start is not None:
            if time < t_start:
                continue
            if time > t_end:
                continue

        try:
            clvr_ds = get_clavrx_datasource(time, platform, file_time_span=file_time_span)
        except Exception:
            print('run: Problem retrieving Datasource', get_time_tuple_utc(time))
            continue

        clvr_file = clvr_ds.get_file(time)[0]
        if clvr_file is None:
            continue

        if clvr_file != last_clvr_file:
            try:
                h5f = h5py.File(clvr_file, 'r')
                nav = clvr_ds.get_navigation(h5f)
            except Exception:
                if h5f is not None:
                    h5f.close()
                print('Problem with file: ', clvr_file)
                continue
            if last_h5f is not None:
                last_h5f.close()
            last_h5f = h5f
            last_clvr_file = clvr_file
        else:
            h5f = last_h5f

        cc = ll = -1
        reports = pirep_dct[time]
        for tup in reports:
            lat, lon, fl, I, uid, rpt_str = tup
            lat_s[0] = lat
            lon_s[0] = lon

            cc_a, ll_a = nav.earth_to_lc_s(lon_s, lat_s)  # non-navigable, skip
            if cc_a[0] < 0 or ll_a[0] < 0:
                continue

            if skip and (cc_a[0] == cc and ll_a[0] == ll):  # time adjacent duplicate, skip
                continue
            else:
                cc = cc_a[0]
                ll = ll_a[0]

            cnt_a = 0
            ds_lst = []
            for didx, ds_name in enumerate(ds_list):
                gvals = get_grid_values(h5f, ds_name, ll_a[0], cc_a[0], HALF_WIDTH, fill_value_name=None, range_name=ds_range[didx], fill_value=ds_fill[didx])
                if gvals is not None:
                    ds_lst.append(gvals)
                    cnt_a += 1

            cnt_b = 0
            l1b_lst = []
            for didx, ds_name in enumerate(l1b_ds_list):
                gvals = get_grid_values(h5f, ds_name, ll_a[0], cc_a[0], HALF_WIDTH, fill_value_name=None, range_name=l1b_ds_range[didx], fill_value=l1b_ds_fill[didx])
                if gvals is not None:
                    l1b_lst.append(gvals)
                    cnt_b += 1

            if cnt_a > 0 and cnt_a != len(ds_list):
                continue

            if cnt_b > 0 and cnt_b != len(l1b_ds_list):
                continue

            if cnt_a == len(ds_list) and cnt_b == len(l1b_ds_list):
                for didx, ds_name in enumerate(ds_list):
                    ds_grd_dct[ds_name].append(ds_lst[didx])

                for didx, ds_name in enumerate(l1b_ds_list):
                    l1b_grd_dct[ds_name].append(l1b_lst[didx])

                lon_c.append(lon_s[0])
                lat_c.append(lat_s[0])
                time_s.append(time)
                fl_alt_s.append(fl)
                ice_int_s.append(I)
                unq_ids.append(uid)

    if len(time_s) == 0:
        return

    t_start = time_s[0]
    t_end = time_s[len(time_s)-1]

    lon_c = np.array(lon_c)
    lat_c = np.array(lat_c)
    time_s = np.array(time_s)
    fl_alt_s = np.array(fl_alt_s)
    ice_int_s = np.array(ice_int_s)
    unq_ids = np.array(unq_ids)

    data_dct = {}
    for ds_name in ds_list:
        data_dct[ds_name] = np.array(ds_grd_dct[ds_name])

    if outfile is not None:
        outfile = add_time_range_to_filename(outfile, t_start, t_end)
        create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)

    data_dct = {}
    for ds_name in l1b_ds_list:
        data_dct[ds_name] = np.array(l1b_grd_dct[ds_name])

    if outfile_l1b is not None:
        outfile_l1b = add_time_range_to_filename(outfile_l1b, t_start, t_end)
        create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)


def run_viirs(pirep_dct, platform='VIIRS', outfile=None, outfile_l1b=None, dt_str_start=None, dt_str_end=None):
    time_keys = list(pirep_dct.keys())
    l1b_grd_dct = {name: [] for name in l1b_ds_list}
    ds_grd_dct = {name: [] for name in ds_list}

    t_start = None
    t_end = None
    if (dt_str_start is not None) and (dt_str_end is not None):
        dto = datetime.datetime.strptime(dt_str_start, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
        dto.replace(tzinfo=timezone.utc)
        t_start = dto.timestamp()

        dto = datetime.datetime.strptime(dt_str_end, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
        dto.replace(tzinfo=timezone.utc)
        t_end = dto.timestamp()

    lon_s = np.zeros(1)
    lat_s = np.zeros(1)
    last_clvr_file = None
    last_h5f = None
    nav = None

    lon_c = []
    lat_c = []
    time_s = []
    fl_alt_s = []
    ice_int_s = []
    unq_ids = []
    num_rpts = 0
    num_rpts_match = 0

    for idx, time in enumerate(time_keys):
        if t_start is not None:
            if time < t_start:
                continue
            if time > t_end:
                continue
        num_rpts += 1

        try:
            clvr_ds = get_clavrx_datasource(time, platform)
        except Exception:
            print('run: Problem retrieving Datasource')
            continue

        clvr_file = clvr_ds.get_file_containing_time(time)[0]
        if clvr_file is None:
            continue

        if clvr_file != last_clvr_file:
            try:
                h5f = h5py.File(clvr_file, 'r')
                nav = clvr_ds.get_navigation(h5f)
            except Exception:
                if h5f is not None:
                    h5f.close()
                print('Problem with file: ', clvr_file)
                continue
            if last_h5f is not None:
                last_h5f.close()
            last_h5f = h5f
            last_clvr_file = clvr_file
        else:
            h5f = last_h5f

        cc = ll = -1
        reports = pirep_dct[time]
        for tup in reports:
            lat, lon, fl, I, uid, rpt_str = tup
            lat_s[0] = lat
            lon_s[0] = lon
            print('        ',lon, lat)
            if not nav.check_inside(lon, lat):
                print('missed range check')
                continue

            cc_a, ll_a = nav.earth_to_lc_s(lon_s, lat_s)  # non-navigable, skip
            if cc_a[0] < 0:
                print('cant navigate')
                continue

            if cc_a[0] == cc and ll_a[0] == ll:  # time adjacent duplicate, skip
                continue
            else:
                cc = cc_a[0]
                ll = ll_a[0]

            cnt_a = 0
            for didx, ds_name in enumerate(ds_list):
                gvals = get_grid_values(h5f, ds_name, ll_a[0], cc_a[0], 20, fill_value_name=None, range_name=ds_range[didx], fill_value=ds_fill[didx])
                if gvals is not None:
                    ds_grd_dct[ds_name].append(gvals)
                    cnt_a += 1

            cnt_b = 0
            for didx, ds_name in enumerate(l1b_ds_list):
                gvals = get_grid_values(h5f, ds_name, ll_a[0], cc_a[0], 20, fill_value_name=None, range_name=l1b_ds_range[didx], fill_value=l1b_ds_fill[didx])
                if gvals is not None:
                    l1b_grd_dct[ds_name].append(gvals)
                    cnt_b += 1

            if cnt_a > 0 and cnt_a != len(ds_list):
                raise GenericException('weirdness')
            if cnt_b > 0 and cnt_b != len(l1b_ds_list):
                raise GenericException('weirdness')

            if cnt_a == len(ds_list) and cnt_b == len(l1b_ds_list):
                lon_c.append(lon_s[0])
                lat_c.append(lat_s[0])
                time_s.append(time)
                fl_alt_s.append(fl)
                ice_int_s.append(I)
                unq_ids.append(uid)
    print('num reports: ', num_rpts)
    if len(time_s) == 0:
        return

    t_start = time_s[0]
    t_end = time_s[len(time_s)-1]

    data_dct = {}
    for ds_name in ds_list:
        data_dct[ds_name] = np.array(ds_grd_dct[ds_name])
    lon_c = np.array(lon_c)
    lat_c = np.array(lat_c)
    time_s = np.array(time_s)
    fl_alt_s = np.array(fl_alt_s)
    ice_int_s = np.array(ice_int_s)
    unq_ids = np.array(unq_ids)

    if outfile is not None:
        outfile = add_time_range_to_filename(outfile, t_start, t_end)
        create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)

    data_dct = {}
    for ds_name in l1b_ds_list:
        data_dct[ds_name] = np.array(l1b_grd_dct[ds_name])

    if outfile_l1b is not None:
        outfile_l1b = add_time_range_to_filename(outfile_l1b, t_start, t_end)
        create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)


def pirep_info(pirep_dct):
    time_keys = list(pirep_dct.keys())

    lat_s = []
    lon_s = []
    flt_lvl_s = []
    ice_intensity_s = []
    for tkey in time_keys:
        reports = pirep_dct[tkey]
        for tup in reports:
            lat, lon, fl, I, uid, rpt_str = tup
            lat_s.append(lat)
            lon_s.append(lon)
            flt_lvl_s.append(fl)
            ice_intensity_s.append(I)

    lat_s = np.array(lat_s)
    lon_s = np.array(lon_s)
    flt_lvl_s = np.array(flt_lvl_s)
    ice_intensity_s = np.array(ice_intensity_s)

    return flt_lvl_s, ice_intensity_s, lat_s, lon_s


lon_space_hdeg = np.linspace(-180, 180, 721)
lat_space_hdeg = np.linspace(-90, 90, 361)
hgt_space_3000 = np.linspace(0, 15000, 3000)


def check_no_overlap(lon, lat, ts, grd_bins, t_delta=600.0):

    grd_x_hi = lon_space_hdeg.shape[0] - 1
    grd_y_hi = lat_space_hdeg.shape[0] - 1

    lon_idx = np.searchsorted(lon_space_hdeg, lon)
    lat_idx = np.searchsorted(lat_space_hdeg, lat)

    if lon_idx < 0 or lon_idx > grd_x_hi:
        return False
    if lat_idx < 0 or lat_idx > grd_y_hi:
        return False

    last_ts = grd_bins[lat_idx, lon_idx]
    if ts - last_ts > t_delta:
        grd_bins[lat_idx, lon_idx] = ts
        return True
    else:
        return False


def check_no_overlap_alt(lon, lat, hgt, ts, grd_bins, t_delta=600.0):

    grd_x_hi = lon_space_hdeg.shape[0] - 1
    grd_y_hi = lat_space_hdeg.shape[0] - 1
    grd_z_hi = hgt_space_3000.shape[0] - 1

    lon_idx = np.searchsorted(lon_space_hdeg, lon)
    lat_idx = np.searchsorted(lat_space_hdeg, lat)
    hgt_idx = np.searchsorted(hgt_space_3000, hgt)

    if lon_idx < 0 or lon_idx > grd_x_hi:
        return False
    if lat_idx < 0 or lat_idx > grd_y_hi:
        return False
    if hgt_idx < 0 or hgt_idx > grd_z_hi:
        return False

    last_ts = grd_bins[hgt_idx, lat_idx, lon_idx]
    if ts - last_ts > t_delta:
        grd_bins[hgt_idx, lat_idx, lon_idx] = ts
        return True
    else:
        return False


# This mostly reduces some categories for a degree of class balancing and removes no intensity reports
def process(ice_dct, no_ice_dct, neg_ice_dct, t_delta=600):
    t_delta = 600  # seconds

    new_ice_dct = {}
    new_no_ice_dct = {}
    new_neg_ice_dct = {}

    ice_keys_5_6 = []
    ice_tidx_5_6 = []
    ice_keys_1 = []
    ice_tidx_1 = []
    ice_keys_4 = []
    ice_tidx_4 = []
    ice_keys_3 = []
    ice_tidx_3 = []
    ice_keys_2 = []
    ice_tidx_2 = []

    print('num keys ice, no_ice, neg_ice: ', len(ice_dct), len(no_ice_dct), len(neg_ice_dct))
    no_intensity_cnt = 0
    num_ice_reports = 0

    for ts in list(ice_dct.keys()):
        rpts = ice_dct[ts]
        for idx, tup in enumerate(rpts):
            num_ice_reports += 1
            if tup[3] == 5 or tup[3] == 6:
                ice_keys_5_6.append(ts)
                ice_tidx_5_6.append(idx)
            elif tup[3] == 1:
                ice_keys_1.append(ts)
                ice_tidx_1.append(idx)
            elif tup[3] == 4:
                ice_keys_4.append(ts)
                ice_tidx_4.append(idx)
            elif tup[3] == 3:
                ice_keys_3.append(ts)
                ice_tidx_3.append(idx)
            elif tup[3] == 2:
                ice_keys_2.append(ts)
                ice_tidx_2.append(idx)
            else:
                no_intensity_cnt += 1

    no_ice_keys = []
    no_ice_tidx = []
    for ts in list(no_ice_dct.keys()):
        rpts = no_ice_dct[ts]
        for idx, tup in enumerate(rpts):
            no_ice_keys.append(ts)
            no_ice_tidx.append(idx)

    neg_ice_keys = []
    neg_ice_tidx = []
    for ts in list(neg_ice_dct.keys()):
        rpts = neg_ice_dct[ts]
        for idx, tup in enumerate(rpts):
            neg_ice_keys.append(ts)
            neg_ice_tidx.append(idx)

    print('num ice reports, no ice, neg ice: ', num_ice_reports, len(no_ice_keys), len(neg_ice_keys))
    print('------------------------------------------------')

    ice_keys_5_6 = np.array(ice_keys_5_6)
    ice_tidx_5_6 = np.array(ice_tidx_5_6, dtype='int64')
    print('5_6: ', ice_keys_5_6.shape[0])

    ice_keys_4 = np.array(ice_keys_4)
    ice_tidx_4 = np.array(ice_tidx_4, dtype='int64')
    print('4: ', ice_keys_4.shape[0])

    ice_keys_3 = np.array(ice_keys_3)
    ice_tidx_3 = np.array(ice_tidx_3, dtype='int64')
    print('3: ', ice_keys_3.shape[0])

    ice_keys_2 = np.array(ice_keys_2)
    ice_tidx_2 = np.array(ice_tidx_2, dtype='int64')
    print('2: ', ice_keys_2.shape[0])

    ice_keys_1 = np.array(ice_keys_1)
    ice_tidx_1 = np.array(ice_tidx_1, dtype='int64')
    print('1: ', ice_keys_1.shape[0])
    print('0: ', no_intensity_cnt)

    ice_keys = np.concatenate([ice_keys_1, ice_keys_2, ice_keys_3, ice_keys_4, ice_keys_5_6])
    ice_tidx = np.concatenate([ice_tidx_1, ice_tidx_2, ice_tidx_3, ice_tidx_4, ice_tidx_5_6])
    print('icing total reduced: ', ice_tidx.shape)
    print('----------------------------')

    sidxs = np.argsort(ice_keys)
    ice_keys = ice_keys[sidxs]
    ice_tidx = ice_tidx[sidxs]

    # -----------------------------------------------------
    no_ice_keys = np.array(no_ice_keys)
    no_ice_tidx = np.array(no_ice_tidx)
    print('no ice total: ', no_ice_keys.shape[0])
    np.random.seed(42)
    ridxs = np.random.permutation(np.arange(no_ice_keys.shape[0]))
    no_ice_keys = no_ice_keys[ridxs]
    no_ice_tidx = no_ice_tidx[ridxs]
    no_ice_keys = no_ice_keys[::3]
    no_ice_tidx = no_ice_tidx[::3]
    print('no ice reduced: ', no_ice_keys.shape[0])
    print('---------------------------------')

    sidxs = np.argsort(no_ice_keys)
    no_ice_keys = no_ice_keys[sidxs]
    no_ice_tidx = no_ice_tidx[sidxs]

    all_which = np.concatenate([np.full(len(ice_keys), 1), np.full(len(no_ice_keys), 0)])
    all_keys = np.concatenate([ice_keys, no_ice_keys])
    all_tidx = np.concatenate([ice_tidx, no_ice_tidx])
    sidxs = np.argsort(all_keys)
    all_keys = all_keys[sidxs]
    all_tidx = all_tidx[sidxs]
    all_which = all_which[sidxs]

    grd_bins = np.full((lat_space_hdeg.shape[0], lon_space_hdeg.shape[0]), -(t_delta+1))
    # grd_bins = np.full((hgt_space_3000.shape[0], lat_space_hdeg.shape[0], lon_space_hdeg.shape[0]), -(t_delta+1))
    cnt_i = 0
    cnt_ni = 0
    for idx, key in enumerate(all_keys):
        i_ni = all_which[idx]

        if i_ni == 0:
            rpts = no_ice_dct[key]
            tup = rpts[all_tidx[idx]]
        elif i_ni == 1:
            rpts = ice_dct[key]
            tup = rpts[all_tidx[idx]]

        lat, lon, falt = tup[0], tup[1], tup[2]

        # if not check_no_overlap_alt(lon, lat, falt, key, grd_bins, t_delta=t_delta):
        if not check_no_overlap(lon, lat, key, grd_bins, t_delta=t_delta):
            # continue
            pass

        if i_ni == 0:
            cnt_ni += 1
            n_rpts = new_no_ice_dct.get(key)
            if n_rpts is None:
                n_rpts = []
                new_no_ice_dct[key] = n_rpts
            n_rpts.append(tup)
        elif i_ni == 1:
            cnt_i += 1
            n_rpts = new_ice_dct.get(key)
            if n_rpts is None:
                n_rpts = []
                new_ice_dct[key] = n_rpts
            n_rpts.append(tup)

    print('no overlap ICE: ', cnt_i)
    print('no overlap NO ICE: ', cnt_ni)

    # -----------------------------------------------------
    neg_ice_keys = np.array(neg_ice_keys)
    neg_ice_tidx = np.array(neg_ice_tidx)
    print('neg ice total: ', neg_ice_keys.shape[0])

    # grd_bins = np.full((hgt_space_3000.shape[0], lat_space_hdeg.shape[0], lon_space_hdeg.shape[0]), -(t_delta+1))
    grd_bins = np.full((lat_space_hdeg.shape[0], lon_space_hdeg.shape[0]), -(t_delta+1))
    cnt = 0
    for idx, key in enumerate(neg_ice_keys):
        rpts = neg_ice_dct[key]
        tup = rpts[neg_ice_tidx[idx]]

        lat, lon, falt = tup[0], tup[1], tup[2]
        # if not check_no_overlap_alt(lon, lat, falt, key, grd_bins, t_delta=t_delta):
        if not check_no_overlap(lon, lat, key, grd_bins, t_delta=t_delta):
            # continue
            pass
        cnt += 1

        n_rpts = new_neg_ice_dct.get(key)
        if n_rpts is None:
            n_rpts = []
            new_neg_ice_dct[key] = n_rpts
        n_rpts.append(tup)
    print('neg icing total no overlap: ', cnt)
    # -------------------------------------------------

    return new_ice_dct, new_no_ice_dct, new_neg_ice_dct


def analyze2(filename, filename_l1b):
    f = h5py.File(filename, 'r')
    f_l1b = h5py.File(filename_l1b, 'r')

    iint = f['icing_intensity'][:]
    icing_alt = f['flight_altitude'][:]
    num_obs = iint.size
    iint = np.broadcast_to(iint.reshape(num_obs, 1), (num_obs, 256)).flatten()
    icing_alt = np.broadcast_to(icing_alt.reshape(num_obs, 1), (num_obs, 256)).flatten()

    keep_mask = f['FOV_mask'][:, :, :].reshape([num_obs, -1])
    keep_mask = keep_mask.astype(np.bool)
    keep_mask = keep_mask.flatten()

    iint = iint[keep_mask]
    icing_alt = icing_alt[keep_mask]
    num_obs = iint.size
    print('num obs: ', num_obs)

    no_ice = iint == -1  # All no icing
    ice = iint >= 1  # All icing
    tr_ice = (iint == 1) | (iint == 2)  # True for icing cats 1 and 2
    tr_plus_ice = iint > 2  # True for icing cats 3 and above

    print('num no ice: ', np.sum(no_ice))
    print('num ice:    ', np.sum(ice))
    print('trace ice: ', np.sum(tr_ice))
    print('trace plus ice:    ', np.sum(tr_plus_ice))

    # keep_no_1_2 = (iint == -1) | (iint == 1) | (iint == 2)
    # keep_no_3_4_5 = (iint == -1) | (iint > 2)
    # # iint = iint[keep_no_1_2]
    # iint = iint[keep_no_3_4_5]
    # # icing_alt = icing_alt[keep_no_1_2]
    # icing_alt = icing_alt[keep_no_3_4_5]
    # no_ice = iint == -1
    # ice = iint > 2

    cld_top_hgt = f['cld_height_acha'][:, :, :].flatten()
    cld_top_tmp = f['cld_temp_acha'][:, :, :].flatten()
    sc_cld_frac = f['supercooled_cloud_fraction'][:, :, :].flatten()
    cld_mask = f['cloud_mask'][:, :, :].flatten()
    cld_opd = f['cld_opd_dcomp'][:, :, :].flatten()
    # cld_opd = f['cld_opd_acha'][:, :, :].flatten()
    cld_reff = f['cld_reff_dcomp'][:, :, :].flatten()
    cld_lwc = f['lwc_dcomp'][:, :, :].flatten()
    cld_iwc = f['iwc_dcomp'][:, :, :].flatten()
    # cld_reff = f['cld_reff_acha'][:, :, :].flatten()
    cld_emiss = f['cld_emiss_acha'][:, :, :].flatten()
    # cld_phase = f['cloud_phase'][:, :, :].flatten()
    # solzen = f['solar_zenith_angle'][:, :, :].flatten()
    # bt_10_4 = f_l1b['temp_10_4um_nom'][:, :, :].flatten()

    cld_top_hgt = cld_top_hgt[keep_mask]
    cld_top_tmp = cld_top_tmp[keep_mask]
    sc_cld_frac = sc_cld_frac[keep_mask]
    cld_opd = cld_opd[keep_mask]
    cld_reff = cld_reff[keep_mask]
    cld_lwc = cld_lwc[keep_mask]
    cld_iwc = cld_iwc[keep_mask]
    cld_mask = cld_mask[keep_mask]

    cld_mask = np.where(cld_mask > 1, 1, 0)
    # cld_frac = np.sum(cld_mask, axis=1) / 256 # need to do this earlier

    #  Simple model: Everything below the melting point is icing=True else icing=False  ------------------------------
    preds = cld_top_tmp <= 273.15
    # preds = cld_top_tmp[keep_no_3_4_5] <= 273.15

    true_ice = ice & preds
    false_ice = no_ice & preds

    true_no_ice = no_ice & np.invert(preds)
    false_no_ice = ice & np.invert(preds)

    tp = np.sum(true_ice).astype(np.float32)
    tn = np.sum(true_no_ice).astype(np.float32)
    fp = np.sum(false_ice).astype(np.float32)
    fn = np.sum(false_no_ice).astype(np.float32)

    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    f1 = 2 * (precision * recall) / (precision + recall)
    mcc = ((tp * tn) - (fp * fn)) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    acc = (tp + tn) / (tp + tn + fp + fn)

    print('Total (Positive/Icing Prediction: ')
    print('True icing:                         ', np.sum(true_ice))
    print('-------------------------')
    print('False no icing (False Negative/Miss): ', np.sum(false_no_ice))
    print('---------------------------------------------------')
    print('---------------------------------------------------')

    print('Total (Negative/No Icing Prediction: ')
    print('True no icing:                             ', np.sum(true_no_ice))
    print('-------------------------')
    print('* False icing (False Positive/False Alarm) *:  ', np.sum(false_ice))
    print('-------------------------')

    print('ACC:         ', acc)
    print('Recall:      ', recall)
    print('Precision:   ', precision)
    print('F1:          ', f1)
    print('MCC:         ', mcc)
    # ------------------------------------------------------------------------------------------------------------

    f.close()
    f_l1b.close()

    return (icing_alt[no_ice], icing_alt[ice],
            cld_top_hgt[no_ice], cld_top_hgt[ice],
            cld_top_tmp[no_ice], cld_top_tmp[ice],
            cld_opd[no_ice], cld_opd[ice],
            cld_reff[no_ice], cld_reff[ice],
            cld_lwc[no_ice], cld_lwc[ice],
            cld_iwc[no_ice], cld_iwc[ice],
            sc_cld_frac[no_ice], sc_cld_frac[ice])


# --------------------------------------------
# x_a = 10
# x_b = 30
x_a = 12
x_b = 28
y_a = x_a
y_b = x_b
nx = ny = (x_b - x_a)
nx_x_ny = nx * ny


def run_daynight(filename, filename_l1b, day_night='ANY'):
    f = h5py.File(filename, 'r')
    f_l1b = h5py.File(filename_l1b, 'r')

    solzen = f['solar_zenith_angle'][:, y_a:y_b, x_a:x_b]
    satzen = f['sensor_zenith_angle'][:, y_a:y_b, x_a:x_b]
    num_obs = solzen.shape[0]

    idxs = []
    for i in range(num_obs):
        if not check_oblique(satzen[i,]):
            continue
        if day_night == 'NIGHT' and is_night(solzen[i,]):
            idxs.append(i)
        elif day_night == 'DAY' and is_day(solzen[i,]):
            idxs.append(i)
        elif day_night == 'ANY':
            if is_day(solzen[i,]) or is_night(solzen[i,]):
                idxs.append(i)

    keep_idxs = np.array(idxs)

    data_dct = {}
    for didx, ds_name in enumerate(ds_list):
        data_dct[ds_name] = f[ds_name][keep_idxs,]

    lon_c = f['longitude'][keep_idxs]
    lat_c = f['latitude'][keep_idxs]
    time_s = f['time'][keep_idxs]
    fl_alt_s = f['icing_altitude'][keep_idxs]
    ice_int_s = f['icing_intensity'][keep_idxs]
    unq_ids = f['unique_id'][keep_idxs]

    path, fname = os.path.split(filename)
    fbase, fext = os.path.splitext(fname)
    outfile = path + '/' + fbase + '_' + day_night + fext

    create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)

    data_dct = {}
    for didx, ds_name in enumerate(l1b_ds_list):
        data_dct[ds_name] = f_l1b[ds_name][keep_idxs]

    path, fname = os.path.split(filename_l1b)
    fbase, fext = os.path.splitext(fname)
    outfile_l1b = path + '/' + fbase + '_' + day_night + fext

    create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)

    f.close()
    f_l1b.close()


def run_qc(filename, filename_l1b, day_night='ANY', pass_thresh_frac=0.20, cloudy_frac=0.5, icing=True):

    f = h5py.File(filename, 'r')
    icing_alt = f['icing_altitude'][:]
    cld_top_hgt = f['cld_height_acha'][:, y_a:y_b, x_a:x_b]
    cld_geo_dz = f['cld_geo_thick'][:, y_a:y_b, x_a:x_b]
    cld_phase = f['cloud_phase'][:, y_a:y_b, x_a:x_b]
    cld_temp = f['cld_temp_acha'][:, y_a:y_b, x_a:x_b]

    if day_night == 'DAY':
        cld_opd = f['cld_opd_dcomp'][:, y_a:y_b, x_a:x_b]
    else:
        cld_opd = f['cld_opd_acha'][:, y_a:y_b, x_a:x_b]

    cld_mask = f['cloud_mask'][:, y_a:y_b, x_a:x_b]
    sol_zen = f['solar_zenith_angle'][:, y_a:y_b, x_a:x_b]
    sat_zen = f['sensor_zenith_angle'][:, y_a:y_b, x_a:x_b]

    f_l1b = h5py.File(filename_l1b, 'r')
    bt_11um = f_l1b['temp_11_0um_nom'][:, y_a:y_b, x_a:x_b]

    if icing:
        # mask, idxs, num_tested = apply_qc_icing_pireps(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, sol_zen, sat_zen, cld_temp, day_night=day_night, cloudy_frac=cloudy_frac)
        mask, idxs, num_tested = apply_qc_icing_pireps_v3(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, sol_zen, sat_zen, cld_temp, day_night=day_night, cloudy_frac=cloudy_frac)
    else:
        # mask, idxs, num_tested = apply_qc_no_icing_pireps(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, sol_zen, sat_zen, cld_temp, day_night=day_night, cloudy_frac=cloudy_frac)
        mask, idxs, num_tested = apply_qc_icing_pireps_v3(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, sol_zen, sat_zen, cld_temp, day_night=day_night, cloudy_frac=cloudy_frac)

    keep_idxs = []
    keep_mask = []

    for i in range(len(mask)):
        # keep_idxs.append(idxs[i])
        frac = np.sum(mask[i]) / num_tested[i]
        if icing:
            if frac > pass_thresh_frac:
                keep_idxs.append(idxs[i])
                keep_mask.append(mask)
        elif frac > pass_thresh_frac:
            keep_idxs.append(idxs[i])
            keep_mask.append(mask)

    if len(keep_idxs) == 0:
        f.close()
        f_l1b.close()
        return len(icing_alt), 0

    print('day_night, icing, all, valid, pass: ', day_night, icing, len(icing_alt), len(mask), len(keep_idxs))
    print('-----------------------')
    keep_idxs = np.array(keep_idxs)
    mask = np.concatenate(mask).reshape((-1, ny, nx))

    data_dct = {}
    for didx, ds_name in enumerate(ds_list):
        data_dct[ds_name] = f[ds_name][keep_idxs, y_a:y_b, x_a:x_b]

    lon_c = f['longitude'][keep_idxs]
    lat_c = f['latitude'][keep_idxs]
    time_s = f['time'][keep_idxs]
    fl_alt_s = f['icing_altitude'][keep_idxs]
    ice_int_s = f['icing_intensity'][keep_idxs]
    unq_ids = f['unique_id'][keep_idxs]

    path, fname = os.path.split(filename)
    fbase, fext = os.path.splitext(fname)
    outfile = path + '/' + fbase + '_' + 'QC' + '_' + day_night + fext

    create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids, mask)

    data_dct = {}
    for didx, ds_name in enumerate(l1b_ds_list):
        data_dct[ds_name] = f_l1b[ds_name][keep_idxs, y_a:y_b, x_a:x_b]

    path, fname = os.path.split(filename_l1b)
    fbase, fext = os.path.splitext(fname)
    outfile_l1b = path + '/' + fbase + '_' + 'QC' + '_' + day_night + fext

    create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids, mask)

    f.close()
    f_l1b.close()

    return len(icing_alt), len(keep_idxs)


def apply_qc_icing_pireps(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, cld_top_temp, day_night='ANY', cloudy_frac=0.5):

    if day_night == 'DAY':
        opd_thick_threshold = 20
        opd_thin_threshold = 1
    elif day_night == 'NIGHT' or day_night == 'ANY':
        opd_thick_threshold = 2
        opd_thin_threshold = 0.1

    closeness_top = 200.0  # meters
    max_cld_depth = 1000.0
    max_altitude = 4000.0
    max_cld_altitude = 4000.0

    num_obs = len(icing_alt)
    cld_mask = cld_mask.reshape((num_obs, -1))
    cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
    cld_geo_dz = cld_geo_dz.reshape((num_obs, -1))
    bt_11um = bt_11um.reshape((num_obs, -1))
    cld_top_temp = cld_top_temp.reshape((num_obs, -1))

    mask = []
    idxs = []
    num_tested = []
    for i in range(num_obs):
        if not check_oblique(satzen[i,]):
            continue
        if day_night == 'NIGHT' and not is_night(solzen[i,]):
            continue
        elif day_night == 'DAY' and not is_day(solzen[i,]):
            continue
        elif day_night == 'ANY':
            pass
            # if not (is_day(solzen[i,]) or is_night(solzen[i,])):
            #     continue

        # if not (icing_alt[i] < max_altitude):
        #     continue

        keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
        keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
        keep_2 = np.invert(np.isnan(bt_11um[i,]))
        keep_3 = np.invert(np.isnan(cld_geo_dz[i,]))
        keep = keep_0 & keep_1 & keep_2 & keep_3
        num_keep = np.sum(keep)
        if num_keep == 0:
            continue
        if num_keep / nx_x_ny < cloudy_frac:  # At least this fraction cloudy
            continue

        keep = np.where(keep, (cld_top_hgt[i,] + closeness_top) > icing_alt[i], False)
        # keep = np.where(keep, cld_top_hgt[i,]  < max_cld_altitude, False)
        # keep = np.where(keep, cld_top_hgt[i,] > max_cld_depth, False)
        # keep = np.where(keep, (cld_top_hgt[i,] - max_cld_depth) < icing_alt[i], False)

        keep = np.where(keep, np.logical_and(bt_11um[i,] > 228.0, bt_11um[i,] < 274.0), False)
        # keep = np.where(keep, bt_11um[i,] < 275.0, False)
        # keep = np.where(keep, cld_top_temp[i,] < 274.2, False)

        # cld_hgt = cld_top_hgt[i, ].flatten()
        # med_cld_hgt = np.median(cld_hgt[keep])
        # if icing_alt[i] > med_cld_hgt:
        #     continue
        #
        # cld_tmp = cld_top_temp[i, ].flatten()
        # med_cld_tmp = np.median(cld_tmp[keep])
        # if med_cld_tmp > 275.0:
        #     continue
        #
        # if icing_alt[i] < med_cld_hgt - max_cld_depth:
        #     continue

        mask.append(keep)
        idxs.append(i)
        num_tested.append(num_keep)

    return mask, idxs, num_tested


def apply_qc_icing_pireps_exp1(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, cld_top_temp, day_night='ANY', cloudy_frac=0.5):

    if day_night == 'DAY':
        opd_thick_threshold = 20
        opd_thin_threshold = 1
    elif day_night == 'NIGHT' or day_night == 'ANY':
        opd_thick_threshold = 2
        opd_thin_threshold = 0.1

    closeness_top = 200.0  # meters
    max_cld_depth = 1000.0
    max_altitude = 4000.0
    max_cld_altitude = 4000.0

    num_obs = len(icing_alt)
    cld_mask = cld_mask.reshape((num_obs, -1))
    cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
    cld_geo_dz = cld_geo_dz.reshape((num_obs, -1))
    bt_11um = bt_11um.reshape((num_obs, -1))
    cld_top_temp = cld_top_temp.reshape((num_obs, -1))

    mask = []
    idxs = []
    num_tested = []
    for i in range(num_obs):
        if not check_oblique(satzen[i,]):
            continue
        if day_night == 'NIGHT' and not is_night(solzen[i,]):
            continue
        elif day_night == 'DAY' and not is_day(solzen[i,]):
            continue
        elif day_night == 'ANY':
            pass

        # keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
        keep_0 = cld_mask[i, ] == 3  # Confident cloudy
        keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
        keep_2 = np.invert(np.isnan(bt_11um[i,]))
        keep_3 = np.invert(np.isnan(cld_geo_dz[i,]))
        keep = keep_0 & keep_1 & keep_2 & keep_3
        num_keep = np.sum(keep)
        if num_keep == 0:
            continue
        if num_keep / nx_x_ny < cloudy_frac:  # At least this fraction cloudy
            continue

        keep = np.where(keep, np.logical_and(cld_top_hgt[i,] > icing_alt[i,],
                                  (cld_top_hgt[i,] - 1500.0) < icing_alt[i,]), False)
        keep = np.where(keep, np.logical_and(cld_top_temp[i,] > 228.0, cld_top_temp[i,] < 274.0), False)

        mask.append(keep)
        idxs.append(i)
        num_tested.append(num_keep)

    return mask, idxs, num_tested


def apply_qc_icing_pireps_exp2(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, cld_top_temp, day_night='ANY', cloudy_frac=0.5):

    closeness_top = 200.0  # meters
    max_cld_depth = 1000.0
    max_altitude = 4000.0
    max_cld_altitude = 4000.0

    num_obs = len(icing_alt)
    cld_mask = cld_mask.reshape((num_obs, -1))
    cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
    cld_geo_dz = cld_geo_dz.reshape((num_obs, -1))
    bt_11um = bt_11um.reshape((num_obs, -1))
    cld_top_temp = cld_top_temp.reshape((num_obs, -1))

    mask = []
    idxs = []
    num_tested = []
    for i in range(num_obs):
        if not check_oblique(satzen[i,]):
            continue
        if day_night == 'NIGHT' and not is_night(solzen[i,]):
            continue
        elif day_night == 'DAY' and not is_day(solzen[i,]):
            continue
        elif day_night == 'ANY':
            pass

        # keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
        keep_0 = cld_mask[i, ] == 3  # Confident cloudy
        keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
        keep_2 = np.invert(np.isnan(bt_11um[i,]))
        keep_3 = np.invert(np.isnan(cld_geo_dz[i,]))
        keep = keep_0 & keep_1 & keep_2 & keep_3
        num_keep = np.sum(keep)
        if num_keep == 0:
            continue
        if num_keep / nx_x_ny < cloudy_frac:  # At least this fraction cloudy
            continue

        keep = np.where(keep, np.logical_and(cld_top_hgt[i,] > icing_alt[i,],
                                  (cld_top_hgt[i,] - 3000.0) < icing_alt[i,]), False)
        keep = np.where(keep, np.logical_and(cld_top_temp[i,] > 228.0, cld_top_temp[i,] < 274.0), False)

        mask.append(keep)
        idxs.append(i)
        num_tested.append(num_keep)

    return mask, idxs, num_tested


def apply_qc_icing_pireps_v3(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, cld_top_temp, day_night='ANY', cloudy_frac=0.5):

    closeness_top = 100.0  # meters
    max_cld_depth = 1500.0
    max_altitude = 5000.0
    max_cld_altitude = 5000.0

    num_obs = len(icing_alt)
    cld_mask = cld_mask.reshape((num_obs, -1))
    cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
    cld_geo_dz = cld_geo_dz.reshape((num_obs, -1))
    bt_11um = bt_11um.reshape((num_obs, -1))
    cld_top_temp = cld_top_temp.reshape((num_obs, -1))

    mask = []
    idxs = []
    num_tested = []
    for i in range(num_obs):
        if not check_oblique(satzen[i,]):
            continue
        if day_night == 'NIGHT' and not is_night(solzen[i,]):
            continue
        elif day_night == 'DAY' and not is_day(solzen[i,]):
            continue
        elif day_night == 'ANY':
            pass

        # keep_0 = cld_mask[i, ] == 3  # Confident cloudy
        keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
        keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
        # keep_2 = np.invert(np.isnan(cld_geo_dz[i,]))
        # keep_3 = np.invert(np.isnan(bt_11um[i,]))
        # keep = keep_0 & keep_1 & keep_2 & keep_3
        keep = keep_0 & keep_1
        num_keep = np.sum(keep)
        if num_keep == 0:
            continue
        if num_keep / nx_x_ny < cloudy_frac:  # At least this fraction cloudy
            continue

        # if icing_alt[i,] > max_altitude:
        #     continue

        # ctt_mean = np.nanmean(cld_top_temp[i,])
        # if ctt_mean > 274.0:
        #     continue

        # cth_mean = np.nanmean(cld_top_hgt[i,])
        # if not (closeness_top+cth_mean > icing_alt[i,] > cth_mean - max_cld_depth):
        #     continue

        # keep = np.where(keep, cld_top_hgt[i,] < max_cld_altitude, False)

        keep = np.where(keep, np.logical_and((cld_top_hgt[i,] + closeness_top) > icing_alt[i,],
                                  (cld_top_hgt[i,] - max_cld_depth) < icing_alt[i,]), False)

        # keep = np.where(keep, np.logical_and(cld_top_temp[i,] > 228.0, cld_top_temp[i,] < 274.0), False)

        mask.append(keep)
        idxs.append(i)
        num_tested.append(num_keep)

    return mask, idxs, num_tested


def apply_qc_no_icing_pireps(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, cld_top_temp, day_night='ANY', cloudy_frac=0.5):

    if day_night == 'DAY':
        opd_thick_threshold = 20
        opd_thin_threshold = 1
    elif day_night == 'NIGHT' or day_night == 'ANY':
        opd_thick_threshold = 2
        opd_thin_threshold = 0.1

    closeness_top = 200.0  # meters
    max_cld_depth = 1000.0
    max_altitude = 5000.0
    max_cld_altitude = 5000.0

    num_obs = len(icing_alt)
    cld_mask = cld_mask.reshape((num_obs, -1))
    cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
    cld_geo_dz = cld_geo_dz.reshape((num_obs, -1))
    bt_11um = bt_11um.reshape((num_obs, -1))
    cld_top_temp = cld_top_temp.reshape((num_obs, -1))

    mask = []
    idxs = []
    num_tested = []
    for i in range(num_obs):
        if not check_oblique(satzen[i,]):
            continue
        if day_night == 'NIGHT' and not is_night(solzen[i,]):
            continue
        elif day_night == 'DAY' and not is_day(solzen[i,]):
            continue
        elif day_night == 'ANY':
            pass
            # if not (is_day(solzen[i,]) or is_night(solzen[i,])):
            #     continue

        # if not (icing_alt[i] < max_altitude):
        #     continue

        keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
        keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
        keep_2 = np.invert(np.isnan(bt_11um[i,]))
        keep_3 = np.invert(np.isnan(cld_geo_dz[i,]))
        keep = keep_0 & keep_1 & keep_2 & keep_3
        num_keep = np.sum(keep)
        if num_keep == 0:
            continue
        if num_keep / nx_x_ny < cloudy_frac:  # At least this fraction cloudy
            continue

        keep = np.where(keep, (cld_top_hgt[i,] + closeness_top) > icing_alt[i], False)
        # keep = np.where(keep, cld_top_hgt[i,]  < max_cld_altitude, False)
        # keep = np.where(keep, cld_top_hgt[i,] > max_cld_depth, False)
        # keep = np.where(keep, (cld_top_hgt[i,] - max_cld_depth) < icing_alt[i], False)

        keep = np.where(keep, np.logical_and(bt_11um[i,] > 228.0, bt_11um[i,] < 274.0), False)
        # keep = np.where(keep, bt_11um[i,] < 275.0, False)
        # keep = np.where(keep, cld_top_temp[i,] < 274.2, False)

        # cld_hgt = cld_top_hgt[i, ].flatten()
        # med_cld_hgt = np.median(cld_hgt[keep])
        # if icing_alt[i] > med_cld_hgt:
        #     continue
        #
        # cld_tmp = cld_top_temp[i, ].flatten()
        # med_cld_tmp = np.median(cld_tmp[keep])
        # if med_cld_tmp > 275.0:
        #     continue
        #
        # if icing_alt[i] < med_cld_hgt - max_cld_depth:
        #     continue

        mask.append(keep)
        idxs.append(i)
        num_tested.append(num_keep)

    return mask, idxs, num_tested


def apply_qc_no_icing_pireps_exp2(icing_alt, cld_top_hgt, cld_geo_dz, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, cld_top_temp, day_night='ANY', cloudy_frac=0.5):

    closeness_top = 200.0  # meters
    max_cld_depth = 1000.0
    max_altitude = 4000.0
    max_cld_altitude = 4000.0

    num_obs = len(icing_alt)
    cld_mask = cld_mask.reshape((num_obs, -1))
    cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
    cld_geo_dz = cld_geo_dz.reshape((num_obs, -1))
    bt_11um = bt_11um.reshape((num_obs, -1))
    cld_top_temp = cld_top_temp.reshape((num_obs, -1))

    mask = []
    idxs = []
    num_tested = []
    for i in range(num_obs):
        if not check_oblique(satzen[i,]):
            continue
        if day_night == 'NIGHT' and not is_night(solzen[i,]):
            continue
        elif day_night == 'DAY' and not is_day(solzen[i,]):
            continue
        elif day_night == 'ANY':
            pass

        # keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
        keep_0 = cld_mask[i, ] == 3  # Confident cloudy
        keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
        keep_2 = np.invert(np.isnan(bt_11um[i,]))
        keep_3 = np.invert(np.isnan(cld_geo_dz[i,]))
        keep = keep_0 & keep_1 & keep_2 & keep_3
        num_keep = np.sum(keep)
        if num_keep == 0:
            continue
        if num_keep / nx_x_ny < cloudy_frac:  # At least this fraction cloudy
            continue

        keep = np.where(keep, np.logical_and(cld_top_hgt[i,] > icing_alt[i,],
                                  (cld_top_hgt[i,] - 3000.0) < icing_alt[i,]), False)
        # keep = np.where(keep, np.logical_and(cld_top_temp[i,] > 228.0, cld_top_temp[i,] < 274.0), False)

        mask.append(keep)
        idxs.append(i)
        num_tested.append(num_keep)

    return mask, idxs, num_tested


def fov_extract(icing_files, no_icing_files, trnfile='/home/rink/fovs_l1b_train.h5', tstfile='/home/rink/fovs_l1b_test.h5', L1B_or_L2='L1B', split=0.2):
    ice_times = []
    icing_int_s = []
    ice_lons = []
    ice_lats = []

    no_ice_times = []
    no_ice_lons = []
    no_ice_lats = []

    h5_s_icing = []
    h5_s_no_icing = []

    if L1B_or_L2 == 'L1B':
        params = l1b_ds_list
        param_types = l1b_ds_types
    elif L1B_or_L2 == 'L2':
        params = ds_list
        param_types = ds_types

    icing_data_dct = {ds: [] for ds in params}
    no_icing_data_dct = {ds: [] for ds in params}

    sub_indexes = np.arange(400)

    num_ice = 0
    for fidx in range(len(icing_files)):
        fname = icing_files[fidx]
        f = h5py.File(fname, 'r')
        h5_s_icing.append(f)

        times = f['time'][:]
        num_obs = len(times)

        icing_int = f['icing_intensity'][:]
        lons = f['longitude'][:]
        lats = f['latitude'][:]

        cld_mask = f['cloud_mask'][:, 10:30, 10:30]
        cld_mask = cld_mask.reshape((num_obs, -1))

        cld_top_temp = f['cld_temp_acha'][:, 10:30, 10:30]
        cld_top_temp = cld_top_temp.reshape((num_obs, -1))

        for i in range(num_obs):
            keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
            keep_1 = np.invert(np.isnan(cld_top_temp[i,]))
            keep = keep_0 & keep_1
            keep = np.where(keep, cld_top_temp[i,] < 273.0, False)
            k_idxs = sub_indexes[keep]
            np.random.shuffle(k_idxs)
            if len(k_idxs) > 20:
                k_idxs = k_idxs[0:20]
            else:
                k_idxs = k_idxs[0:len(k_idxs)]
            num_ice += len(k_idxs)

            for ds_name in params:
                dat = f[ds_name][i, 10:30, 10:30].flatten()
                icing_data_dct[ds_name].append(dat[k_idxs])

            icing_int_s.append(np.full(len(k_idxs), icing_int[i]))
            ice_times.append(np.full(len(k_idxs), times[i]))
            ice_lons.append(np.full(len(k_idxs), lons[i]))
            ice_lats.append(np.full(len(k_idxs), lats[i]))

        print(fname)

    for ds_name in params:
        lst = icing_data_dct[ds_name]
        icing_data_dct[ds_name] = np.concatenate(lst)

    icing_int_s = np.concatenate(icing_int_s)
    ice_times = np.concatenate(ice_times)
    ice_lons = np.concatenate(ice_lons)
    ice_lats = np.concatenate(ice_lats)

    num_no_ice = 0
    for fidx in range(len(no_icing_files)):
        fname = no_icing_files[fidx]
        f = h5py.File(fname, 'r')
        h5_s_no_icing.append(f)

        times = f['time']
        num_obs = len(times)
        lons = f['longitude']
        lats = f['latitude']

        cld_mask = f['cloud_mask'][:, 10:30, 10:30]
        cld_mask = cld_mask.reshape((num_obs, -1))

        cld_top_temp = f['cld_temp_acha'][:, 10:30, 10:30]
        cld_top_temp = cld_top_temp.reshape((num_obs, -1))

        for i in range(num_obs):
            keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
            keep_1 = np.invert(np.isnan(cld_top_temp[i,]))
            keep = keep_0 & keep_1
            keep = np.where(keep, cld_top_temp[i,] < 273.0, False)
            k_idxs = sub_indexes[keep]
            np.random.shuffle(k_idxs)
            if len(k_idxs) > 10:
                k_idxs = k_idxs[0:10]
            else:
                k_idxs = k_idxs[0:len(k_idxs)]
            num_no_ice += len(k_idxs)
            no_ice_times.append(np.full(len(k_idxs), times[i]))
            no_ice_lons.append(np.full(len(k_idxs), lons[i]))
            no_ice_lats.append(np.full(len(k_idxs), lats[i]))

            for ds_name in params:
                dat = f[ds_name][i, 10:30, 10:30].flatten()
                no_icing_data_dct[ds_name].append(dat[k_idxs])

        print(fname)

    for ds_name in params:
        lst = no_icing_data_dct[ds_name]
        no_icing_data_dct[ds_name] = np.concatenate(lst)
    no_icing_int_s = np.full(num_no_ice, -1)
    no_ice_times = np.concatenate(no_ice_times)
    no_ice_lons = np.concatenate(no_ice_lons)
    no_ice_lats = np.concatenate(no_ice_lats)

    icing_intensity = np.concatenate([icing_int_s, no_icing_int_s])
    icing_times = np.concatenate([ice_times, no_ice_times])
    icing_lons = np.concatenate([ice_lons, no_ice_lons])
    icing_lats = np.concatenate([ice_lats, no_ice_lats])

    data_dct = {}
    for ds_name in params:
        data_dct[ds_name] = np.concatenate([icing_data_dct[ds_name], no_icing_data_dct[ds_name]])

    # apply shuffle indexes
    # ds_indexes = np.arange(num_ice + num_no_ice)
    # np.random.shuffle(ds_indexes)
    #
    # for ds_name in train_params:
    #     data_dct[ds_name] = data_dct[ds_name][ds_indexes]
    # icing_intensity = icing_intensity[ds_indexes]
    # icing_times = icing_times[ds_indexes]
    # icing_lons = icing_lons[ds_indexes]
    # icing_lats = icing_lats[ds_indexes]

    # do sort
    ds_indexes = np.argsort(icing_times)
    for ds_name in params:
        data_dct[ds_name] = data_dct[ds_name][ds_indexes]
    icing_intensity = icing_intensity[ds_indexes]
    icing_times = icing_times[ds_indexes]
    icing_lons = icing_lons[ds_indexes]
    icing_lats = icing_lats[ds_indexes]

    #trn_idxs, tst_idxs = split_data(icing_intensity.shape[0], shuffle=False, perc=split)
    all_idxs = np.arange(icing_intensity.shape[0])
    splt_idx = int(icing_intensity.shape[0] * (1-split))
    trn_idxs = all_idxs[0:splt_idx]
    tst_idxs = all_idxs[splt_idx:]

    trn_data_dct = {}
    for ds_name in params:
        trn_data_dct[ds_name] = data_dct[ds_name][trn_idxs,]
    trn_icing_intensity = icing_intensity[trn_idxs,]
    trn_icing_times = icing_times[trn_idxs,]
    trn_icing_lons = icing_lons[trn_idxs,]
    trn_icing_lats = icing_lats[trn_idxs,]

    write_file(trnfile, params, param_types, trn_data_dct, trn_icing_intensity, trn_icing_times, trn_icing_lons, trn_icing_lats)

    tst_data_dct = {}
    for ds_name in params:
        tst_data_dct[ds_name] = data_dct[ds_name][tst_idxs,]
    tst_icing_intensity = icing_intensity[tst_idxs,]
    tst_icing_times = icing_times[tst_idxs,]
    tst_icing_lons = icing_lons[tst_idxs,]
    tst_icing_lats = icing_lats[tst_idxs,]

    # do sort
    ds_indexes = np.argsort(tst_icing_times)
    for ds_name in params:
        tst_data_dct[ds_name] = tst_data_dct[ds_name][ds_indexes]
    tst_icing_intensity = tst_icing_intensity[ds_indexes]
    tst_icing_times = tst_icing_times[ds_indexes]
    tst_icing_lons = tst_icing_lons[ds_indexes]
    tst_icing_lats = tst_icing_lats[ds_indexes]

    write_file(tstfile, params, param_types, tst_data_dct, tst_icing_intensity, tst_icing_times, tst_icing_lons, tst_icing_lats)

    # --- close files
    for h5f in h5_s_icing:
        h5f.close()

    for h5f in h5_s_no_icing:
        h5f.close()


def tile_extract(icing_files, no_icing_files, trnfile='/home/rink/tiles_train.h5', vldfile='/home/rink/tiles_valid.h5', tstfile='/home/rink/tiles_test.h5', L1B_or_L2='L1B',
                 cld_mask_name='cloud_mask', augment=False, do_split=True, has_test=True):
    # 16x16
    # n_a, n_b = 12, 28
    # m_a, m_b = 12, 28
    n_a, n_b = 0, 16
    m_a, m_b = 0, 16
    # 10x10
    # n_a, n_b = 15, 25
    # m_a, m_b = 15, 25

    icing_int_s = []
    ice_time_s = []
    no_ice_time_s = []
    ice_lon_s = []
    no_ice_lon_s = []
    ice_lat_s = []
    no_ice_lat_s = []
    ice_flt_alt_s = []
    no_ice_flt_alt_s = []

    h5_s_icing = []
    h5_s_no_icing = []

    if L1B_or_L2 == 'L1B':
        params = l1b_ds_list
        param_types = l1b_ds_types
    elif L1B_or_L2 == 'L2':
        params = ds_list
        param_types = ds_types
    params.append('FOV_mask')
    param_types.append('i1')

    icing_data_dct = {ds: [] for ds in params}
    no_icing_data_dct = {ds: [] for ds in params}

    for fidx in range(len(icing_files)):
        fname = icing_files[fidx]
        f = h5py.File(fname, 'r')
        h5_s_icing.append(f)

        times = f['time'][:]
        num_obs = len(times)
        lons = f['longitude']
        lats = f['latitude']

        icing_int = f['icing_intensity'][:]
        flt_altitude = f['icing_altitude'][:]

        for i in range(num_obs):
            cld_msk = f[cld_mask_name][i, n_a:n_b, m_a:m_b]
            for ds_name in params:
                dat = f[ds_name][i, n_a:n_b, m_a:m_b]
                if L1B_or_L2 == 'L2':
                    keep = np.logical_or(cld_msk == 2, cld_msk == 3)  # cloudy
                    np.where(keep, dat, np.nan)
                icing_data_dct[ds_name].append(dat)

            icing_int_s.append(icing_int[i])
            ice_time_s.append(times[i])
            ice_lon_s.append(lons[i])
            ice_lat_s.append(lats[i])
            ice_flt_alt_s.append(flt_altitude[i])

        print(fname)

    for ds_name in params:
        lst = icing_data_dct[ds_name]
        icing_data_dct[ds_name] = np.stack(lst, axis=0)
    icing_int_s = np.array(icing_int_s)
    ice_time_s = np.array(ice_time_s)
    ice_lon_s = np.array(ice_lon_s)
    ice_lat_s = np.array(ice_lat_s)
    ice_flt_alt_s = np.array(ice_flt_alt_s)

    # No icing  ------------------------------------------------------------
    num_no_ice = 0
    for fidx in range(len(no_icing_files)):
        fname = no_icing_files[fidx]
        f = h5py.File(fname, 'r')
        h5_s_no_icing.append(f)

        times = f['time']
        num_obs = len(times)
        lons = f['longitude']
        lats = f['latitude']
        flt_altitude = f['icing_altitude'][:]

        for i in range(num_obs):
            cld_msk = f[cld_mask_name][i, n_a:n_b, m_a:m_b]
            for ds_name in params:
                dat = f[ds_name][i, n_a:n_b, m_a:m_b]
                if L1B_or_L2 == 'L2':
                    keep = np.logical_or(cld_msk == 2, cld_msk == 3)  # cloudy
                    np.where(keep, dat, np.nan)
                no_icing_data_dct[ds_name].append(dat)
            num_no_ice += 1
            no_ice_time_s.append(times[i])
            no_ice_lon_s.append(lons[i])
            no_ice_lat_s.append(lats[i])
            no_ice_flt_alt_s.append(flt_altitude[i])

        print(fname)

    for ds_name in params:
        lst = no_icing_data_dct[ds_name]
        no_icing_data_dct[ds_name] = np.stack(lst, axis=0)
    no_icing_int_s = np.full(num_no_ice, -1)
    no_ice_time_s = np.array(no_ice_time_s)
    no_ice_lon_s = np.array(no_ice_lon_s)
    no_ice_lat_s = np.array(no_ice_lat_s)
    no_ice_flt_alt_s = np.array(no_ice_flt_alt_s)

    icing_intensity = np.concatenate([icing_int_s, no_icing_int_s])
    icing_times = np.concatenate([ice_time_s, no_ice_time_s])
    icing_lons = np.concatenate([ice_lon_s, no_ice_lon_s])
    icing_lats = np.concatenate([ice_lat_s, no_ice_lat_s])
    icing_alt = np.concatenate([ice_flt_alt_s, no_ice_flt_alt_s])

    data_dct = {}
    for ds_name in params:
        data_dct[ds_name] = np.concatenate([icing_data_dct[ds_name], no_icing_data_dct[ds_name]])

    # do sort -------------------------------------
    ds_indexes = np.argsort(icing_times)
    for ds_name in params:
        data_dct[ds_name] = data_dct[ds_name][ds_indexes]
    icing_intensity = icing_intensity[ds_indexes]
    icing_times = icing_times[ds_indexes]
    icing_lons = icing_lons[ds_indexes]
    icing_lats = icing_lats[ds_indexes]
    icing_alt = icing_alt[ds_indexes]

    if do_split:
        trn_idxs, vld_idxs, tst_idxs = split_data(icing_times, has_test)
        # Below for no test data, just train and valid
        if has_test:  # for no test data, just train and valid
            trn_idxs = np.concatenate([trn_idxs, tst_idxs])
    else:
        trn_idxs = np.arange(icing_intensity.shape[0])
        tst_idxs = None
    # ---------------------------------------------

    trn_data_dct = {}
    for ds_name in params:
        trn_data_dct[ds_name] = data_dct[ds_name][trn_idxs,]
    trn_icing_intensity = icing_intensity[trn_idxs,]
    trn_icing_times = icing_times[trn_idxs,]
    trn_icing_lons = icing_lons[trn_idxs,]
    trn_icing_lats = icing_lats[trn_idxs,]
    trn_icing_alt = icing_alt[trn_idxs,]

    #  Data augmentation -------------------------------------------------------------
    if augment:
        trn_data_dct_aug = {ds_name: [] for ds_name in params}
        trn_icing_intensity_aug = []
        trn_icing_times_aug = []
        trn_icing_lons_aug = []
        trn_icing_lats_aug = []
        trn_icing_alt_aug = []

        for k in range(trn_icing_intensity.shape[0]):
            iceint = trn_icing_intensity[k]
            icetime = trn_icing_times[k]
            icelon = trn_icing_lons[k]
            icelat = trn_icing_lats[k]
            icealt = trn_icing_alt[k]
            if iceint == 3 or iceint == 4 or iceint == 5 or iceint == 6:
                for ds_name in params:
                    dat = trn_data_dct[ds_name]
                    trn_data_dct_aug[ds_name].append(np.fliplr(dat[k,]))
                    trn_data_dct_aug[ds_name].append(np.flipud(dat[k,]))
                    trn_data_dct_aug[ds_name].append(np.rot90(dat[k,]))

                trn_icing_intensity_aug.append(iceint)
                trn_icing_intensity_aug.append(iceint)
                trn_icing_intensity_aug.append(iceint)

                trn_icing_times_aug.append(icetime)
                trn_icing_times_aug.append(icetime)
                trn_icing_times_aug.append(icetime)

                trn_icing_lons_aug.append(icelon)
                trn_icing_lons_aug.append(icelon)
                trn_icing_lons_aug.append(icelon)

                trn_icing_lats_aug.append(icelat)
                trn_icing_lats_aug.append(icelat)
                trn_icing_lats_aug.append(icelat)

                trn_icing_alt_aug.append(icealt)
                trn_icing_alt_aug.append(icealt)
                trn_icing_alt_aug.append(icealt)

        for ds_name in params:
            trn_data_dct_aug[ds_name] = np.stack(trn_data_dct_aug[ds_name])
        trn_icing_intensity_aug = np.stack(trn_icing_intensity_aug)
        trn_icing_times_aug = np.stack(trn_icing_times_aug)
        trn_icing_lons_aug = np.stack(trn_icing_lons_aug)
        trn_icing_lats_aug = np.stack(trn_icing_lats_aug)
        trn_icing_alt_aug = np.stack(trn_icing_alt_aug)

        for ds_name in params:
            trn_data_dct[ds_name] = np.concatenate([trn_data_dct[ds_name], trn_data_dct_aug[ds_name]])
        trn_icing_intensity = np.concatenate([trn_icing_intensity, trn_icing_intensity_aug])
        trn_icing_times = np.concatenate([trn_icing_times, trn_icing_times_aug])
        trn_icing_lons = np.concatenate([trn_icing_lons, trn_icing_lons_aug])
        trn_icing_lats = np.concatenate([trn_icing_lats, trn_icing_lats_aug])
        trn_icing_alt = np.concatenate([trn_icing_alt, trn_icing_alt_aug])

    # do sort
    ds_indexes = np.argsort(trn_icing_times)
    for ds_name in params:
        trn_data_dct[ds_name] = trn_data_dct[ds_name][ds_indexes]
    trn_icing_intensity = trn_icing_intensity[ds_indexes]
    trn_icing_times = trn_icing_times[ds_indexes]
    trn_icing_lons = trn_icing_lons[ds_indexes]
    trn_icing_lats = trn_icing_lats[ds_indexes]
    trn_icing_alt = trn_icing_alt[ds_indexes]

    write_file(trnfile, params, param_types, trn_data_dct, trn_icing_intensity, trn_icing_times, trn_icing_lons, trn_icing_lats, trn_icing_alt)

    if do_split:
        if has_test:
            tst_data_dct = {}
            for ds_name in params:
                tst_data_dct[ds_name] = data_dct[ds_name][tst_idxs,]
            tst_icing_intensity = icing_intensity[tst_idxs,]
            tst_icing_times = icing_times[tst_idxs,]
            tst_icing_lons = icing_lons[tst_idxs,]
            tst_icing_lats = icing_lats[tst_idxs,]
            tst_icing_alt = icing_alt[tst_idxs,]

            # do sort
            ds_indexes = np.argsort(tst_icing_times)
            for ds_name in params:
                tst_data_dct[ds_name] = tst_data_dct[ds_name][ds_indexes]
            tst_icing_intensity = tst_icing_intensity[ds_indexes]
            tst_icing_times = tst_icing_times[ds_indexes]
            tst_icing_lons = tst_icing_lons[ds_indexes]
            tst_icing_lats = tst_icing_lats[ds_indexes]
            tst_icing_alt = tst_icing_alt[ds_indexes]

            write_file(tstfile, params, param_types, tst_data_dct, tst_icing_intensity, tst_icing_times, tst_icing_lons, tst_icing_lats, tst_icing_alt)

        vld_data_dct = {}
        for ds_name in params:
            vld_data_dct[ds_name] = data_dct[ds_name][vld_idxs,]
        vld_icing_intensity = icing_intensity[vld_idxs,]
        vld_icing_times = icing_times[vld_idxs,]
        vld_icing_lons = icing_lons[vld_idxs,]
        vld_icing_lats = icing_lats[vld_idxs,]
        vld_icing_alt = icing_alt[vld_idxs,]

        # do sort
        ds_indexes = np.argsort(vld_icing_times)
        for ds_name in params:
            vld_data_dct[ds_name] = vld_data_dct[ds_name][ds_indexes]
        vld_icing_intensity = vld_icing_intensity[ds_indexes]
        vld_icing_times = vld_icing_times[ds_indexes]
        vld_icing_lons = vld_icing_lons[ds_indexes]
        vld_icing_lats = vld_icing_lats[ds_indexes]
        vld_icing_alt = vld_icing_alt[ds_indexes]

        write_file(vldfile, params, param_types, vld_data_dct, vld_icing_intensity, vld_icing_times, vld_icing_lons, vld_icing_lats, vld_icing_alt)

    # --- close files
    for h5f in h5_s_icing:
        h5f.close()

    for h5f in h5_s_no_icing:
        h5f.close()


def write_file(outfile, params, param_types, data_dct, icing_intensity, icing_times, icing_lons, icing_lats, icing_alt):
    h5f_expl = h5py.File(a_clvr_file, 'r')
    h5f_out = h5py.File(outfile, 'w')

    for idx, ds_name in enumerate(params):
        dt = param_types[idx]
        data = data_dct[ds_name]
        h5f_out.create_dataset(ds_name, data=data, dtype=dt)

    icing_int_ds = h5f_out.create_dataset('icing_intensity', data=icing_intensity, dtype='i4')
    icing_int_ds.attrs.create('long_name', data='From PIREP. -1:No Icing, 1:Trace, 2:Light, 3:Light Moderate, 4:Moderate, 5:Moderate Severe, 6:Severe')

    time_ds = h5f_out.create_dataset('time', data=icing_times, dtype='f4')
    time_ds.attrs.create('units', data='seconds since 1970-1-1 00:00:00')
    time_ds.attrs.create('long_name', data='PIREP time')

    lon_ds = h5f_out.create_dataset('longitude', data=icing_lons, dtype='f4')
    lon_ds.attrs.create('units', data='degrees_east')
    lon_ds.attrs.create('long_name', data='PIREP longitude')

    lat_ds = h5f_out.create_dataset('latitude', data=icing_lats, dtype='f4')
    lat_ds.attrs.create('units', data='degrees_north')
    lat_ds.attrs.create('long_name', data='PIREP latitude')

    alt_ds = h5f_out.create_dataset('flight_altitude', data=icing_alt, dtype='f4')
    alt_ds.attrs.create('units', data='meter')
    alt_ds.attrs.create('long_name', data='PIREP altitude')

    # copy relevant attributes
    for ds_name in params:
        h5f_ds = h5f_out[ds_name]
        try:
            h5f_ds.attrs.create('standard_name', data=h5f_expl[ds_name].attrs.get('standard_name'))
            h5f_ds.attrs.create('long_name', data=h5f_expl[ds_name].attrs.get('long_name'))
            h5f_ds.attrs.create('units', data=h5f_expl[ds_name].attrs.get('units'))
            attr = h5f_expl[ds_name].attrs.get('actual_range')
            if attr is not None:
                h5f_ds.attrs.create('actual_range', data=attr)
            attr = h5f_expl[ds_name].attrs.get('flag_values')
            if attr is not None:
                h5f_ds.attrs.create('flag_values', data=attr)
        except Exception:
            pass

    # --- close files
    h5f_out.close()
    h5f_expl.close()


def run_mean_std(check_cloudy=False):
    ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha',
               'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha',
               'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
               'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
               'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction']

    # ds_list = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
    #            'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
    #            'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']

    mean_std_dct = {}

    ice_flist = [f for f in glob.glob('/data/Personal/rink/icing/icing_2*.h5')]
    no_ice_flist = [f for f in glob.glob('/data/Personal/rink/icing/no_icing_2*.h5')]

    ice_h5f_lst = [h5py.File(f, 'r') for f in ice_flist]
    no_ice_h5f_lst = [h5py.File(f, 'r') for f in no_ice_flist]

    if check_cloudy:
        cld_msk_i = []
        cld_msk_ni = []
        for idx, ice_h5f in enumerate(ice_h5f_lst):
            no_ice_h5f = no_ice_h5f_lst[idx]
            cld_msk_i.append(ice_h5f['cloud_mask'][:,].flatten())
            cld_msk_ni.append(no_ice_h5f['cloud_mask'][:,].flatten())
        cld_msk_i = np.concatenate(cld_msk_i)
        cld_msk_ni = np.concatenate(cld_msk_ni)

    for dname in ds_list:
        data = []
        data_i = []
        data_ni = []
        for idx, ice_h5f in enumerate(ice_h5f_lst):
            no_ice_h5f = no_ice_h5f_lst[idx]
            data.append(ice_h5f[dname][:,].flatten())
            data.append(no_ice_h5f[dname][:,].flatten())

            data_i.append(ice_h5f[dname][:,].flatten())
            data_ni.append(no_ice_h5f[dname][:,].flatten())

        data = np.concatenate(data)
        mean = np.nanmean(data)
        data -= mean
        std = np.nanstd(data)

        data_i = np.concatenate(data_i)
        if check_cloudy:
            keep = np.logical_or(cld_msk_i == 2, cld_msk_i == 3)
            data_i = data_i[keep]
        mean_i = np.nanmean(data_i)
        lo_i = np.nanmin(data_i)
        hi_i = np.nanmax(data_i)
        data_i -= mean_i
        std_i = np.nanstd(data_i)
        cnt_i = np.sum(np.invert(np.isnan(data_i)))

        data_ni = np.concatenate(data_ni)
        if check_cloudy:
            keep = np.logical_or(cld_msk_ni == 2, cld_msk_ni == 3)
            data_ni = data_ni[keep]
        mean_ni = np.nanmean(data_ni)
        lo_ni = np.nanmin(data_ni)
        hi_ni = np.nanmax(data_ni)
        data_ni -= mean_ni
        std_ni = np.nanstd(data_ni)
        cnt_ni = np.sum(np.invert(np.isnan(data_ni)))

        no_icing_to_icing_ratio = cnt_ni/cnt_i

        mean = (mean_i + no_icing_to_icing_ratio*mean_ni)/(no_icing_to_icing_ratio + 1)
        std = (std_i + no_icing_to_icing_ratio*std_ni)/(no_icing_to_icing_ratio + 1)
        lo = (lo_i + no_icing_to_icing_ratio*lo_ni)/(no_icing_to_icing_ratio + 1)
        hi = (hi_i + no_icing_to_icing_ratio*hi_ni)/(no_icing_to_icing_ratio + 1)

        print(dname,': (', mean, mean_i, mean_ni, ') (', std, std_i, std_ni, ') ratio: ', no_icing_to_icing_ratio)
        print(dname,': (', lo, lo_i, lo_ni, ') (', hi, hi_i, hi_ni, ') ratio: ', no_icing_to_icing_ratio)

        mean_std_dct[dname] = (mean, std, lo, hi)

    [h5f.close() for h5f in ice_h5f_lst]
    [h5f.close() for h5f in no_ice_h5f_lst]

    f = open('/home/rink/data/icing_ml/mean_std_lo_hi.pkl', 'wb')
    pickle.dump(mean_std_dct, f)
    f.close()

    return mean_std_dct


def run_mean_std_2(check_cloudy=False, no_icing_to_icing_ratio=5, params=train_params_day):
    params = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_press_acha',
              'cld_reff_dcomp', 'cld_opd_dcomp', 'cld_cwp_dcomp', 'iwc_dcomp', 'lwc_dcomp']

    mean_std_dct = {}

    flist = [f for f in glob.glob('/Users/tomrink/data/icing/fov*.h5')]

    h5f_lst = [h5py.File(f, 'r') for f in flist]

    if check_cloudy:
        cld_msk = []
        for idx, h5f in enumerate(h5f_lst):
            cld_msk.append(h5f['cloud_mask'][:,].flatten())
        cld_msk = np.concatenate(cld_msk)

    for dname in params:
        data = []
        for idx, h5f in enumerate(h5f_lst):
            data.append(h5f[dname][:,].flatten())

        data = np.concatenate(data)

        if check_cloudy:
            keep = np.logical_or(cld_msk == 2, cld_msk == 3)
            data = data[keep]

        mean = np.nanmean(data)
        data -= mean
        std = np.nanstd(data)

        print(dname,': ', mean, std)

        mean_std_dct[dname] = (mean, std)

    [h5f.close() for h5f in h5f_lst]

    f = open('/Users/tomrink/data/icing/fovs_mean_std_day.pkl', 'wb')
    pickle.dump(mean_std_dct, f)
    f.close()


def run_mean_std_3(train_file_path, check_cloudy=False, params=train_params_day):

    # params = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha',
    #            'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
    #            'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
    #            'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction']
    #check_cloudy = True

    params = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
               'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
               'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']

    mean_std_lo_hi_dct = {}

    h5f = h5py.File(train_file_path, 'r')

    if check_cloudy:
        cld_msk = h5f['cloud_mask'][:].flatten()

    for dname in params:
        data = h5f[dname][:,].flatten()

        if check_cloudy:
            keep = np.logical_or(cld_msk == 2, cld_msk == 3)
            data = data[keep]

        lo = np.nanmin(data)
        hi = np.nanmax(data)
        mean = np.nanmean(data)
        data -= mean
        std = np.nanstd(data)

        print(dname,': ', mean, std, lo, hi)

        mean_std_lo_hi_dct[dname] = (mean, std, lo, hi)

    h5f.close()

    f = open('/Users/tomrink/data/icing/mean_std_lo_hi_test.pkl', 'wb')
    pickle.dump(mean_std_lo_hi_dct, f)
    f.close()


def split_data(times, has_test):
    time_idxs = np.arange(times.shape[0])

    # time_ranges = [[get_timestamp('2018-01-01_00:00'), get_timestamp('2018-01-07_23:59')],
    #                [get_timestamp('2018-03-01_00:00'), get_timestamp('2018-03-07_23:59')],
    #                [get_timestamp('2018-05-01_00:00'), get_timestamp('2018-05-07_23:59')],
    #                [get_timestamp('2018-07-01_00:00'), get_timestamp('2018-07-07_23:59')],
    #                [get_timestamp('2018-09-01_00:00'), get_timestamp('2018-09-07_23:59')],
    #                [get_timestamp('2018-11-01_00:00'), get_timestamp('2018-11-07_23:59')],
    #                [get_timestamp('2019-01-01_00:00'), get_timestamp('2019-01-07_23:59')],
    #                [get_timestamp('2019-03-01_00:00'), get_timestamp('2019-03-07_23:59')],
    #                [get_timestamp('2019-05-01_00:00'), get_timestamp('2019-05-07_23:59')],
    #                [get_timestamp('2019-07-01_00:00'), get_timestamp('2019-07-07_23:59')],
    #                [get_timestamp('2019-09-01_00:00'), get_timestamp('2019-09-07_23:59')],
    #                [get_timestamp('2019-11-01_00:00'), get_timestamp('2019-11-07_23:59')],
    #                [get_timestamp('2021-09-24_00:00'), get_timestamp('2021-10-01_23:59')],
    #                [get_timestamp('2021-11-01_00:00'), get_timestamp('2021-11-07_23:59')],
    #                [get_timestamp('2022-01-01_00:00'), get_timestamp('2022-01-07_23:59')],
    #                [get_timestamp('2022-03-01_00:00'), get_timestamp('2022-03-07_23:59')],
    #                [get_timestamp('2022-04-01_00:00'), get_timestamp('2022-04-04_23:59')]]

    time_ranges = [[get_timestamp('2018-01-01_00:00'), get_timestamp('2018-01-04_23:59')],
                   [get_timestamp('2018-02-01_00:00'), get_timestamp('2018-02-04_23:59')],
                   [get_timestamp('2018-03-01_00:00'), get_timestamp('2018-03-04_23:59')],
                   [get_timestamp('2018-04-01_00:00'), get_timestamp('2018-04-04_23:59')],
                   [get_timestamp('2018-05-01_00:00'), get_timestamp('2018-05-04_23:59')],
                   [get_timestamp('2018-06-01_00:00'), get_timestamp('2018-06-04_23:59')],
                   [get_timestamp('2018-07-01_00:00'), get_timestamp('2018-07-04_23:59')],
                   [get_timestamp('2018-08-01_00:00'), get_timestamp('2018-08-04_23:59')],
                   [get_timestamp('2018-09-01_00:00'), get_timestamp('2018-09-04_23:59')],
                   [get_timestamp('2018-10-01_00:00'), get_timestamp('2018-10-04_23:59')],
                   [get_timestamp('2018-11-01_00:00'), get_timestamp('2018-11-04_23:59')],
                   [get_timestamp('2018-12-01_00:00'), get_timestamp('2018-12-04_23:59')],
                   [get_timestamp('2019-01-01_00:00'), get_timestamp('2019-01-04_23:59')],
                   [get_timestamp('2019-02-01_00:00'), get_timestamp('2019-02-04_23:59')],
                   [get_timestamp('2019-03-01_00:00'), get_timestamp('2019-03-04_23:59')],
                   [get_timestamp('2019-04-01_00:00'), get_timestamp('2019-04-04_23:59')],
                   [get_timestamp('2019-05-01_00:00'), get_timestamp('2019-05-04_23:59')],
                   [get_timestamp('2019-06-01_00:00'), get_timestamp('2019-06-04_23:59')],
                   [get_timestamp('2019-07-01_00:00'), get_timestamp('2019-07-04_23:59')],
                   [get_timestamp('2019-08-01_00:00'), get_timestamp('2019-08-04_23:59')],
                   [get_timestamp('2019-09-01_00:00'), get_timestamp('2019-09-04_23:59')],
                   [get_timestamp('2019-10-01_00:00'), get_timestamp('2019-10-04_23:59')],
                   [get_timestamp('2019-11-01_00:00'), get_timestamp('2019-11-04_23:59')],
                   [get_timestamp('2019-12-01_00:00'), get_timestamp('2019-12-04_23:59')],
                   [get_timestamp('2021-12-04_00:00'), get_timestamp('2021-12-08_23:59')],
                   [get_timestamp('2022-01-01_00:00'), get_timestamp('2022-01-04_23:59')],
                   [get_timestamp('2022-02-01_00:00'), get_timestamp('2022-02-04_23:59')],
                   [get_timestamp('2022-03-01_00:00'), get_timestamp('2022-03-04_23:59')],
                   [get_timestamp('2022-04-01_00:00'), get_timestamp('2022-04-04_23:59')],
                   [get_timestamp('2022-05-01_00:00'), get_timestamp('2022-05-04_23:59')],
                   [get_timestamp('2022-10-03_00:00'), get_timestamp('2022-10-07_23:59')],
                   [get_timestamp('2022-11-01_00:00'), get_timestamp('2022-11-04_23:59')],
                   [get_timestamp('2022-12-01_00:00'), get_timestamp('2022-12-04_23:59')],
                   [get_timestamp('2023-12-01_00:00'), get_timestamp('2023-12-04_23:59')]]

    keep_out = 10800  # 3 hrs

    vld_time_idxs = []
    for t_rng in time_ranges:
        t_rng[0] -= keep_out
        t_rng[1] += keep_out
        tidxs = np.searchsorted(times, t_rng)
        vld_time_idxs.append(np.arange(tidxs[0], tidxs[1], 1))
    vld_time_idxs = np.concatenate(vld_time_idxs, axis=None)

    # time_ranges = [[get_timestamp('2018-02-01_00:00'), get_timestamp('2018-02-04_23:59')],
    #                [get_timestamp('2018-04-01_00:00'), get_timestamp('2018-04-04_23:59')],
    #                [get_timestamp('2018-06-01_00:00'), get_timestamp('2018-06-04_23:59')],
    #                [get_timestamp('2018-08-01_00:00'), get_timestamp('2018-08-04_23:59')],
    #                [get_timestamp('2018-10-01_00:00'), get_timestamp('2018-10-04_23:59')],
    #                [get_timestamp('2018-12-01_00:00'), get_timestamp('2018-12-04_23:59')],
    #                [get_timestamp('2019-02-01_00:00'), get_timestamp('2019-02-04_23:59')],
    #                [get_timestamp('2019-04-01_00:00'), get_timestamp('2019-04-04_23:59')],
    #                [get_timestamp('2019-06-01_00:00'), get_timestamp('2019-06-04_23:59')],
    #                [get_timestamp('2019-08-01_00:00'), get_timestamp('2019-08-04_23:59')],
    #                [get_timestamp('2019-10-01_00:00'), get_timestamp('2019-10-04_23:59')],
    #                [get_timestamp('2019-12-01_00:00'), get_timestamp('2019-12-04_23:59')],
    #                [get_timestamp('2021-10-05_00:00'), get_timestamp('2021-10-10_23:59')],
    #                [get_timestamp('2021-12-01_00:00'), get_timestamp('2021-12-04_23:59')],
    #                [get_timestamp('2022-02-01_00:00'), get_timestamp('2022-02-04_23:59')],
    #                [get_timestamp('2022-03-26_00:00'), get_timestamp('2022-03-30_23:59')],
    #                [get_timestamp('2022-04-07_00:00'), get_timestamp('2022-04-10_23:59')],
    #                [get_timestamp('2022-05-01_00:00'), get_timestamp('2022-05-06_23:59')],
    #                [get_timestamp('2022-10-03_00:00'), get_timestamp('2022-10-08_23:59')],
    #                [get_timestamp('2022-10-17_00:00'), get_timestamp('2022-10-22_20:59')],
    #                [get_timestamp('2022-11-01_00:00'), get_timestamp('2022-11-05_23:59')],
    #                [get_timestamp('2022-11-10_00:00'), get_timestamp('2022-12-01_23:59')]]

    time_ranges = [[get_timestamp('2018-01-11_00:00'), get_timestamp('2018-01-12_23:59')],
                   [get_timestamp('2018-02-11_00:00'), get_timestamp('2018-02-13_23:59')],
                   [get_timestamp('2018-03-11_00:00'), get_timestamp('2018-03-13_23:59')],
                   [get_timestamp('2018-04-11_00:00'), get_timestamp('2018-04-13_23:59')],
                   [get_timestamp('2018-05-11_00:00'), get_timestamp('2018-05-13_23:59')],
                   [get_timestamp('2018-06-11_00:00'), get_timestamp('2018-06-13_23:59')],
                   [get_timestamp('2018-07-11_00:00'), get_timestamp('2018-07-13_23:59')],
                   [get_timestamp('2018-08-11_00:00'), get_timestamp('2018-08-13_23:59')],
                   [get_timestamp('2018-09-11_00:00'), get_timestamp('2018-09-13_23:59')],
                   [get_timestamp('2018-10-11_00:00'), get_timestamp('2018-10-13_23:59')],
                   [get_timestamp('2018-11-11_00:00'), get_timestamp('2018-11-13_23:59')],
                   [get_timestamp('2018-12-11_00:00'), get_timestamp('2018-12-13_23:59')],
                   [get_timestamp('2019-01-11_00:00'), get_timestamp('2019-01-13_23:59')],
                   [get_timestamp('2019-02-11_00:00'), get_timestamp('2019-02-13_23:59')],
                   [get_timestamp('2019-03-11_00:00'), get_timestamp('2019-03-13_23:59')],
                   [get_timestamp('2019-04-11_00:00'), get_timestamp('2019-04-13_23:59')],
                   [get_timestamp('2019-05-11_00:00'), get_timestamp('2019-05-13_23:59')],
                   [get_timestamp('2019-06-11_00:00'), get_timestamp('2019-06-13_23:59')],
                   [get_timestamp('2019-07-11_00:00'), get_timestamp('2019-07-13_23:59')],
                   [get_timestamp('2019-08-11_00:00'), get_timestamp('2019-08-13_23:59')],
                   [get_timestamp('2019-09-11_00:00'), get_timestamp('2019-09-13_23:59')],
                   [get_timestamp('2019-10-11_00:00'), get_timestamp('2019-10-13_23:59')],
                   [get_timestamp('2019-11-11_00:00'), get_timestamp('2019-11-13_23:59')],
                   [get_timestamp('2019-12-11_00:00'), get_timestamp('2019-12-13_23:59')],
                   [get_timestamp('2021-12-11_00:00'), get_timestamp('2021-12-13_23:59')],
                   [get_timestamp('2022-01-11_00:00'), get_timestamp('2022-01-13_23:59')],
                   [get_timestamp('2022-02-11_00:00'), get_timestamp('2022-02-13_23:59')],
                   [get_timestamp('2022-03-11_00:00'), get_timestamp('2022-03-13_23:59')],
                   [get_timestamp('2022-04-11_00:00'), get_timestamp('2022-04-13_23:59')],
                   [get_timestamp('2022-05-11_00:00'), get_timestamp('2022-05-13_23:59')],
                   [get_timestamp('2022-10-11_00:00'), get_timestamp('2022-10-13_23:59')],
                   [get_timestamp('2022-11-11_00:00'), get_timestamp('2022-11-13_23:59')],
                   [get_timestamp('2022-12-11_00:00'), get_timestamp('2022-12-13_23:59')],
                   [get_timestamp('2023-12-11_00:00'), get_timestamp('2023-12-13_23:59')]]

    tst_time_idxs = []
    if has_test:
        for t_rng in time_ranges:
            t_rng[0] -= keep_out
            t_rng[1] += keep_out
            tidxs = np.searchsorted(times, t_rng)
            tst_time_idxs.append(np.arange(tidxs[0], tidxs[1], 1))
        tst_time_idxs = np.concatenate(tst_time_idxs, axis=None)

        vld_tst_time_idxs = np.concatenate([vld_time_idxs, tst_time_idxs])
        vld_tst_time_idxs = np.sort(vld_tst_time_idxs)

    vld_tst_time_idxs = np.sort(vld_time_idxs)

    train_time_idxs = time_idxs[np.in1d(time_idxs, vld_tst_time_idxs, invert=True)]

    return train_time_idxs, vld_time_idxs, tst_time_idxs


def normalize(data, param, mean_std_dict, add_noise=False, noise_scale=1.0, seed=None):

    if mean_std_dict.get(param) is None:
        return data

    shape = data.shape
    data = data.flatten()

    mean, std, lo, hi = mean_std_dict.get(param)
    data -= mean
    data /= std

    if add_noise:
        if seed is not None:
            np.random.seed(seed)
        rnd = np.random.normal(loc=0, scale=noise_scale, size=data.size)
        data += rnd

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


lon_space = np.linspace(-180, 180, 721)
lat_space = np.linspace(-90, 90, 361)


def spatial_filter(icing_dict):
    keys = icing_dict.keys()
    grd_x_hi = lon_space.shape[0] - 1
    grd_y_hi = lat_space.shape[0] - 1

    grd_bins = np.full((lat_space.shape[0], lon_space.shape[0]), 0)
    grd_bins_keys = [[[] for i in range(lon_space.shape[0])] for j in range(lat_space.shape[0])]

    for key in keys:
        rpts = icing_dict.get(key)
        for tup in rpts:
            lat = tup[0]
            lon = tup[1]

            lon_idx = np.searchsorted(lon_space, lon)
            lat_idx = np.searchsorted(lat_space, lat)

            if lon_idx < 0 or lon_idx > grd_x_hi:
                continue
            if lat_idx < 0 or lat_idx > grd_y_hi:
                continue

            grd_bins[lat_idx, lon_idx] += 1
            grd_bins_keys[lat_idx][lon_idx].append(key)

    return grd_bins, grd_bins_keys


def remove_common(boeing_dct, pirep_dct, threshold=3000):

    boeing_times = list(boeing_dct.keys())
    pirep_times = np.array(list(pirep_dct.keys()))

    pt_s = []
    bt_s = []
    bi_s = []
    for k, bt in enumerate(boeing_times):
        idx_s, v_s = get_indexes_within_threshold(pirep_times, bt, threshold=threshold)
        if len(idx_s) > 0:
            bt_s.append(bt)
            bi_s.append(k)
            pt_s.append(v_s[0])

    boeing_times = np.array(boeing_times)

    sub_pirep_dct = {}
    sub_boeing_dct = {}
    for key in pt_s:
        sub_pirep_dct[key] = pirep_dct.get(key)
    for key in bt_s:
        sub_boeing_dct[key] = boeing_dct.get(key)

    grd_bins, _ = spatial_filter(sub_pirep_dct)
    grd_bins_boeing, key_bins = spatial_filter(sub_boeing_dct)

    grd_bins = np.where(grd_bins > 0, 1, grd_bins)
    grd_bins_boeing = np.where(grd_bins_boeing > 0, 1, grd_bins_boeing)

    ovrlp_grd_bins = grd_bins + grd_bins_boeing

    ovlp_keys = []
    for j in range(lat_space.shape[0]):
        for i in range(lon_space.shape[0]):
            if ovrlp_grd_bins[j, i] == 2:
                keys = key_bins[j][i]
                nkeys = len(keys)
                for k in range(nkeys):
                    ovlp_keys.append(keys[k])
    set_a = set(ovlp_keys)
    set_b = set(boeing_times)
    set_b.difference_update(set_a)

    no_ovlp_dct = {}
    for key in set_b:
        no_ovlp_dct[key] = boeing_dct.get(key)

    return no_ovlp_dct


# dt_str_0: start datetime string in format YYYY-MM-DD_HH:MM (default)
# dt_str_1: end datetime string in format YYYY-MM-DD_HH:MM (default)
# format_code: Python Datetime format code, default: '%Y-%m-%d_%H:%M'
# return a flatten list of icing reports
def time_filter(icing_dct, dt_str_0=None, dt_str_1=None, format_code='%Y-%m-%d_%H:%M'):
    ts_0 = 0
    if dt_str_0 is not None:
        dto_0 = datetime.datetime.strptime(dt_str_0, format_code).replace(tzinfo=timezone.utc)
        ts_0 = dto_0.timestamp()

    ts_1 = np.finfo(np.float64).max
    if dt_str_1 is not None:
        dto_1 = datetime.datetime.strptime(dt_str_1, format_code).replace(tzinfo=timezone.utc)
        ts_1 = dto_1.timestamp()

    keep_reports = []
    keep_times = []
    keep_lons = []
    keep_lats = []

    for ts in list(icing_dct.keys()):
        if ts_0 <= ts < ts_1:
            rpts = icing_dct[ts]
            for idx, tup in enumerate(rpts):
                keep_reports.append(tup)
                keep_times.append(ts)
                keep_lats.append(tup[0])
                keep_lons.append(tup[1])

    return keep_times, keep_lons, keep_lats, keep_reports


# dt_str_0: start datetime string in format YYYY-MM-DD_HH:MM (default)
# dt_str_1: end datetime string in format YYYY-MM-DD_HH:MM (default)
# format_code: Python Datetime format code, default: '%Y-%m-%d_%H:%M'
# return a flatten list of icing reports
def time_filter_2(times, dt_str_0=None, dt_str_1=None, format_code='%Y-%m-%d_%H:%M'):
    ts_0 = 0
    if dt_str_0 is not None:
        dto_0 = datetime.datetime.strptime(dt_str_0, format_code).replace(tzinfo=timezone.utc)
        ts_0 = dto_0.timestamp()

    ts_1 = np.finfo(np.float64).max
    if dt_str_1 is not None:
        dto_1 = datetime.datetime.strptime(dt_str_1, format_code).replace(tzinfo=timezone.utc)
        ts_1 = dto_1.timestamp()

    keep_idxs = []
    keep_times = []

    for idx, ts in enumerate(times):
        if ts_0 <= ts < ts_1:
            keep_times.append(ts)
            keep_idxs.append(idx)

    return keep_times, keep_idxs


def time_filter_3(icing_dct, ts_0, ts_1, alt_lo=None, alt_hi=None):
    keep_reports = []
    keep_times = []
    keep_lons = []
    keep_lats = []

    for ts in list(icing_dct.keys()):
        if ts_0 <= ts < ts_1:
            rpts = icing_dct[ts]
            for idx, tup in enumerate(rpts):
                falt = tup[2]
                if alt_lo is not None and (alt_lo < falt <= alt_hi):
                    keep_reports.append(tup)
                    keep_times.append(ts)
                    keep_lats.append(tup[0])
                    keep_lons.append(tup[1])

    return keep_times, keep_lons, keep_lats, keep_reports


def analyze_moon_phase(icing_dict):
    ts = api.load.timescale()
    eph = api.load('de421.bsp')

    last_date = None
    moon_phase = None
    cnt = 0

    for key in list(icing_dict.keys()):
        dt_obj, dt_tup = get_time_tuple_utc(key)
        date = datetime.date(dt_tup.tm_year, dt_tup.tm_mon, dt_tup.tm_mday)
        if last_date != date:
            t = ts.utc(dt_tup.tm_year, dt_tup.tm_mon, dt_tup.tm_mday)
            moon_phase = almanac.moon_phase(eph, t)
            if 30 < moon_phase.degrees < 330:
                cnt += 1
            last_date = date
        else:
            if 30 < moon_phase.degrees < 330:
                cnt += 1

    print(len(icing_dict), cnt)


def tiles_info(filename):
    h5f = h5py.File(filename, 'r')
    iint = h5f['icing_intensity'][:]
    print('No Icing: ', np.sum(iint == -1))
    print('Icing:    ', np.sum(iint > 0))
    print('Icing 1:  ', np.sum(iint == 1))
    print('Icing 2:  ', np.sum(iint == 2))
    print('Icing 3:  ', np.sum(iint == 3))
    print('Icing 4:  ', np.sum(iint == 4))
    print('Icing 5:  ', np.sum(iint == 5))
    print('Icing 6:  ', np.sum(iint == 6))


def analyze(preds_file, labels, prob_avg, test_file):

    if preds_file is not None:
        labels, prob_avg = np.load(preds_file, allow_pickle=True)

    h5f = h5py.File(test_file, 'r')
    nda = h5f['flight_altitude'][:]
    num_obs = nda.shape[0]
    iint = h5f['icing_intensity'][:]
    cld_hgt = h5f['cld_height_acha'][:]
    cld_dz = h5f['cld_geo_thick'][:]
    cld_tmp = h5f['cld_temp_acha'][:]
    # cld_top_hgt = cld_hgt.reshape((num_obs, -1))

    print('report altitude (m): ', np.histogram(nda, bins=12))

    flt_alt = nda.copy()

    iint = np.where(iint == -1, 0, iint)
    iint = np.where(iint != 0, 1, iint)

    nda[np.logical_and(nda >= 100, nda < 4000)] = 0
    nda[np.logical_and(nda >= 4000, nda < 15000)] = 1
    nda[np.logical_and(nda >= 15000, nda < 20000)] = 2
    nda[np.logical_and(nda >= 20000, nda < 25000)] = 3
    # nda[np.logical_and(nda >= 8000, nda < 15000)] = 4

    print(np.sum(nda == 0), np.sum(nda == 1), np.sum(nda == 2))
    print('No icing:       ', np.sum((iint == 0) & (nda == 0)), np.sum((iint == 0) & (nda == 1)), np.sum((iint == 0) & (nda == 2)), np.sum((iint == 0) & (nda == 3)))
    print('---------------------------')
    print('Icing:          ', np.sum((iint == 1) & (nda == 0)), np.sum((iint == 1) & (nda == 1)), np.sum((iint == 1) & (nda == 2)), np.sum((iint == 1) & (nda == 3)))
    print('---------------------------')

    print('----------------------------------------------------------')
    print('----------------------------------------------------------')

    if prob_avg is None:
        return

    preds = np.where(prob_avg > 0.5, 1, 0)

    num_0 = np.sum(nda == 0)
    num_1 = np.sum(nda == 1)
    num_2 = np.sum(nda == 2)
    num_3 = np.sum(nda == 3)

    true_ice = (labels == 1) & (preds == 1)
    false_ice = (labels == 0) & (preds == 1)

    true_no_ice = (labels == 0) & (preds == 0)
    false_no_ice = (labels == 1) & (preds == 0)

    tp_0 = np.sum(true_ice & (nda == 0))
    tp_1 = np.sum(true_ice & (nda == 1))
    tp_2 = np.sum(true_ice & (nda == 2))
    tp_3 = np.sum(true_ice & (nda == 3))

    tn_0 = np.sum(true_no_ice & (nda == 0))
    tn_1 = np.sum(true_no_ice & (nda == 1))
    tn_2 = np.sum(true_no_ice & (nda == 2))
    tn_3 = np.sum(true_no_ice & (nda == 3))

    fp_0 = np.sum(false_ice & (nda == 0))
    fp_1 = np.sum(false_ice & (nda == 1))
    fp_2 = np.sum(false_ice & (nda == 2))
    fp_3 = np.sum(false_ice & (nda == 3))

    fn_0 = np.sum(false_no_ice & (nda == 0))
    fn_1 = np.sum(false_no_ice & (nda == 1))
    fn_2 = np.sum(false_no_ice & (nda == 2))
    fn_3 = np.sum(false_no_ice & (nda == 3))

    recall_0 = tp_0 / (tp_0 + fn_0)
    recall_1 = tp_1 / (tp_1 + fn_1)
    recall_2 = tp_2 / (tp_2 + fn_2)
    recall_3 = tp_3 / (tp_3 + fn_3)

    precision_0 = tp_0 / (tp_0 + fp_0)
    precision_1 = tp_1 / (tp_1 + fp_1)
    precision_2 = tp_2 / (tp_2 + fp_2)
    precision_3 = tp_3 / (tp_3 + fp_3)

    f1_0 = 2 * (precision_0 * recall_0) / (precision_0 + recall_0)
    f1_1 = 2 * (precision_1 * recall_1) / (precision_1 + recall_1)
    f1_2 = 2 * (precision_2 * recall_2) / (precision_2 + recall_2)
    f1_3 = 2 * (precision_3 * recall_3) / (precision_3 + recall_3)

    mcc_0 = ((tp_0 * tn_0) - (fp_0 * fn_0)) / np.sqrt((tp_0 + fp_0) * (tp_0 + fn_0) * (tn_0 + fp_0) * (tn_0 + fn_0))
    mcc_1 = ((tp_1 * tn_1) - (fp_1 * fn_1)) / np.sqrt((tp_1 + fp_1) * (tp_1 + fn_1) * (tn_1 + fp_1) * (tn_1 + fn_1))
    mcc_2 = ((tp_2 * tn_2) - (fp_2 * fn_2)) / np.sqrt((tp_2 + fp_2) * (tp_2 + fn_2) * (tn_2 + fp_2) * (tn_2 + fn_2))
    mcc_3 = ((tp_3 * tn_3) - (fp_3 * fn_3)) / np.sqrt((tp_3 + fp_3) * (tp_3 + fn_3) * (tn_3 + fp_3) * (tn_3 + fn_3))

    acc_0 = np.sum(labels[nda == 0] == preds[nda == 0]) / num_0
    acc_1 = np.sum(labels[nda == 1] == preds[nda == 1]) / num_1
    acc_2 = np.sum(labels[nda == 2] == preds[nda == 2]) / num_2
    acc_3 = np.sum(labels[nda == 3] == preds[nda == 3]) / num_3

    print('Total (Positive/Icing Prediction: ')
    print('True icing:                         ', np.sum(true_ice & (nda == 0)), np.sum(true_ice & (nda == 1)), np.sum(true_ice & (nda == 2)), np.sum(true_ice & (nda == 3)))
    print('-------------------------')
    print('False no icing (False Negative/Miss): ', np.sum(false_no_ice & (nda == 0)), np.sum(false_no_ice & (nda == 1)), np.sum(false_no_ice & (nda == 2)), np.sum(false_no_ice & (nda == 3)))
    print('---------------------------------------------------')
    print('---------------------------------------------------')

    print('Total (Negative/No Icing Prediction: ')
    print('True no icing:                             ', np.sum(true_no_ice & (nda == 0)), np.sum(true_no_ice & (nda == 1)), np.sum(true_no_ice & (nda == 2)), np.sum(true_no_ice & (nda == 3)))
    print('-------------------------')
    print('* False icing (False Positive/False Alarm) *:  ', np.sum(false_ice & (nda == 0)), np.sum(false_ice & (nda == 1)), np.sum(false_ice & (nda == 2)), np.sum(false_ice & (nda == 3)))
    print('-------------------------')

    print('ACC:         ', acc_0, acc_1, acc_2, acc_3)
    print('Recall:      ', recall_0, recall_1, recall_2, recall_3)
    print('Precision:   ', precision_0, precision_1, precision_2, precision_3)
    print('F1:          ', f1_0, f1_1, f1_2, f1_3)
    print('MCC:         ', mcc_0, mcc_1, mcc_2, mcc_3)

    return labels[nda == 0], prob_avg[nda == 0], labels[nda == 1], prob_avg[nda == 1], labels[nda == 2], prob_avg[nda == 2]


def get_training_parameters(day_night='DAY', l1b_andor_l2='both'):
    if day_night == 'DAY':
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']

        train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                            'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
                            'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
    else:
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_acha', 'cld_opd_acha']

        train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                            'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom']

    if l1b_andor_l2 == 'both':
        train_params = train_params_l1b + train_params_l2
    elif l1b_andor_l2 == 'l1b':
        train_params = train_params_l1b
    elif l1b_andor_l2 == 'l2':
        train_params = train_params_l2

    return train_params


flt_level_ranges = {k: None for k in range(5)}
flt_level_ranges[0] = [0.0, 2000.0]
flt_level_ranges[1] = [2000.0, 4000.0]
flt_level_ranges[2] = [4000.0, 6000.0]
flt_level_ranges[3] = [6000.0, 8000.0]
flt_level_ranges[4] = [8000.0, 15000.0]