Skip to content
Snippets Groups Projects
Select Git revision
  • 5b10ce73aa7062c7af5093b628a64f591333aded
  • master default protected
  • use_flight_altitude
  • distribute
4 results

pirep_goes.py

Blame
  • user avatar
    tomrink authored
    7490d84b
    History
    pirep_goes.py 76.31 KiB
    from icing.pireps import pirep_icing
    import numpy as np
    import pickle
    import matplotlib.pyplot as plt
    import os
    from util.util import get_time_tuple_utc, GenericException, add_time_range_to_filename, is_night, is_day, \
        check_oblique, make_times, find_bin_index, get_timestamp, homedir, write_icing_file, make_for_full_domain_predict, \
        make_for_full_domain_predict2
    from util.plot import make_icing_image
    from util.geos_nav import get_navigation, get_lon_lat_2d_mesh
    from util.setup import model_path_day, model_path_night
    from aeolus.datasource import CLAVRx, CLAVRx_VIIRS, GOESL1B, CLAVRx_H08
    import h5py
    import re
    import datetime
    from datetime import timezone
    import glob
    from skyfield import api, almanac
    from deeplearning.icing_cnn import run_evaluate_static_avg, run_evaluate_static
    
    goes_date_format = '%Y%j%H'
    goes16_directory = '/arcdata/goes/grb/goes16'  # /year/date/abi/L1b/RadC
    clavrx_dir = '/ships19/cloud/scratch/ICING/'
    #clavrx_dir = '/data/Personal/rink/clavrx/'
    clavrx_viirs_dir = '/apollo/cloud/scratch/Satellite_Output/NASA-SNPP_VIIRS/global/2019_DNB_for_Rink_wDBfix/level2_h5/'
    clavrx_test_dir = '/data/Personal/rink/clavrx/'
    dir_fmt = '%Y_%m_%d_%j'
    # dir_list = [f.path for f in os.scandir('.') if f.is_dir()]
    ds_dct = {}
    goes_ds_dct = {}
    
    
    # --- CLAVRx Radiometric parameters and metadata ------------------------------------------------
    l1b_ds_list = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                   'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
                   'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
    l1b_ds_types = ['f4' for ds in l1b_ds_list]
    l1b_ds_fill = [-32767 for i in range(10)] + [-32768 for i in range(5)]
    l1b_ds_range = ['actual_range' for ds in l1b_ds_list]
    
    # --- CLAVRx L2 parameters and metadata
    ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'sensor_zenith_angle', 'supercooled_prob_acha',
               'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha', 'solar_zenith_angle',
               'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
               'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
               'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction', 'cloud_type', 'cloud_phase', 'cloud_mask']
    ds_types = ['f4' for i in range(23)] + ['i1' for i in range(3)]
    ds_fill = [-32768 for i in range(23)] + [-128 for i in range(3)]
    ds_range = ['actual_range' for i in range(23)] + [None for i in range(3)]
    # --------------------------------------------
    # --- CLAVRx VIIRS L2 parameters and metadata
    # ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'sensor_zenith_angle', 'supercooled_prob_acha',
    #            'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha', 'solar_zenith_angle',
    #            'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
    #            'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
    #            'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction', 'cld_opd_nlcomp', 'cld_reff_nlcomp', 'cloud_type', 'cloud_phase', 'cloud_mask']
    # ds_types = ['f4' for i in range(25)] + ['i1' for i in range(3)]
    # ds_fill = [-32768 for i in range(25)] + [-128 for i in range(3)]
    # ds_range = ['actual_range' for i in range(25)] + [None for i in range(3)]
    # --------------------------------------------
    
    
    # An example file for accessing and copying metadata
    a_clvr_file = homedir+'data/clavrx/clavrx_OR_ABI-L1b-RadC-M3C01_G16_s20190020002186.level2.nc'
    # VIIRS
    #a_clvr_file = homedir+'data/clavrx/clavrx_snpp_viirs.A2019071.0000.001.2019071061610.uwssec_B00038187.level2.h5'
    
    # Location of files for tile/FOV extraction
    # icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/icing_2*_DAY.h5')]
    icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/icing_l1b_2*_DAY.h5')]
    # icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/icing_l1b_2*_ANY.h5')]
    icing_l1b_files = []
    
    # no_icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/no_icing_2*_DAY.h5')]
    no_icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/no_icing_l1b_2*_DAY.h5')]
    # no_icing_files = [f for f in glob.glob('/data/Personal/rink/icing_ml/no_icing_l1b_2*_ANY.h5')]
    no_icing_l1b_files = []
    
    
    train_params_day = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_press_acha',
                        'solar_zenith_angle', 'cld_reff_dcomp', 'cld_opd_dcomp', 'cld_cwp_dcomp', 'iwc_dcomp', 'lwc_dcomp',
                        'cloud_phase', 'cloud_mask']
    
    train_params_night = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha',
                          'cld_press_acha', 'cld_reff_acha', 'cld_opd_acha', 'cloud_phase', 'cloud_mask']
    
    
    def setup(pirep_file=homedir+'data/pirep/pireps_20180101_20200331.csv'):
        ice_dict, no_ice_dict, neg_ice_dict = pirep_icing(pirep_file)
        return ice_dict, no_ice_dict, neg_ice_dict
    
    
    def get_clavrx_datasource(timestamp, platform):
        if platform == 'GOES':
            return get_clavrx_datasource_goes(timestamp)
        elif platform == 'VIIRS':
            return get_clavrx_datasource_viirs(timestamp)
    
    
    def get_clavrx_datasource_goes(timestamp):
        dt_obj, time_tup = get_time_tuple_utc(timestamp)
        date_dir_str = dt_obj.strftime(dir_fmt)
        ds = ds_dct.get(date_dir_str)
        if ds is None:
            ds = CLAVRx(clavrx_dir + date_dir_str + '/')
            ds_dct[date_dir_str] = ds
        return ds
    
    
    def get_clavrx_datasource_viirs(timestamp):
        dt_obj, time_tup = get_time_tuple_utc(timestamp)
        date_dir_str = dt_obj.strftime('%j')
        ds = ds_dct.get(date_dir_str)
        if ds is None:
            ds = CLAVRx_VIIRS(clavrx_viirs_dir + date_dir_str + '/')
            ds_dct[date_dir_str] = ds
        return ds
    
    
    def get_goes_datasource(timestamp):
        dt_obj, time_tup = get_time_tuple_utc(timestamp)
    
        yr_dir = str(dt_obj.timetuple().tm_year)
        date_dir = dt_obj.strftime(dir_fmt)
    
        files_path = goes16_directory + '/' + yr_dir + '/' + date_dir + '/abi' + '/L1b' + '/RadC/'
        ds = goes_ds_dct.get(date_dir)
        if ds is None:
            ds = GOESL1B(files_path)
            goes_ds_dct[date_dir] = ds
        return ds
    
    
    def get_grid_values(h5f, grid_name, j_c, i_c, half_width, num_j=None, num_i=None, scale_factor_name='scale_factor', add_offset_name='add_offset',
                        fill_value_name='_FillValue', range_name='actual_range', fill_value=None):
        hfds = h5f[grid_name]
        attrs = hfds.attrs
        if attrs is None:
            raise GenericException('No attributes object for: '+grid_name)
    
        ylen, xlen = hfds.shape
    
        if half_width is not None:
            j_l = j_c-half_width
            i_l = i_c-half_width
            if j_l < 0 or i_l < 0:
                return None
    
            j_r = j_c+half_width+1
            i_r = i_c+half_width+1
            if j_r >= ylen or i_r >= xlen:
                return None
        else:
            j_l = j_c
            j_r = j_c + num_j + 1
            i_l = i_c
            i_r = i_c + num_i + 1
    
        grd_vals = hfds[j_l:j_r, i_l:i_r]
        if fill_value is not None:
            grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)
    
        if scale_factor_name is not None:
            attr = attrs.get(scale_factor_name)
            if attr is None:
                raise GenericException('Attribute: '+scale_factor_name+' not found for dataset: '+grid_name)
            if np.isscalar(attr):
                scale_factor = attr
            else:
                scale_factor = attr[0]
            grd_vals = grd_vals * scale_factor
    
        if add_offset_name is not None:
            attr = attrs.get(add_offset_name)
            if attr is None:
                raise GenericException('Attribute: '+add_offset_name+' not found for dataset: '+grid_name)
            if np.isscalar(attr):
                add_offset = attr
            else:
                add_offset = attr[0]
            grd_vals = grd_vals + add_offset
    
        if range_name is not None:
            attr = attrs.get(range_name)
            if attr is None:
                raise GenericException('Attribute: '+range_name+' not found for dataset: '+grid_name)
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)
        elif fill_value_name is not None:
            attr = attrs.get(fill_value_name)
            if attr is None:
                raise GenericException('Attribute: '+fill_value_name+' not found for dataset: '+grid_name)
            if np.isscalar(attr):
                fill_value = attr
            else:
                fill_value = attr[0]
            grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)
    
        return grd_vals
    
    
    def create_file(filename, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, icing_intensity, unq_ids):
        h5f_expl = h5py.File(a_clvr_file, 'r')
        h5f = h5py.File(filename, 'w')
    
        for idx, ds_name in enumerate(ds_list):
            data = data_dct[ds_name]
            h5f.create_dataset(ds_name, data=data, dtype=ds_types[idx])
    
        lon_ds = h5f.create_dataset('longitude', data=lon_c, dtype='f4')
        lon_ds.dims[0].label = 'time'
        lon_ds.attrs.create('units', data='degrees_east')
        lon_ds.attrs.create('long_name', data='PIREP longitude')
    
        lat_ds = h5f.create_dataset('latitude', data=lat_c, dtype='f4')
        lat_ds.dims[0].label = 'time'
        lat_ds.attrs.create('units', data='degrees_north')
        lat_ds.attrs.create('long_name', data='PIREP latitude')
    
        time_ds = h5f.create_dataset('time', data=time_s)
        time_ds.dims[0].label = 'time'
        time_ds.attrs.create('units', data='seconds since 1970-1-1 00:00:00')
        time_ds.attrs.create('long_name', data='PIREP time')
    
        ice_alt_ds = h5f.create_dataset('icing_altitude', data=fl_alt_s, dtype='f4')
        ice_alt_ds.dims[0].label = 'time'
        ice_alt_ds.attrs.create('units', data='m')
        ice_alt_ds.attrs.create('long_name', data='PIREP altitude')
    
        if icing_intensity is not None:
            icing_int_ds = h5f.create_dataset('icing_intensity', data=icing_intensity, dtype='i4')
            icing_int_ds.attrs.create('long_name', data='From PIREP. 0:No intensity report, 1:Trace, 2:Light, 3:Light Moderate, 4:Moderate, 5:Moderate Severe, 6:Severe')
    
        unq_ids_ds = h5f.create_dataset('unique_id', data=unq_ids, dtype='i4')
        unq_ids_ds.attrs.create('long_name', data='ID mapping to PIREP icing dictionary: see pireps.py')
    
        # copy relevant attributes
        for ds_name in ds_list:
            h5f_ds = h5f[ds_name]
            h5f_ds.attrs.create('standard_name', data=h5f_expl[ds_name].attrs.get('standard_name'))
            h5f_ds.attrs.create('long_name', data=h5f_expl[ds_name].attrs.get('long_name'))
            h5f_ds.attrs.create('units', data=h5f_expl[ds_name].attrs.get('units'))
    
            h5f_ds.dims[0].label = 'time'
            h5f_ds.dims[1].label = 'y'
            h5f_ds.dims[2].label = 'x'
    
        h5f.close()
        h5f_expl.close()
    
    
    def run(pirep_dct, platform, outfile=None, outfile_l1b=None, dt_str_start=None, dt_str_end=None):
        time_keys = list(pirep_dct.keys())
        l1b_grd_dct = {name: [] for name in l1b_ds_list}
        ds_grd_dct = {name: [] for name in ds_list}
    
        t_start = None
        t_end = None
        if (dt_str_start is not None) and (dt_str_end is not None):
            dto = datetime.datetime.strptime(dt_str_start, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
            dto.replace(tzinfo=timezone.utc)
            t_start = dto.timestamp()
    
            dto = datetime.datetime.strptime(dt_str_end, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
            dto.replace(tzinfo=timezone.utc)
            t_end = dto.timestamp()
    
        #nav = GEOSNavigation(sub_lon=-75.0, CFAC=5.6E-05, COFF=-0.101332, LFAC=-5.6E-05, LOFF=0.128212, num_elems=2500, num_lines=1500)
    
        lon_s = np.zeros(1)
        lat_s = np.zeros(1)
        last_clvr_file = None
        last_h5f = None
        nav = None
    
        lon_c = []
        lat_c = []
        time_s = []
        fl_alt_s = []
        ice_int_s = []
        unq_ids = []
    
        for idx, time in enumerate(time_keys):
            if t_start is not None:
                if time < t_start:
                    continue
                if time > t_end:
                    continue
    
            try:
                clvr_ds = get_clavrx_datasource(time, platform)
            except Exception:
                print('run: Problem retrieving Datasource')
                continue
    
            clvr_file = clvr_ds.get_file(time)[0]
            if clvr_file is None:
                continue
    
            if clvr_file != last_clvr_file:
                try:
                    h5f = h5py.File(clvr_file, 'r')
                    nav = clvr_ds.get_navigation(h5f)
                except Exception:
                    if h5f is not None:
                        h5f.close()
                    print('Problem with file: ', clvr_file)
                    continue
                if last_h5f is not None:
                    last_h5f.close()
                last_h5f = h5f
                last_clvr_file = clvr_file
            else:
                h5f = last_h5f
    
            cc = ll = -1
            reports = pirep_dct[time]
            for tup in reports:
                lat, lon, fl, I, uid, rpt_str = tup
                lat_s[0] = lat
                lon_s[0] = lon
    
                cc_a, ll_a = nav.earth_to_lc_s(lon_s, lat_s)  # non-navigable, skip
                if cc_a[0] < 0 or ll_a[0] < 0:
                    continue
    
                if cc_a[0] == cc and ll_a[0] == ll:  # time adjacent duplicate, skip
                    continue
                else:
                    cc = cc_a[0]
                    ll = ll_a[0]
    
                cnt_a = 0
                for didx, ds_name in enumerate(ds_list):
                    gvals = get_grid_values(h5f, ds_name, ll_a[0], cc_a[0], 20, fill_value_name=None, range_name=ds_range[didx], fill_value=ds_fill[didx])
                    if gvals is not None:
                        ds_grd_dct[ds_name].append(gvals)
                        cnt_a += 1
    
                cnt_b = 0
                for didx, ds_name in enumerate(l1b_ds_list):
                    gvals = get_grid_values(h5f, ds_name, ll_a[0], cc_a[0], 20, fill_value_name=None, range_name=l1b_ds_range[didx], fill_value=l1b_ds_fill[didx])
                    if gvals is not None:
                        l1b_grd_dct[ds_name].append(gvals)
                        cnt_b += 1
    
                if cnt_a > 0 and cnt_a != len(ds_list):
                    raise GenericException('weirdness')
                if cnt_b > 0 and cnt_b != len(l1b_ds_list):
                    raise GenericException('weirdness')
    
                if cnt_a == len(ds_list) and cnt_b == len(l1b_ds_list):
                    lon_c.append(lon_s[0])
                    lat_c.append(lat_s[0])
                    time_s.append(time)
                    fl_alt_s.append(fl)
                    ice_int_s.append(I)
                    unq_ids.append(uid)
    
        if len(time_s) == 0:
            return
    
        t_start = time_s[0]
        t_end = time_s[len(time_s)-1]
    
        data_dct = {}
        for ds_name in ds_list:
            data_dct[ds_name] = np.array(ds_grd_dct[ds_name])
        lon_c = np.array(lon_c)
        lat_c = np.array(lat_c)
        time_s = np.array(time_s)
        fl_alt_s = np.array(fl_alt_s)
        ice_int_s = np.array(ice_int_s)
        unq_ids = np.array(unq_ids)
    
        if outfile is not None:
            outfile = add_time_range_to_filename(outfile, t_start, t_end)
            create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)
    
        data_dct = {}
        for ds_name in l1b_ds_list:
            data_dct[ds_name] = np.array(l1b_grd_dct[ds_name])
    
        if outfile_l1b is not None:
            outfile_l1b = add_time_range_to_filename(outfile_l1b, t_start, t_end)
            create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)
    
    
    def pirep_info(pirep_dct):
        time_keys = list(pirep_dct.keys())
    
        lat_s = []
        lon_s = []
        flt_lvl_s = []
        ice_intensity_s = []
        for tkey in time_keys:
            reports = pirep_dct[tkey]
            for tup in reports:
                lat, lon, fl, I, uid, rpt_str = tup
                lat_s.append(lat)
                lon_s.append(lon)
                flt_lvl_s.append(fl)
                ice_intensity_s.append(I)
    
        lat_s = np.array(lat_s)
        lon_s = np.array(lon_s)
        flt_lvl_s = np.array(flt_lvl_s)
        ice_intensity_s = np.array(ice_intensity_s)
    
        return flt_lvl_s, ice_intensity_s, lat_s, lon_s
    
    
    def analyze(ice_dct, no_ice_dct):
    
        last_file = None
        ice_files = []
        ice_times = []
        for ts in list(ice_dct.keys()):
            try:
                ds = get_goes_datasource(ts)
                goes_file, t_0, _ = ds.get_file(ts)
                if goes_file is not None and goes_file != last_file:
                    ice_files.append(goes_file)
                    ice_times.append(t_0)
                    last_file = goes_file
            except Exception:
                continue
    
        last_file = None
        no_ice_files = []
        no_ice_times = []
        for ts in list(no_ice_dct.keys()):
            try:
                ds = get_goes_datasource(ts)
                goes_file, t_0, _ = ds.get_file(ts)
                if goes_file is not None and goes_file != last_file:
                    no_ice_files.append(goes_file)
                    no_ice_times.append(t_0)
                    last_file = goes_file
            except Exception:
                continue
    
        ice_times = np.array(ice_times)
        no_ice_times = np.array(no_ice_times)
    
        itrsct_vals, comm1, comm2 = np.intersect1d(no_ice_times, ice_times, return_indices=True)
    
        ice_indexes = np.arange(len(ice_times))
    
        ucomm2 = np.setxor1d(comm2, ice_indexes)
        np.random.seed(42)
        np.random.shuffle(ucomm2)
        ucomm2 = ucomm2[0:8000]
    
        files_comm = []
        for i in comm2:
            files_comm.append(ice_files[i])
    
        files_extra = []
        times_extra = []
        for i in ucomm2:
            files_extra.append(ice_files[i])
            times_extra.append(ice_times[i])
    
        files = files_comm + files_extra
        times = itrsct_vals.tolist() + times_extra
        times = np.array(times)
    
        sidxs = np.argsort(times)
        for i in sidxs:
            filename = os.path.split(files[i])[1]
            so = re.search('_s\\d{11}', filename)
            dt_str = so.group()
            print(dt_str[2:])
    
    
    # This mostly reduces some categories for a degree of class balancing and removes no intensity reports
    def process(ice_dct, no_ice_dct, neg_ice_dct):
        new_ice_dct = {}
        new_no_ice_dct = {}
        new_neg_ice_dct = {}
    
        ice_keys_5_6 = []
        ice_tidx_5_6 = []
        ice_keys_1 = []
        ice_tidx_1 = []
        ice_keys_4 = []
        ice_tidx_4 = []
        ice_keys_3 = []
        ice_tidx_3 = []
        ice_keys_2 = []
        ice_tidx_2 = []
    
        print('num keys ice, no_ice, neg_ice: ', len(ice_dct), len(no_ice_dct), len(neg_ice_dct))
        no_intensity_cnt = 0
        num_ice_reports = 0
    
        for ts in list(ice_dct.keys()):
            rpts = ice_dct[ts]
            for idx, tup in enumerate(rpts):
                num_ice_reports += 1
                if tup[3] == 5 or tup[3] == 6:
                    ice_keys_5_6.append(ts)
                    ice_tidx_5_6.append(idx)
                elif tup[3] == 1:
                    ice_keys_1.append(ts)
                    ice_tidx_1.append(idx)
                elif tup[3] == 4:
                    ice_keys_4.append(ts)
                    ice_tidx_4.append(idx)
                elif tup[3] == 3:
                    ice_keys_3.append(ts)
                    ice_tidx_3.append(idx)
                elif tup[3] == 2:
                    ice_keys_2.append(ts)
                    ice_tidx_2.append(idx)
                else:
                    no_intensity_cnt += 1
    
        no_ice_keys = []
        no_ice_tidx = []
        for ts in list(no_ice_dct.keys()):
            rpts = no_ice_dct[ts]
            for idx, tup in enumerate(rpts):
                no_ice_keys.append(ts)
                no_ice_tidx.append(idx)
    
        neg_ice_keys = []
        for ts in list(neg_ice_dct.keys()):
                rpts = neg_ice_dct[ts]
                for tup in rpts:
                    neg_ice_keys.append(ts)
    
        print('num ice reports, no ice, neg ice: ', num_ice_reports, len(no_ice_keys), len(neg_ice_keys))
        print('------------------------------------------------')
    
        ice_keys_5_6 = np.array(ice_keys_5_6)
        ice_tidx_5_6 = np.array(ice_tidx_5_6)
        print('5_6: ', ice_keys_5_6.shape[0])
    
        ice_keys_4 = np.array(ice_keys_4)
        ice_tidx_4 = np.array(ice_tidx_4)
        print('4: ', ice_keys_4.shape[0])
    
        ice_keys_3 = np.array(ice_keys_3)
        ice_tidx_3 = np.array(ice_tidx_3)
        print('3: ', ice_keys_3.shape[0])
    
        ice_keys_2 = np.array(ice_keys_2)
        ice_tidx_2 = np.array(ice_tidx_2)
        print('2: ', ice_keys_2.shape[0])
        np.random.seed(42)
        ridxs = np.random.permutation(np.arange(ice_keys_2.shape[0]))
        ice_keys_2 = ice_keys_2[ridxs]
        ice_tidx_2 = ice_tidx_2[ridxs]
        num = int(ice_keys_2.shape[0] * 0.7)
        ice_keys_2 = ice_keys_2[0:num]
        ice_tidx_2 = ice_tidx_2[0:num]
        print('2: reduced: ', ice_tidx_2.shape)
    
        ice_keys_1 = np.array(ice_keys_1)
        ice_tidx_1 = np.array(ice_tidx_1)
        print('1: ', ice_keys_1.shape[0])
        print('0: ', no_intensity_cnt)
    
        ice_keys = np.concatenate([ice_keys_1, ice_keys_2, ice_keys_3, ice_keys_4, ice_keys_5_6])
        ice_tidx = np.concatenate([ice_tidx_1, ice_tidx_2, ice_tidx_3, ice_tidx_4, ice_tidx_5_6])
        print('icing total reduced: ', ice_tidx.shape)
    
        sidxs = np.argsort(ice_keys)
        ice_keys = ice_keys[sidxs]
        ice_tidx = ice_tidx[sidxs]
    
        for idx, key in enumerate(ice_keys):
            rpts = ice_dct[key]
            tup = rpts[ice_tidx[idx]]
    
            n_rpts = new_ice_dct.get(key)
            if n_rpts is None:
                n_rpts = []
                new_ice_dct[key] = n_rpts
            n_rpts.append(tup)
    
        # -----------------------------------------------------
        no_ice_keys = np.array(no_ice_keys)
        no_ice_tidx = np.array(no_ice_tidx)
        print('no ice total: ', no_ice_keys.shape[0])
        np.random.seed(42)
        ridxs = np.random.permutation(np.arange(no_ice_keys.shape[0]))
        no_ice_keys = no_ice_keys[ridxs]
        no_ice_tidx = no_ice_tidx[ridxs]
        no_ice_keys = no_ice_keys[::10]
        no_ice_tidx = no_ice_tidx[::10]
        print('no ice reduced: ', no_ice_keys.shape[0])
    
        sidxs = np.argsort(no_ice_keys)
        no_ice_keys = no_ice_keys[sidxs]
        no_ice_tidx = no_ice_tidx[sidxs]
    
        for idx, key in enumerate(no_ice_keys):
            rpts = no_ice_dct[key]
            tup = rpts[no_ice_tidx[idx]]
    
            n_rpts = new_no_ice_dct.get(key)
            if n_rpts is None:
                n_rpts = []
                new_no_ice_dct[key] = n_rpts
            n_rpts.append(tup)
        # -------------------------------------------------
    
        neg_ice_keys = np.array(neg_ice_keys)
        print('neg ice total: ', neg_ice_keys.shape[0])
        np.random.seed(42)
        np.random.shuffle(neg_ice_keys)
        neg_ice_keys = neg_ice_keys[0:12000]
        uniq_sorted_neg_ice = np.unique(neg_ice_keys)
        print('neg ice reduced: ', uniq_sorted_neg_ice.shape)
    
        for key in uniq_sorted_neg_ice:
            new_neg_ice_dct[key] = neg_ice_dct[key]
    
        return new_ice_dct, new_no_ice_dct, new_neg_ice_dct
    
    
    def analyze2(filename, filename_l1b):
        f = h5py.File(filename, 'r')
        icing_alt = f['icing_altitude'][:]
        cld_top_hgt = f['cld_height_acha'][:, 10:30, 10:30]
        cld_phase = f['cloud_phase'][:, 10:30, 10:30]
        cld_opd_dc = f['cld_opd_dcomp'][:, 10:30, 10:30]
        cld_opd = f['cld_opd_acha'][:, 10:30, 10:30]
        solzen = f['solar_zenith_angle'][:, 10:30, 10:30]
    
        f_l1b = h5py.File(filename_l1b, 'r')
        bt_11um = f_l1b['temp_11_0um_nom'][:, 10:30, 10:30]
    
        cld_opd = cld_opd.flatten()
        cld_opd_dc = cld_opd_dc.flatten()
        solzen = solzen.flatten()
    
        keep1 = np.invert(np.isnan(cld_opd))
        keep2 = np.invert(np.isnan(solzen))
        keep = keep1 & keep2
        cld_opd = cld_opd[np.invert(np.isnan(cld_opd))]
        cld_opd_dc = cld_opd_dc[keep]
        solzen = solzen[keep]
    
        plt.hist(cld_opd, bins=20)
        plt.show()
        plt.hist(cld_opd_dc, bins=20)
        plt.show()
    
    
    # --------------------------------------------
    x_a = 10
    x_b = 30
    y_a = x_a
    y_b = x_b
    nx = ny = (x_b - x_a)
    nx_x_ny = nx * ny
    
    
    def run_daynight(filename, filename_l1b, day_night='ANY'):
        f = h5py.File(filename, 'r')
        f_l1b = h5py.File(filename_l1b, 'r')
    
        solzen = f['solar_zenith_angle'][:, y_a:y_b, x_a:x_b]
        num_obs = solzen.shape[0]
    
        idxs = []
        for i in range(num_obs):
            if day_night == 'NIGHT':
                if is_night(solzen[i,]):
                    idxs.append(i)
            elif day_night == 'DAY':
                if is_day(solzen[i,]):
                    idxs.append(i)
            else:
                idxs.append(i)
    
        keep_idxs = np.array(idxs)
    
        data_dct = {}
        for didx, ds_name in enumerate(ds_list):
            data_dct[ds_name] = f[ds_name][keep_idxs,]
    
        lon_c = f['longitude'][keep_idxs]
        lat_c = f['latitude'][keep_idxs]
        time_s = f['time'][keep_idxs]
        fl_alt_s = f['icing_altitude'][keep_idxs]
        ice_int_s = f['icing_intensity'][keep_idxs]
        unq_ids = f['unique_id'][keep_idxs]
    
        path, fname = os.path.split(filename)
        fbase, fext = os.path.splitext(fname)
        outfile = path + '/' + fbase + '_' + day_night + fext
    
        create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)
    
        data_dct = {}
        for didx, ds_name in enumerate(l1b_ds_list):
            data_dct[ds_name] = f_l1b[ds_name][keep_idxs]
    
        path, fname = os.path.split(filename_l1b)
        fbase, fext = os.path.splitext(fname)
        outfile_l1b = path + '/' + fbase + '_' + day_night + fext
    
        create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)
    
        f.close()
        f_l1b.close()
    
    
    def run_qc(filename, filename_l1b, day_night='ANY', pass_thresh_frac=0.20, icing=True):
    
        f = h5py.File(filename, 'r')
        icing_alt = f['icing_altitude'][:]
        cld_top_hgt = f['cld_height_acha'][:, y_a:y_b, x_a:x_b]
        cld_phase = f['cloud_phase'][:, y_a:y_b, x_a:x_b]
    
        if day_night == 'DAY':
            cld_opd = f['cld_opd_dcomp'][:, y_a:y_b, x_a:x_b]
        else:
            cld_opd = f['cld_opd_acha'][:, y_a:y_b, x_a:x_b]
    
        cld_mask = f['cloud_mask'][:, y_a:y_b, x_a:x_b]
        sol_zen = f['solar_zenith_angle'][:, y_a:y_b, x_a:x_b]
        sat_zen = f['sensor_zenith_angle'][:, y_a:y_b, x_a:x_b]
    
        f_l1b = h5py.File(filename_l1b, 'r')
        bt_11um = f_l1b['temp_11_0um_nom'][:, y_a:y_b, x_a:x_b]
    
        print('num pireps all: ', len(icing_alt))
    
        if icing:
            mask, idxs, num_tested = apply_qc_icing_pireps(icing_alt, cld_top_hgt, cld_phase, cld_opd, cld_mask, bt_11um, sol_zen, sat_zen, day_night=day_night)
        else:
            mask, idxs, num_tested = apply_qc_no_icing_pireps(icing_alt, cld_top_hgt, cld_phase, cld_opd, cld_mask, bt_11um, sol_zen, sat_zen, day_night=day_night)
        print('num pireps, day_night: ', len(mask), day_night)
    
        keep_idxs = []
    
        for i in range(len(mask)):
            # frac = np.sum(mask[i]) / nx_x_ny
            frac = np.sum(mask[i]) / num_tested[i]
            if icing:
                if frac > pass_thresh_frac:
                    keep_idxs.append(idxs[i])
            elif frac > pass_thresh_frac:
                keep_idxs.append(idxs[i])
    
        print('num valid pireps: ', len(keep_idxs))
        keep_idxs = np.array(keep_idxs)
    
        data_dct = {}
        for didx, ds_name in enumerate(ds_list):
            data_dct[ds_name] = f[ds_name][keep_idxs,]
    
        lon_c = f['longitude'][keep_idxs]
        lat_c = f['latitude'][keep_idxs]
        time_s = f['time'][keep_idxs]
        fl_alt_s = f['icing_altitude'][keep_idxs]
        ice_int_s = f['icing_intensity'][keep_idxs]
        unq_ids = f['unique_id'][keep_idxs]
    
        path, fname = os.path.split(filename)
        fbase, fext = os.path.splitext(fname)
        outfile = path + '/' + fbase + '_' + 'QC' + '_' + day_night + fext
    
        create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)
    
        data_dct = {}
        for didx, ds_name in enumerate(l1b_ds_list):
            data_dct[ds_name] = f_l1b[ds_name][keep_idxs]
    
        path, fname = os.path.split(filename_l1b)
        fbase, fext = os.path.splitext(fname)
        outfile_l1b = path + '/' + fbase + '_' + 'QC' + '_' + day_night + fext
    
        create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)
    
        f.close()
        f_l1b.close()
    
    
    def apply_qc_icing_pireps(icing_alt, cld_top_hgt, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, day_night='ANY'):
    
        if day_night == 'DAY':
            opd_thick_threshold = 20
            opd_thin_threshold = 1
        elif day_night == 'NIGHT' or day_night == 'ANY':
            opd_thick_threshold = 2
            opd_thin_threshold = 0.1
    
        closeness = 100.0  # meters
        num_obs = len(icing_alt)
        cld_mask = cld_mask.reshape((num_obs, -1))
        cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
        cld_phase = cld_phase.reshape((num_obs, -1))
        cld_opd = cld_opd.reshape((num_obs, -1))
        bt_11um = bt_11um.reshape((num_obs, -1))
    
        mask = []
        idxs = []
        num_tested = []
        for i in range(num_obs):
            if not check_oblique(satzen[i,]):
                continue
            if day_night == 'NIGHT' and not is_night(solzen[i,]):
                continue
            elif day_night == 'DAY' and not is_day(solzen[i,]):
                continue
            elif day_night == 'ANY':
                if not (is_day(solzen[i,]) or is_night(solzen[i,])):
                    continue
    
            keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
            keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
            keep_2 = np.invert(np.isnan(bt_11um[i,]))
            keep_3 = np.invert(np.isnan(cld_opd[i,]))
            keep = keep_0 & keep_1 & keep_2 & keep_3
            num_keep = np.sum(keep)
            if num_keep == 0:
                continue
    
            # Test1
            keep = np.where(keep, cld_top_hgt[i,] > icing_alt[i], False)
    
            # Test2
            keep = np.where(keep,
                np.invert((cld_phase[i,] == 4) &
                    np.logical_and(cld_top_hgt[i,]+closeness > icing_alt[i], cld_top_hgt[i,]-closeness < icing_alt[i])),
                         False)
    
            # Test4
            keep = np.where(keep, np.invert((cld_phase[i,] == 4) & (cld_opd[i,] < opd_thin_threshold) & (cld_top_hgt[i,] > icing_alt[i])), False)
    
            # Test5
            keep = np.where(keep, np.invert(bt_11um[i,] > 270.0), False)
    
            # Test6
            keep = np.where(keep, np.invert(bt_11um[i,] < 228.0), False)
    
            # Test3
            keep = np.where(keep, (cld_opd[i,] >= opd_thick_threshold) & (cld_phase[i,] == 4) & (cld_top_hgt[i,] > icing_alt[i]), False)
    
            mask.append(keep)
            idxs.append(i)
            num_tested.append(num_keep)
    
        return mask, idxs, num_tested
    
    
    def apply_qc_no_icing_pireps(icing_alt, cld_top_hgt, cld_phase, cld_opd, cld_mask, bt_11um, solzen, satzen, day_night='ANY'):
    
        if day_night == 'DAY':
            opd_thick_threshold = 20
            opd_thin_threshold = 1
        elif day_night == 'NIGHT' or day_night == 'ANY':
            opd_thick_threshold = 2
            opd_thin_threshold = 0.1
    
        closeness = 100.0  # meters
        num_obs = len(icing_alt)
        cld_mask = cld_mask.reshape((num_obs, -1))
        cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
        cld_phase = cld_phase.reshape((num_obs, -1))
        cld_opd = cld_opd.reshape((num_obs, -1))
        bt_11um = bt_11um.reshape((num_obs, -1))
    
        mask = []
        idxs = []
        num_tested = []
        for i in range(num_obs):
            if not check_oblique(satzen[i,]):
                continue
            if day_night == 'NIGHT' and not is_night(solzen[i,]):
                continue
            elif day_night == 'DAY' and not is_day(solzen[i,]):
                continue
            elif day_night == 'ANY':
                if not (is_day(solzen[i,]) or is_night(solzen[i,])):
                    continue
    
            keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
            keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
            keep_2 = np.invert(np.isnan(bt_11um[i,]))
            keep_3 = np.invert(np.isnan(cld_opd[i,]))
            keep = keep_0 & keep_1 & keep_2 & keep_3
            num_keep = np.sum(keep)
            if num_keep == 0:
                continue
    
            # Test1
            keep = np.where(keep, cld_top_hgt[i,] > icing_alt[i], False)
    
            # Test3
            keep = np.where(keep, np.invert((cld_opd[i,] >= opd_thick_threshold) & (cld_phase[i,] == 4) & (cld_top_hgt[i,] > icing_alt[i])), False)
    
            mask.append(keep)
            idxs.append(i)
            num_tested.append(num_keep)
    
        return mask, idxs, num_tested
    
    
    def fov_extract(trnfile='/home/rink/fovs_l1b_train.h5', tstfile='/home/rink/fovs_l1b_test.h5', L1B_or_L2='L1B', split=0.2):
        ice_times = []
        icing_int_s = []
        ice_lons = []
        ice_lats = []
    
        no_ice_times = []
        no_ice_lons = []
        no_ice_lats = []
    
        h5_s_icing = []
        h5_s_no_icing = []
    
        if L1B_or_L2 == 'L1B':
            params = l1b_ds_list
            param_types = l1b_ds_types
        elif L1B_or_L2 == 'L2':
            params = ds_list
            param_types = ds_types
    
        icing_data_dct = {ds: [] for ds in params}
        no_icing_data_dct = {ds: [] for ds in params}
    
        sub_indexes = np.arange(400)
    
        num_ice = 0
        for fidx in range(len(icing_files)):
            fname = icing_files[fidx]
            f = h5py.File(fname, 'r')
            h5_s_icing.append(f)
    
            times = f['time'][:]
            num_obs = len(times)
    
            icing_int = f['icing_intensity'][:]
            lons = f['longitude'][:]
            lats = f['latitude'][:]
    
            cld_mask = f['cloud_mask'][:, 10:30, 10:30]
            cld_mask = cld_mask.reshape((num_obs, -1))
    
            cld_top_temp = f['cld_temp_acha'][:, 10:30, 10:30]
            cld_top_temp = cld_top_temp.reshape((num_obs, -1))
    
            for i in range(num_obs):
                keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
                keep_1 = np.invert(np.isnan(cld_top_temp[i,]))
                keep = keep_0 & keep_1
                keep = np.where(keep, cld_top_temp[i,] < 273.0, False)
                k_idxs = sub_indexes[keep]
                np.random.shuffle(k_idxs)
                if len(k_idxs) > 20:
                    k_idxs = k_idxs[0:20]
                else:
                    k_idxs = k_idxs[0:len(k_idxs)]
                num_ice += len(k_idxs)
    
                for ds_name in params:
                    dat = f[ds_name][i, 10:30, 10:30].flatten()
                    icing_data_dct[ds_name].append(dat[k_idxs])
    
                icing_int_s.append(np.full(len(k_idxs), icing_int[i]))
                ice_times.append(np.full(len(k_idxs), times[i]))
                ice_lons.append(np.full(len(k_idxs), lons[i]))
                ice_lats.append(np.full(len(k_idxs), lats[i]))
    
            print(fname)
    
        for ds_name in params:
            lst = icing_data_dct[ds_name]
            icing_data_dct[ds_name] = np.concatenate(lst)
    
        icing_int_s = np.concatenate(icing_int_s)
        ice_times = np.concatenate(ice_times)
        ice_lons = np.concatenate(ice_lons)
        ice_lats = np.concatenate(ice_lats)
    
        num_no_ice = 0
        for fidx in range(len(no_icing_files)):
            fname = no_icing_files[fidx]
            f = h5py.File(fname, 'r')
            h5_s_no_icing.append(f)
    
            times = f['time']
            num_obs = len(times)
            lons = f['longitude']
            lats = f['latitude']
    
            cld_mask = f['cloud_mask'][:, 10:30, 10:30]
            cld_mask = cld_mask.reshape((num_obs, -1))
    
            cld_top_temp = f['cld_temp_acha'][:, 10:30, 10:30]
            cld_top_temp = cld_top_temp.reshape((num_obs, -1))
    
            for i in range(num_obs):
                keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
                keep_1 = np.invert(np.isnan(cld_top_temp[i,]))
                keep = keep_0 & keep_1
                keep = np.where(keep, cld_top_temp[i,] < 273.0, False)
                k_idxs = sub_indexes[keep]
                np.random.shuffle(k_idxs)
                if len(k_idxs) > 10:
                    k_idxs = k_idxs[0:10]
                else:
                    k_idxs = k_idxs[0:len(k_idxs)]
                num_no_ice += len(k_idxs)
                no_ice_times.append(np.full(len(k_idxs), times[i]))
                no_ice_lons.append(np.full(len(k_idxs), lons[i]))
                no_ice_lats.append(np.full(len(k_idxs), lats[i]))
    
                for ds_name in params:
                    dat = f[ds_name][i, 10:30, 10:30].flatten()
                    no_icing_data_dct[ds_name].append(dat[k_idxs])
    
            print(fname)
    
        for ds_name in params:
            lst = no_icing_data_dct[ds_name]
            no_icing_data_dct[ds_name] = np.concatenate(lst)
        no_icing_int_s = np.full(num_no_ice, -1)
        no_ice_times = np.concatenate(no_ice_times)
        no_ice_lons = np.concatenate(no_ice_lons)
        no_ice_lats = np.concatenate(no_ice_lats)
    
        icing_intensity = np.concatenate([icing_int_s, no_icing_int_s])
        icing_times = np.concatenate([ice_times, no_ice_times])
        icing_lons = np.concatenate([ice_lons, no_ice_lons])
        icing_lats = np.concatenate([ice_lats, no_ice_lats])
    
        data_dct = {}
        for ds_name in params:
            data_dct[ds_name] = np.concatenate([icing_data_dct[ds_name], no_icing_data_dct[ds_name]])
    
        # apply shuffle indexes
        # ds_indexes = np.arange(num_ice + num_no_ice)
        # np.random.shuffle(ds_indexes)
        #
        # for ds_name in train_params:
        #     data_dct[ds_name] = data_dct[ds_name][ds_indexes]
        # icing_intensity = icing_intensity[ds_indexes]
        # icing_times = icing_times[ds_indexes]
        # icing_lons = icing_lons[ds_indexes]
        # icing_lats = icing_lats[ds_indexes]
    
        # do sort
        ds_indexes = np.argsort(icing_times)
        for ds_name in params:
            data_dct[ds_name] = data_dct[ds_name][ds_indexes]
        icing_intensity = icing_intensity[ds_indexes]
        icing_times = icing_times[ds_indexes]
        icing_lons = icing_lons[ds_indexes]
        icing_lats = icing_lats[ds_indexes]
    
        #trn_idxs, tst_idxs = split_data(icing_intensity.shape[0], shuffle=False, perc=split)
        all_idxs = np.arange(icing_intensity.shape[0])
        splt_idx = int(icing_intensity.shape[0] * (1-split))
        trn_idxs = all_idxs[0:splt_idx]
        tst_idxs = all_idxs[splt_idx:]
    
        trn_data_dct = {}
        for ds_name in params:
            trn_data_dct[ds_name] = data_dct[ds_name][trn_idxs,]
        trn_icing_intensity = icing_intensity[trn_idxs,]
        trn_icing_times = icing_times[trn_idxs,]
        trn_icing_lons = icing_lons[trn_idxs,]
        trn_icing_lats = icing_lats[trn_idxs,]
    
        write_file(trnfile, params, param_types, trn_data_dct, trn_icing_intensity, trn_icing_times, trn_icing_lons, trn_icing_lats)
    
        tst_data_dct = {}
        for ds_name in params:
            tst_data_dct[ds_name] = data_dct[ds_name][tst_idxs,]
        tst_icing_intensity = icing_intensity[tst_idxs,]
        tst_icing_times = icing_times[tst_idxs,]
        tst_icing_lons = icing_lons[tst_idxs,]
        tst_icing_lats = icing_lats[tst_idxs,]
    
        # do sort
        ds_indexes = np.argsort(tst_icing_times)
        for ds_name in params:
            tst_data_dct[ds_name] = tst_data_dct[ds_name][ds_indexes]
        tst_icing_intensity = tst_icing_intensity[ds_indexes]
        tst_icing_times = tst_icing_times[ds_indexes]
        tst_icing_lons = tst_icing_lons[ds_indexes]
        tst_icing_lats = tst_icing_lats[ds_indexes]
    
        write_file(tstfile, params, param_types, tst_data_dct, tst_icing_intensity, tst_icing_times, tst_icing_lons, tst_icing_lats)
    
        # --- close files
        for h5f in h5_s_icing:
            h5f.close()
    
        for h5f in h5_s_no_icing:
            h5f.close()
    
    
    def tile_extract(trnfile='/home/rink/tiles_l1b_train.h5', tstfile='/home/rink/tiles_l1b_test.h5', L1B_or_L2='L1B',
                     cld_mask_name='cloud_mask', augment=False, do_split=True):
        icing_int_s = []
        ice_time_s = []
        no_ice_time_s = []
        ice_lon_s = []
        no_ice_lon_s = []
        ice_lat_s = []
        no_ice_lat_s = []
        ice_flt_alt_s = []
        no_ice_flt_alt_s = []
    
        h5_s_icing = []
        h5_s_no_icing = []
    
        if L1B_or_L2 == 'L1B':
            params = l1b_ds_list
            param_types = l1b_ds_types
        elif L1B_or_L2 == 'L2':
            params = ds_list
            param_types = ds_types
    
        icing_data_dct = {ds: [] for ds in params}
        no_icing_data_dct = {ds: [] for ds in params}
    
        for fidx in range(len(icing_files)):
            fname = icing_files[fidx]
            f = h5py.File(fname, 'r')
            h5_s_icing.append(f)
    
            times = f['time'][:]
            num_obs = len(times)
            lons = f['longitude']
            lats = f['latitude']
    
            icing_int = f['icing_intensity'][:]
            flt_altitude = f['icing_altitude'][:]
    
            for i in range(num_obs):
                cld_msk = f[cld_mask_name][i, 12:28, 12:28]
                for ds_name in params:
                    dat = f[ds_name][i, 12:28, 12:28]
                    if L1B_or_L2 == 'L2':
                        keep = np.logical_or(cld_msk == 2, cld_msk == 3)  # cloudy
                        np.where(keep, dat, np.nan)
                    icing_data_dct[ds_name].append(dat)
    
                icing_int_s.append(icing_int[i])
                ice_time_s.append(times[i])
                ice_lon_s.append(lons[i])
                ice_lat_s.append(lats[i])
                ice_flt_alt_s.append(flt_altitude[i])
    
            print(fname)
    
        for ds_name in params:
            lst = icing_data_dct[ds_name]
            icing_data_dct[ds_name] = np.stack(lst, axis=0)
        icing_int_s = np.array(icing_int_s)
        ice_time_s = np.array(ice_time_s)
        ice_lon_s = np.array(ice_lon_s)
        ice_lat_s = np.array(ice_lat_s)
        ice_flt_alt_s = np.array(ice_flt_alt_s)
        num_ice = icing_int_s.shape[0]
    
        # No icing  ------------------------------------------------------------
        num_no_ice = 0
        for fidx in range(len(no_icing_files)):
            fname = no_icing_files[fidx]
            f = h5py.File(fname, 'r')
            h5_s_no_icing.append(f)
    
            times = f['time']
            num_obs = len(times)
            lons = f['longitude']
            lats = f['latitude']
            flt_altitude = f['icing_altitude'][:]
    
            for i in range(num_obs):
                cld_msk = f[cld_mask_name][i, 12:28, 12:28]
                for ds_name in params:
                    dat = f[ds_name][i, 12:28, 12:28]
                    if L1B_or_L2 == 'L2':
                        keep = np.logical_or(cld_msk == 2, cld_msk == 3)  # cloudy
                        np.where(keep, dat, np.nan)
                    no_icing_data_dct[ds_name].append(dat)
                num_no_ice += 1
                no_ice_time_s.append(times[i])
                no_ice_lon_s.append(lons[i])
                no_ice_lat_s.append(lats[i])
                no_ice_flt_alt_s.append(flt_altitude[i])
    
            print(fname)
    
        for ds_name in params:
            lst = no_icing_data_dct[ds_name]
            no_icing_data_dct[ds_name] = np.stack(lst, axis=0)
        no_icing_int_s = np.full(num_no_ice, -1)
        no_ice_time_s = np.array(no_ice_time_s)
        no_ice_lon_s = np.array(no_ice_lon_s)
        no_ice_lat_s = np.array(no_ice_lat_s)
        no_ice_flt_alt_s = np.array(no_ice_flt_alt_s)
    
        icing_intensity = np.concatenate([icing_int_s, no_icing_int_s])
        icing_times = np.concatenate([ice_time_s, no_ice_time_s])
        icing_lons = np.concatenate([ice_lon_s, no_ice_lon_s])
        icing_lats = np.concatenate([ice_lat_s, no_ice_lat_s])
        icing_alt = np.concatenate([ice_flt_alt_s, no_ice_flt_alt_s])
    
        data_dct = {}
        for ds_name in params:
            data_dct[ds_name] = np.concatenate([icing_data_dct[ds_name], no_icing_data_dct[ds_name]])
    
        # do sort -------------------------------------
        ds_indexes = np.argsort(icing_times)
        for ds_name in params:
            data_dct[ds_name] = data_dct[ds_name][ds_indexes]
        icing_intensity = icing_intensity[ds_indexes]
        icing_times = icing_times[ds_indexes]
        icing_lons = icing_lons[ds_indexes]
        icing_lats = icing_lats[ds_indexes]
        icing_alt = icing_alt[ds_indexes]
    
        if do_split:
            trn_idxs, tst_idxs = split_data(icing_times)
        else:
            trn_idxs = np.arange(icing_intensity.shape[0])
            tst_idxs = None
        # ---------------------------------------------
    
        trn_data_dct = {}
        for ds_name in params:
            trn_data_dct[ds_name] = data_dct[ds_name][trn_idxs,]
        trn_icing_intensity = icing_intensity[trn_idxs,]
        trn_icing_times = icing_times[trn_idxs,]
        trn_icing_lons = icing_lons[trn_idxs,]
        trn_icing_lats = icing_lats[trn_idxs,]
        trn_icing_alt = icing_alt[trn_idxs,]
    
        #  Data augmentation -------------------------------------------------------------
        if augment:
            trn_data_dct_aug = {ds_name: [] for ds_name in params}
            trn_icing_intensity_aug = []
            trn_icing_times_aug = []
            trn_icing_lons_aug = []
            trn_icing_lats_aug = []
            trn_icing_alt_aug = []
    
            for k in range(trn_icing_intensity.shape[0]):
                iceint = trn_icing_intensity[k]
                icetime = trn_icing_times[k]
                icelon = trn_icing_lons[k]
                icelat = trn_icing_lats[k]
                icealt = trn_icing_alt[k]
                if iceint == 3 or iceint == 4 or iceint == 5 or iceint == 6:
                    for ds_name in params:
                        dat = trn_data_dct[ds_name]
                        trn_data_dct_aug[ds_name].append(np.fliplr(dat[k,]))
                        trn_data_dct_aug[ds_name].append(np.flipud(dat[k,]))
                        trn_data_dct_aug[ds_name].append(np.rot90(dat[k,]))
    
                    trn_icing_intensity_aug.append(iceint)
                    trn_icing_intensity_aug.append(iceint)
                    trn_icing_intensity_aug.append(iceint)
    
                    trn_icing_times_aug.append(icetime)
                    trn_icing_times_aug.append(icetime)
                    trn_icing_times_aug.append(icetime)
    
                    trn_icing_lons_aug.append(icelon)
                    trn_icing_lons_aug.append(icelon)
                    trn_icing_lons_aug.append(icelon)
    
                    trn_icing_lats_aug.append(icelat)
                    trn_icing_lats_aug.append(icelat)
                    trn_icing_lats_aug.append(icelat)
    
                    trn_icing_alt_aug.append(icealt)
                    trn_icing_alt_aug.append(icealt)
                    trn_icing_alt_aug.append(icealt)
    
            for ds_name in params:
                trn_data_dct_aug[ds_name] = np.stack(trn_data_dct_aug[ds_name])
            trn_icing_intensity_aug = np.stack(trn_icing_intensity_aug)
            trn_icing_times_aug = np.stack(trn_icing_times_aug)
            trn_icing_lons_aug = np.stack(trn_icing_lons_aug)
            trn_icing_lats_aug = np.stack(trn_icing_lats_aug)
            trn_icing_alt_aug = np.stack(trn_icing_alt_aug)
    
            for ds_name in params:
                trn_data_dct[ds_name] = np.concatenate([trn_data_dct[ds_name], trn_data_dct_aug[ds_name]])
            trn_icing_intensity = np.concatenate([trn_icing_intensity, trn_icing_intensity_aug])
            trn_icing_times = np.concatenate([trn_icing_times, trn_icing_times_aug])
            trn_icing_lons = np.concatenate([trn_icing_lons, trn_icing_lons_aug])
            trn_icing_lats = np.concatenate([trn_icing_lats, trn_icing_lats_aug])
            trn_icing_alt = np.concatenate([trn_icing_alt, trn_icing_alt_aug])
    
        # do sort
        ds_indexes = np.argsort(trn_icing_times)
        for ds_name in params:
            trn_data_dct[ds_name] = trn_data_dct[ds_name][ds_indexes]
        trn_icing_intensity = trn_icing_intensity[ds_indexes]
        trn_icing_times = trn_icing_times[ds_indexes]
        trn_icing_lons = trn_icing_lons[ds_indexes]
        trn_icing_lats = trn_icing_lats[ds_indexes]
        trn_icing_alt = trn_icing_alt[ds_indexes]
    
        write_file(trnfile, params, param_types, trn_data_dct, trn_icing_intensity, trn_icing_times, trn_icing_lons, trn_icing_lats, trn_icing_alt)
    
        if do_split:
            tst_data_dct = {}
            for ds_name in params:
                tst_data_dct[ds_name] = data_dct[ds_name][tst_idxs,]
            tst_icing_intensity = icing_intensity[tst_idxs,]
            tst_icing_times = icing_times[tst_idxs,]
            tst_icing_lons = icing_lons[tst_idxs,]
            tst_icing_lats = icing_lats[tst_idxs,]
            tst_icing_alt = icing_alt[tst_idxs,]
    
            # do sort
            ds_indexes = np.argsort(tst_icing_times)
            for ds_name in params:
                tst_data_dct[ds_name] = tst_data_dct[ds_name][ds_indexes]
            tst_icing_intensity = tst_icing_intensity[ds_indexes]
            tst_icing_times = tst_icing_times[ds_indexes]
            tst_icing_lons = tst_icing_lons[ds_indexes]
            tst_icing_lats = tst_icing_lats[ds_indexes]
            tst_icing_alt = tst_icing_alt[ds_indexes]
    
            write_file(tstfile, params, param_types, tst_data_dct, tst_icing_intensity, tst_icing_times, tst_icing_lons, tst_icing_lats, tst_icing_alt)
    
        # --- close files
        for h5f in h5_s_icing:
            h5f.close()
    
        for h5f in h5_s_no_icing:
            h5f.close()
    
    
    def write_file(outfile, params, param_types, data_dct, icing_intensity, icing_times, icing_lons, icing_lats, icing_alt):
        h5f_expl = h5py.File(a_clvr_file, 'r')
        h5f_out = h5py.File(outfile, 'w')
    
        for idx, ds_name in enumerate(params):
            dt = param_types[idx]
            data = data_dct[ds_name]
            h5f_out.create_dataset(ds_name, data=data, dtype=dt)
    
        icing_int_ds = h5f_out.create_dataset('icing_intensity', data=icing_intensity, dtype='i4')
        icing_int_ds.attrs.create('long_name', data='From PIREP. -1:No Icing, 1:Trace, 2:Light, 3:Light Moderate, 4:Moderate, 5:Moderate Severe, 6:Severe')
    
        time_ds = h5f_out.create_dataset('time', data=icing_times, dtype='f4')
        time_ds.attrs.create('units', data='seconds since 1970-1-1 00:00:00')
        time_ds.attrs.create('long_name', data='PIREP time')
    
        lon_ds = h5f_out.create_dataset('longitude', data=icing_lons, dtype='f4')
        lon_ds.attrs.create('units', data='degrees_east')
        lon_ds.attrs.create('long_name', data='PIREP longitude')
    
        lat_ds = h5f_out.create_dataset('latitude', data=icing_lats, dtype='f4')
        lat_ds.attrs.create('units', data='degrees_north')
        lat_ds.attrs.create('long_name', data='PIREP latitude')
    
        alt_ds = h5f_out.create_dataset('flight_altitude', data=icing_alt, dtype='f4')
        alt_ds.attrs.create('units', data='meter')
        alt_ds.attrs.create('long_name', data='PIREP altitude')
    
        # copy relevant attributes
        for ds_name in params:
            h5f_ds = h5f_out[ds_name]
            h5f_ds.attrs.create('standard_name', data=h5f_expl[ds_name].attrs.get('standard_name'))
            h5f_ds.attrs.create('long_name', data=h5f_expl[ds_name].attrs.get('long_name'))
            h5f_ds.attrs.create('units', data=h5f_expl[ds_name].attrs.get('units'))
            attr = h5f_expl[ds_name].attrs.get('actual_range')
            if attr is not None:
                h5f_ds.attrs.create('actual_range', data=attr)
            attr = h5f_expl[ds_name].attrs.get('flag_values')
            if attr is not None:
                h5f_ds.attrs.create('flag_values', data=attr)
    
        # --- close files
        h5f_out.close()
        h5f_expl.close()
    
    
    def run_mean_std(check_cloudy=False):
        ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha',
                   'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha',
                   'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
                   'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
                   'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction']
    
        # ds_list = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
        #            'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
        #            'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
    
        mean_std_dct = {}
    
        ice_flist = [f for f in glob.glob('/data/Personal/rink/icing/icing_2*.h5')]
        no_ice_flist = [f for f in glob.glob('/data/Personal/rink/icing/no_icing_2*.h5')]
    
        ice_h5f_lst = [h5py.File(f, 'r') for f in ice_flist]
        no_ice_h5f_lst = [h5py.File(f, 'r') for f in no_ice_flist]
    
        if check_cloudy:
            cld_msk_i = []
            cld_msk_ni = []
            for idx, ice_h5f in enumerate(ice_h5f_lst):
                no_ice_h5f = no_ice_h5f_lst[idx]
                cld_msk_i.append(ice_h5f['cloud_mask'][:,].flatten())
                cld_msk_ni.append(no_ice_h5f['cloud_mask'][:,].flatten())
            cld_msk_i = np.concatenate(cld_msk_i)
            cld_msk_ni = np.concatenate(cld_msk_ni)
    
        for dname in ds_list:
            data = []
            data_i = []
            data_ni = []
            for idx, ice_h5f in enumerate(ice_h5f_lst):
                no_ice_h5f = no_ice_h5f_lst[idx]
                data.append(ice_h5f[dname][:,].flatten())
                data.append(no_ice_h5f[dname][:,].flatten())
    
                data_i.append(ice_h5f[dname][:,].flatten())
                data_ni.append(no_ice_h5f[dname][:,].flatten())
    
            data = np.concatenate(data)
            mean = np.nanmean(data)
            data -= mean
            std = np.nanstd(data)
    
            data_i = np.concatenate(data_i)
            if check_cloudy:
                keep = np.logical_or(cld_msk_i == 2, cld_msk_i == 3)
                data_i = data_i[keep]
            mean_i = np.nanmean(data_i)
            lo_i = np.nanmin(data_i)
            hi_i = np.nanmax(data_i)
            data_i -= mean_i
            std_i = np.nanstd(data_i)
            cnt_i = np.sum(np.invert(np.isnan(data_i)))
    
            data_ni = np.concatenate(data_ni)
            if check_cloudy:
                keep = np.logical_or(cld_msk_ni == 2, cld_msk_ni == 3)
                data_ni = data_ni[keep]
            mean_ni = np.nanmean(data_ni)
            lo_ni = np.nanmin(data_ni)
            hi_ni = np.nanmax(data_ni)
            data_ni -= mean_ni
            std_ni = np.nanstd(data_ni)
            cnt_ni = np.sum(np.invert(np.isnan(data_ni)))
    
            no_icing_to_icing_ratio = cnt_ni/cnt_i
    
            mean = (mean_i + no_icing_to_icing_ratio*mean_ni)/(no_icing_to_icing_ratio + 1)
            std = (std_i + no_icing_to_icing_ratio*std_ni)/(no_icing_to_icing_ratio + 1)
            lo = (lo_i + no_icing_to_icing_ratio*lo_ni)/(no_icing_to_icing_ratio + 1)
            hi = (hi_i + no_icing_to_icing_ratio*hi_ni)/(no_icing_to_icing_ratio + 1)
    
            print(dname,': (', mean, mean_i, mean_ni, ') (', std, std_i, std_ni, ') ratio: ', no_icing_to_icing_ratio)
            print(dname,': (', lo, lo_i, lo_ni, ') (', hi, hi_i, hi_ni, ') ratio: ', no_icing_to_icing_ratio)
    
            mean_std_dct[dname] = (mean, std, lo, hi)
    
        [h5f.close() for h5f in ice_h5f_lst]
        [h5f.close() for h5f in no_ice_h5f_lst]
    
        f = open('/home/rink/data/icing_ml/mean_std_lo_hi.pkl', 'wb')
        pickle.dump(mean_std_dct, f)
        f.close()
    
        return mean_std_dct
    
    
    def run_mean_std_2(check_cloudy=False, no_icing_to_icing_ratio=5, params=train_params_day):
        params = ['cld_height_acha', 'cld_geo_thick', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_press_acha',
                  'cld_reff_dcomp', 'cld_opd_dcomp', 'cld_cwp_dcomp', 'iwc_dcomp', 'lwc_dcomp']
    
        mean_std_dct = {}
    
        flist = [f for f in glob.glob('/Users/tomrink/data/icing/fov*.h5')]
    
        h5f_lst = [h5py.File(f, 'r') for f in flist]
    
        if check_cloudy:
            cld_msk = []
            for idx, h5f in enumerate(h5f_lst):
                cld_msk.append(h5f['cloud_mask'][:,].flatten())
            cld_msk = np.concatenate(cld_msk)
    
        for dname in params:
            data = []
            for idx, h5f in enumerate(h5f_lst):
                data.append(h5f[dname][:,].flatten())
    
            data = np.concatenate(data)
    
            if check_cloudy:
                keep = np.logical_or(cld_msk == 2, cld_msk == 3)
                data = data[keep]
    
            mean = np.nanmean(data)
            data -= mean
            std = np.nanstd(data)
    
            print(dname,': ', mean, std)
    
            mean_std_dct[dname] = (mean, std)
    
        [h5f.close() for h5f in h5f_lst]
    
        f = open('/Users/tomrink/data/icing/fovs_mean_std_day.pkl', 'wb')
        pickle.dump(mean_std_dct, f)
        f.close()
    
    
    def run_mean_std_3(train_file_path, check_cloudy=False, params=train_params_day):
    
        # params = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha',
        #            'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
        #            'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
        #            'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction']
        #check_cloudy = True
    
        params = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                   'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
                   'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
    
        mean_std_lo_hi_dct = {}
    
        h5f = h5py.File(train_file_path, 'r')
    
        if check_cloudy:
            cld_msk = h5f['cloud_mask'][:].flatten()
    
        for dname in params:
            data = h5f[dname][:,].flatten()
    
            if check_cloudy:
                keep = np.logical_or(cld_msk == 2, cld_msk == 3)
                data = data[keep]
    
            lo = np.nanmin(data)
            hi = np.nanmax(data)
            mean = np.nanmean(data)
            data -= mean
            std = np.nanstd(data)
    
            print(dname,': ', mean, std, lo, hi)
    
            mean_std_lo_hi_dct[dname] = (mean, std, lo, hi)
    
        h5f.close()
    
        f = open('/Users/tomrink/data/icing/mean_std_lo_hi_test.pkl', 'wb')
        pickle.dump(mean_std_lo_hi_dct, f)
        f.close()
    
    
    def split_data(times):
        time_idxs = np.arange(times.shape[0])
    
        time_ranges = [[get_timestamp('2018-01-01_00:00'), get_timestamp('2018-01-07_23:59')],
                       [get_timestamp('2018-03-01_00:00'), get_timestamp('2018-03-07_23:59')],
                       [get_timestamp('2018-05-01_00:00'), get_timestamp('2018-05-07_23:59')],
                       [get_timestamp('2018-07-01_00:00'), get_timestamp('2018-07-07_23:59')],
                       [get_timestamp('2018-09-01_00:00'), get_timestamp('2018-09-07_23:59')],
                       [get_timestamp('2018-11-01_00:00'), get_timestamp('2018-11-07_23:59')],
                       [get_timestamp('2019-01-01_00:00'), get_timestamp('2019-01-07_23:59')],
                       [get_timestamp('2019-03-01_00:00'), get_timestamp('2019-03-07_23:59')],
                       [get_timestamp('2019-05-01_00:00'), get_timestamp('2019-05-07_23:59')],
                       [get_timestamp('2019-07-01_00:00'), get_timestamp('2019-07-07_23:59')],
                       [get_timestamp('2019-09-01_00:00'), get_timestamp('2019-09-07_23:59')],
                       [get_timestamp('2019-11-01_00:00'), get_timestamp('2019-11-07_23:59')]]
    
    
        test_time_idxs = []
        for t_rng in time_ranges:
            tidxs = np.searchsorted(times, t_rng)
            test_time_idxs.append(np.arange(tidxs[0], tidxs[1], 1))
        test_time_idxs = np.concatenate(test_time_idxs, axis=None)
        train_time_idxs = time_idxs[np.in1d(time_idxs, test_time_idxs, invert=True)]
    
        # Keep out
        out_idxs = []
        for k, t_rng in enumerate(time_ranges):
            t_a = time_ranges[k][0]
            t_b = time_ranges[k][1]
            tidxs = np.searchsorted(times, [t_a - 10800, t_a])
            out_idxs.append(np.arange(tidxs[0], tidxs[1], 1))
            tidxs = np.searchsorted(times, [t_b, t_b + 10800])
            out_idxs.append(np.arange(tidxs[0], tidxs[1], 1))
        out_idxs = np.concatenate(out_idxs, axis=None)
        train_time_idxs = train_time_idxs[np.in1d(train_time_idxs, out_idxs, invert=True)]
    
        return train_time_idxs, test_time_idxs
    
    
    def normalize(data, param, mean_std_dict, add_noise=False, noise_scale=1.0, seed=None):
    
        if mean_std_dict.get(param) is None:
            return data
    
        shape = data.shape
        data = data.flatten()
    
        mean, std, lo, hi = mean_std_dict.get(param)
        data -= mean
        data /= std
    
        if add_noise:
            if seed is not None:
                np.random.seed(seed)
            rnd = np.random.normal(loc=0, scale=noise_scale, size=data.size)
            data += rnd
    
        not_valid = np.isnan(data)
        data[not_valid] = 0
    
        data = np.reshape(data, shape)
    
        return data
    
    
    lon_space = np.linspace(-180, 180, 361)
    lat_space = np.linspace(-90, 90, 181)
    
    
    def spatial_filter(icing_dict):
        keys = icing_dict.keys()
        grd_x_hi = lon_space.shape[0] - 1
        grd_y_hi = lat_space.shape[0] - 1
    
        grd_bins = np.full((lat_space.shape[0], lon_space.shape[0]), 0)
    
        for key in keys:
            rpts = icing_dict.get(key)
            for tup in rpts:
                lat = tup[0]
                lon = tup[1]
    
                lon_idx = np.searchsorted(lon_space, lon)
                lat_idx = np.searchsorted(lat_space, lat)
    
                if lon_idx < 0 or lon_idx > grd_x_hi:
                    continue
                if lat_idx < 0 or lat_idx > grd_y_hi:
                    continue
    
                grd_bins[lat_idx, lon_idx] += 1
    
        return grd_bins
    
    
    # def spatial_filter2(icing_dict):
    #     keys = icing_dict.keys()
    #     grd_x_hi = lon_space.shape[0] - 1
    #     grd_y_hi = lat_space.shape[0] - 1
    #
    #     grd_bins = np.full((lat_space.shape[0], lon_space.shape[0]), 0)
    #
    #     for key in keys:
    #             tup = icing_dict.get(key)
    #             lat = tup[0]
    #             lon = tup[1]
    #
    #             lon_idx = np.searchsorted(lon_space, lon)
    #             lat_idx = np.searchsorted(lat_space, lat)
    #
    #             if lon_idx < 0 or lon_idx > grd_x_hi:
    #                 continue
    #             if lat_idx < 0 or lat_idx > grd_y_hi:
    #                 continue
    #
    #             grd_bins[lat_idx, lon_idx] += 1
    #
    #     return grd_bins
    
    
    # dt_str_0: start datetime string in format YYYY-MM-DD_HH:MM (default)
    # dt_str_1: end datetime string in format YYYY-MM-DD_HH:MM (default)
    # format_code: Python Datetime format code, default: '%Y-%m-%d_%H:%M'
    # return a flatten list of icing reports
    def time_filter(icing_dct, dt_str_0=None, dt_str_1=None, format_code='%Y-%m-%d_%H:%M'):
        ts_0 = 0
        if dt_str_0 is not None:
            dto_0 = datetime.datetime.strptime(dt_str_0, format_code).replace(tzinfo=timezone.utc)
            ts_0 = dto_0.timestamp()
    
        ts_1 = np.finfo(np.float64).max
        if dt_str_1 is not None:
            dto_1 = datetime.datetime.strptime(dt_str_1, format_code).replace(tzinfo=timezone.utc)
            ts_1 = dto_1.timestamp()
    
        keep_reports = []
        keep_times = []
        keep_lons = []
        keep_lats = []
    
        for ts in list(icing_dct.keys()):
            if ts_0 <= ts < ts_1:
                rpts = icing_dct[ts]
                for idx, tup in enumerate(rpts):
                    keep_reports.append(tup)
                    keep_times.append(ts)
                    keep_lats.append(tup[0])
                    keep_lons.append(tup[1])
    
        return keep_times, keep_lons, keep_lats, keep_reports
    
    
    # dt_str_0: start datetime string in format YYYY-MM-DD_HH:MM (default)
    # dt_str_1: end datetime string in format YYYY-MM-DD_HH:MM (default)
    # format_code: Python Datetime format code, default: '%Y-%m-%d_%H:%M'
    # return a flatten list of icing reports
    def time_filter_2(times, dt_str_0=None, dt_str_1=None, format_code='%Y-%m-%d_%H:%M'):
        ts_0 = 0
        if dt_str_0 is not None:
            dto_0 = datetime.datetime.strptime(dt_str_0, format_code).replace(tzinfo=timezone.utc)
            ts_0 = dto_0.timestamp()
    
        ts_1 = np.finfo(np.float64).max
        if dt_str_1 is not None:
            dto_1 = datetime.datetime.strptime(dt_str_1, format_code).replace(tzinfo=timezone.utc)
            ts_1 = dto_1.timestamp()
    
        keep_idxs = []
        keep_times = []
    
        for idx, ts in enumerate(times):
            if ts_0 <= ts < ts_1:
                keep_times.append(ts)
                keep_idxs.append(idx)
    
        return keep_times, keep_idxs
    
    
    def time_filter_3(icing_dct, ts_0, ts_1, alt_lo=None, alt_hi=None):
        keep_reports = []
        keep_times = []
        keep_lons = []
        keep_lats = []
    
        for ts in list(icing_dct.keys()):
            if ts_0 <= ts < ts_1:
                rpts = icing_dct[ts]
                for idx, tup in enumerate(rpts):
                    falt = tup[2]
                    if alt_lo is not None and (alt_lo < falt <= alt_hi):
                        keep_reports.append(tup)
                        keep_times.append(ts)
                        keep_lats.append(tup[0])
                        keep_lons.append(tup[1])
    
        return keep_times, keep_lons, keep_lats, keep_reports
    
    
    def analyze_moon_phase(icing_dict):
        ts = api.load.timescale()
        eph = api.load('de421.bsp')
    
        last_date = None
        moon_phase = None
        cnt = 0
    
        for key in list(icing_dict.keys()):
            dt_obj, dt_tup = get_time_tuple_utc(key)
            date = datetime.date(dt_tup.tm_year, dt_tup.tm_mon, dt_tup.tm_mday)
            if last_date != date:
                t = ts.utc(dt_tup.tm_year, dt_tup.tm_mon, dt_tup.tm_mday)
                moon_phase = almanac.moon_phase(eph, t)
                if 30 < moon_phase.degrees < 330:
                    cnt += 1
                last_date = date
            else:
                if 30 < moon_phase.degrees < 330:
                    cnt += 1
    
        print(len(icing_dict), cnt)
    
    
    def tiles_info(filename):
        h5f = h5py.File(filename, 'r')
        iint = h5f['icing_intensity'][:]
        print('No Icing: ', np.sum(iint == -1))
        print('Icing:    ', np.sum(iint > 0))
        print('Icing 1:  ', np.sum(iint == 1))
        print('Icing 2:  ', np.sum(iint == 2))
        print('Icing 3:  ', np.sum(iint == 3))
        print('Icing 4:  ', np.sum(iint == 4))
        print('Icing 5:  ', np.sum(iint == 5))
        print('Icing 6:  ', np.sum(iint == 6))
    
    
    def analyze(preds_file, test_file='/Users/tomrink/data/icing_ml/tiles_202109240000_202111212359_l2_test_v3_DAY.h5'):
        h5f = h5py.File(test_file, 'r')
        nda = h5f['flight_altitude'][:]
        iint = h5f['icing_intensity'][:]
        cld_hgt = h5f['cld_height_acha'][:]
        cld_dz = h5f['cld_geo_thick'][:]
        cld_tmp = h5f['cld_temp_acha'][:]
    
        iint = np.where(iint == -1, 0, iint)
        iint = np.where(iint != 0, 1, iint)
    
        nda[np.logical_and(nda >= 0, nda < 2000)] = 0
        nda[np.logical_and(nda >= 2000, nda < 4000)] = 1
        nda[np.logical_and(nda >= 4000, nda < 6000)] = 2
        nda[np.logical_and(nda >= 6000, nda < 8000)] = 3
        nda[np.logical_and(nda >= 8000, nda < 15000)] = 4
    
        print(np.sum(nda == 0), np.sum(nda == 1), np.sum(nda == 2), np.sum(nda == 3), np.sum(nda == 4))
        print('No icing:       ', np.histogram(nda[iint == 0], bins=5)[0])
        print('---------------------------')
        print('Icing:          ', np.histogram(nda[iint == 1], bins=5)[0])
        print('---------------------------')
    
        print('No Icing(Negative): mean cld_dz, cld_hgt')
        print('Icing(Positive):         ",          "')
        print('level 0: ')
        print(np.nanmean(cld_dz[(nda == 0) & (iint == 0)]), np.nanmean(cld_hgt[(nda == 0) & (iint == 0)]))
        print(np.nanmean(cld_dz[(nda == 0) & (iint == 1)]), np.nanmean(cld_hgt[(nda == 0) & (iint == 1)]))
        print('------------')
    
        print('level 1: ')
        print(np.nanmean(cld_dz[(nda == 1) & (iint == 0)]), np.nanmean(cld_hgt[(nda == 1) & (iint == 0)]))
        print(np.nanmean(cld_dz[(nda == 1) & (iint == 1)]), np.nanmean(cld_hgt[(nda == 1) & (iint == 1)]))
        print('------------')
    
        print('level 2: ')
        print(np.nanmean(cld_dz[(nda == 2) & (iint == 0)]), np.nanmean(cld_hgt[(nda == 2) & (iint == 0)]))
        print(np.nanmean(cld_dz[(nda == 2) & (iint == 1)]), np.nanmean(cld_hgt[(nda == 2) & (iint == 1)]))
        print('------------')
    
        print('level 3: ')
        print(np.nanmean(cld_dz[(nda == 3) & (iint == 0)]), np.nanmean(cld_hgt[(nda == 3) & (iint == 0)]))
        print(np.nanmean(cld_dz[(nda == 3) & (iint == 1)]), np.nanmean(cld_hgt[(nda == 3) & (iint == 1)]))
        print('------------')
    
        print('level 4: ')
        print(np.nanmean(cld_dz[(nda == 4) & (iint == 0)]), np.nanmean(cld_hgt[(nda == 4) & (iint == 0)]))
        print(np.nanmean(cld_dz[(nda == 4) & (iint == 1)]), np.nanmean(cld_hgt[(nda == 4) & (iint == 1)]))
        print('-----------------------------')
        print('----------------------------')
    
        if preds_file is None:
            return
    
        labels, prob_avg, cm_avg = pickle.load(open(preds_file, 'rb'))
    
        preds = np.where(prob_avg > 0.5, 1, 0)
    
        true_ice = (labels == 1) & (preds == 1)
        false_ice = (labels == 0) & (preds == 1)
    
        print('Total (Positive/Icing Prediction: ')
        print('True icing:                         ', np.histogram(nda[true_ice], bins=5)[0])
        print('-------------------------')
        print('False icing (False Positive/Alarm): ', np.histogram(nda[false_ice], bins=5)[0])
        print('By flight level:')
        print('No Icing(Negative): mean cld_dz, cld_hgt')
        print('Icing(Positive):         ",          "')
        print('level 0: ')
        print(np.nanmean(cld_dz[(nda == 0) & false_ice]), np.nanmean(cld_hgt[(nda == 0) & false_ice]))
        print(np.nanmean(cld_dz[(nda == 0) & true_ice]), np.nanmean(cld_hgt[(nda == 0) & true_ice]))
        print('------------')
    
        print('level 1: ')
        print(np.nanmean(cld_dz[(nda == 1) & false_ice]), np.nanmean(cld_hgt[(nda == 1) & false_ice]))
        print(np.nanmean(cld_dz[(nda == 1) & true_ice]), np.nanmean(cld_hgt[(nda == 1) & true_ice]))
        print('------------')
    
        print('level 2: ')
        print(np.nanmean(cld_dz[(nda == 2) & false_ice]), np.nanmean(cld_hgt[(nda == 2) & false_ice]))
        print(np.nanmean(cld_dz[(nda == 2) & true_ice]), np.nanmean(cld_hgt[(nda == 2) & true_ice]))
        print('------------')
    
        print('level 3: ')
        print(np.nanmean(cld_dz[(nda == 3) & false_ice]), np.nanmean(cld_hgt[(nda == 3) & false_ice]))
        print(np.nanmean(cld_dz[(nda == 3) & true_ice]), np.nanmean(cld_hgt[(nda == 3) & true_ice]))
        print('------------')
    
        print('level 4: ')
        print(np.nanmean(cld_dz[(nda == 4) & false_ice]), np.nanmean(cld_hgt[(nda == 4) & false_ice]))
        print(np.nanmean(cld_dz[(nda == 4) & true_ice]), np.nanmean(cld_hgt[(nda == 4) & true_ice]))
        print('-------------')
        print('-------------')
    
        true_no_ice = (labels == 0) & (preds == 0)
        false_no_ice = (labels == 1) & (preds == 0)
    
        print('Total (Negative/No Icing Prediction: ')
        print('True no icing:                             ', np.histogram(nda[true_no_ice], bins=5)[0])
        print('-------------------------')
        print('* False no icing (False Negative/Miss) *:  ', np.histogram(nda[false_no_ice], bins=5)[0])
        print('-------------------------')
    
        print('level 0: ')
        print(np.nanmean(cld_dz[(nda == 0) & false_no_ice]), np.nanmean(cld_hgt[(nda == 0) & false_no_ice]))
        print(np.nanmean(cld_dz[(nda == 0) & true_no_ice]), np.nanmean(cld_hgt[(nda == 0) & true_no_ice]))
        print('------------')
    
        print('level 1: ')
        print(np.nanmean(cld_dz[(nda == 1) & false_no_ice]), np.nanmean(cld_hgt[(nda == 1) & false_no_ice]))
        print(np.nanmean(cld_dz[(nda == 1) & true_no_ice]), np.nanmean(cld_hgt[(nda == 1) & true_no_ice]))
        print('------------')
    
        print('level 2: ')
        print(np.nanmean(cld_dz[(nda == 2) & false_no_ice]), np.nanmean(cld_hgt[(nda == 2) & false_no_ice]))
        print(np.nanmean(cld_dz[(nda == 2) & true_no_ice]), np.nanmean(cld_hgt[(nda == 2) & true_no_ice]))
        print('------------')
    
        print('level 3: ')
        print(np.nanmean(cld_dz[(nda == 3) & false_no_ice]), np.nanmean(cld_hgt[(nda == 3) & false_no_ice]))
        print(np.nanmean(cld_dz[(nda == 3) & true_no_ice]), np.nanmean(cld_hgt[(nda == 3) & true_no_ice]))
        print('------------')
    
        print('level 4: ')
        print(np.nanmean(cld_dz[(nda == 4) & false_no_ice]), np.nanmean(cld_hgt[(nda == 4) & false_no_ice]))
        print(np.nanmean(cld_dz[(nda == 4) & true_no_ice]), np.nanmean(cld_hgt[(nda == 4) & true_no_ice]))
    
    
    def get_training_parameters(day_night='DAY', l1b_andor_l2='BOTH'):
        if day_night == 'DAY':
            train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                               'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']
    
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
                                'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
        else:
            train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                               'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_acha', 'cld_opd_acha']
    
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom']
    
        if l1b_andor_l2 == 'BOTH':
            train_params = train_params_l1b + train_params_l2
        elif l1b_andor_l2 == 'l1b':
            train_params = train_params_l1b
        elif l1b_andor_l2 == 'l2':
            train_params = train_params_l2
    
        return train_params
    
    
    flt_level_ranges = {k: None for k in range(5)}
    flt_level_ranges[0] = [0.0, 2000.0]
    flt_level_ranges[1] = [2000.0, 4000.0]
    flt_level_ranges[2] = [4000.0, 6000.0]
    flt_level_ranges[3] = [6000.0, 8000.0]
    flt_level_ranges[4] = [8000.0, 15000.0]
    
    
    def run_make_images(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', ckpt_dir_s_path='/Users/tomrink/tf_model/', prob_thresh=0.5, satellite='GOES16', domain='CONUS',
                        extent=[-105, -70, 15, 50],
                        pirep_file='/Users/tomrink/data/pirep/pireps_202109200000_202109232359.csv',
                        obs_lons=None, obs_lats=None, obs_times=None, obs_alt=None, flight_level=None,
                        use_flight_altitude=False, day_night='DAY', l1b_andor_l2='l2'):
    
        if pirep_file is not None:
            ice_dict, no_ice_dict, neg_ice_dict = setup(pirep_file)
    
        if satellite == 'H08':
            clvrx_ds = CLAVRx_H08(clvrx_dir)
        else:
            clvrx_ds = CLAVRx(clvrx_dir)
        clvrx_files = clvrx_ds.flist
    
        alt_lo, alt_hi = 0.0, 15000.0
        if flight_level is not None:
            alt_lo, alt_hi = flt_level_ranges[flight_level]
    
        train_params = get_training_parameters(day_night=day_night, l1b_andor_l2=l1b_andor_l2)
    
        for fidx, fname in enumerate(clvrx_files):
            h5f = h5py.File(fname, 'r')
            dto = clvrx_ds.get_datetime(fname)
            ts = dto.timestamp()
            clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')
    
            data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain)
            num_elems, num_lines = len(cc), len(ll)
    
            dto, _ = get_time_tuple_utc(ts)
            dto_0 = dto - datetime.timedelta(minutes=30)
            dto_1 = dto + datetime.timedelta(minutes=30)
            ts_0 = dto_0.timestamp()
            ts_1 = dto_1.timestamp()
    
            if pirep_file is not None:
                _, keep_lons, keep_lats, _ = time_filter_3(ice_dict, ts_0, ts_1, alt_lo, alt_hi)
            elif obs_times is not None:
                keep = np.logical_and(obs_times >= ts_0, obs_times < ts_1)
                keep = np.where(keep, np.logical_and(obs_alt >= alt_lo, obs_alt < alt_hi), False)
                keep_lons = obs_lons[keep]
                keep_lats = obs_lats[keep]
            else:
                keep_lons = None
                keep_lats = None
    
            ice_lons, ice_lats, preds_2d = run_evaluate_static_avg(data_dct, ll, cc, ckpt_dir_s_path=ckpt_dir_s_path,
                                                               flight_level=flight_level, prob_thresh=prob_thresh,
                                                               satellite=satellite, domain=domain,
                                                               use_flight_altitude=use_flight_altitude)
    
            make_icing_image(h5f, None, ice_lons, ice_lats, clvrx_str_time, satellite, domain,
                             ice_lons_vld=keep_lons, ice_lats_vld=keep_lats, extent=extent)
    
            # preds_2d_dct, probs_2d_dct = run_evaluate_static(data_dct, num_lines, num_elems, day_night=day_night,
            #                                                  ckpt_dir_s_path=ckpt_dir_s_path, prob_thresh=prob_thresh,
            #                                                  flight_levels=[0],
            #                                                  use_flight_altitude=use_flight_altitude)
            #
            # make_icing_image(None, probs_2d_dct[0], None, None, clvrx_str_time, satellite, domain,
            #                  ice_lons_vld=keep_lons, ice_lats_vld=keep_lats, extent=extent)
    
            h5f.close()
            print('Done: ', clvrx_str_time)
    
    
    def run_icing_predict(clvrx_dir='/Users/tomrink/data/clavrx/RadC/', output_dir=homedir, model_path=None,
                          prob_thresh=0.5, satellite='GOES16', domain='CONUS', day_night='DAY',
                          l1b_andor_l2='BOTH', use_flight_altitude=True):
    
        if day_night == 'DAY':
            if model_path is None:
                model_path = model_path_day
        else:
            if model_path is None:
                model_path = model_path_night
    
        train_params = get_training_parameters(day_night=day_night, l1b_andor_l2=l1b_andor_l2)
    
        if satellite == 'H08':
            clvrx_ds = CLAVRx_H08(clvrx_dir)
        else:
            clvrx_ds = CLAVRx(clvrx_dir)
        clvrx_files = clvrx_ds.flist
    
        for fidx, fname in enumerate(clvrx_files):
            h5f = h5py.File(fname, 'r')
            dto = clvrx_ds.get_datetime(fname)
            clvrx_str_time = dto.strftime('%Y-%m-%d_%H:%M')
    
            data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, satellite=satellite, domain=domain)
            # ancil_data_dct, _, _ = make_for_full_domain_predict(h5f, name_list=['cld_height_acha', 'cld_geo_thick'])
            solzen, satzen = make_for_full_domain_predict2(h5f, satellite=satellite, domain=domain)
            if fidx == 0:
                num_elems = len(cc)
                num_lines = len(ll)
                nav = get_navigation(satellite, domain)
                lons_2d, lats_2d, x_rad, y_rad = get_lon_lat_2d_mesh(nav, ll, cc)
    
            preds_2d_dct, probs_2d_dct = run_evaluate_static(data_dct, num_lines, num_elems, day_night=day_night,
                                                             ckpt_dir_s_path=model_path, prob_thresh=prob_thresh,
                                                             use_flight_altitude=use_flight_altitude)
            flt_lvls = list(preds_2d_dct.keys())
            for flvl in flt_lvls:
                probs = probs_2d_dct[flvl]
                preds = preds_2d_dct[flvl]
                keep = np.logical_or(lats_2d > -63.0, lats_2d < 63.0)
                keep = np.where(keep, satzen < 70, False)
                if day_night == 'DAY':
                    keep = np.where(keep, solzen < 80, False)
                np.where(keep, preds, -1)
                np.where(keep, probs, -1.0)
    
            write_icing_file(clvrx_str_time, output_dir, preds_2d_dct, probs_2d_dct, x_rad, y_rad, lons_2d, lats_2d, cc, ll)
    
            print('Done: ', clvrx_str_time)
            h5f.close()