from icing.pireps import pirep_icing
import numpy as np
import pickle
import os
from util.util import get_time_tuple_utc, GenericException, add_time_range_to_filename
from aeolus.datasource import CLAVRx, GOESL1B
from util.geos_nav import GEOSNavigation
import h5py
import re
import datetime
from datetime import timezone

goes_date_format = '%Y%j%H'
goes16_directory = '/arcdata/goes/grb/goes16'  # /year/date/abi/L1b/RadC
#clavrx_dir = '/apollo/cloud/scratch/ICING/'
clavrx_dir = '/ships19/cloud/scratch/ICING/'
dir_fmt = '%Y_%m_%d_%j'
# dir_list = [f.path for f in os.scandir('.') if f.is_dir()]
ds_dct = {}
goes_ds_dct = {}
#pirep_file = '/home/rink/data/pireps/pireps_2019010000_2019063023.csv'
pirep_file = '/home/rink/data/pireps/pireps_20180101_20200331.csv'

l1b_ds_list = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
               'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
               'refl_0_47um_nom', 'refl_0_55um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom',
               'refl_1_60um_nom']
l1b_ds_types = ['f4' for ds in l1b_ds_list]

ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'sensor_zenith_angle', 'supercooled_prob_acha',
           'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha', 'solar_zenith_angle',
           'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
           'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
           'lwc_dcomp', 'cloud_type', 'cloud_phase', 'cloud_mask']
ds_types = ['f4' for i in range(21)] + ['i4' for i in range(3)]


a_clvr_file = '/home/rink/data/clavrx/clavrx_OR_ABI-L1b-RadC-M3C01_G16_s20190020002186.level2.nc'


def setup():
    ice_dict, no_ice_dict, neg_ice_dict = pirep_icing(pirep_file)
    return ice_dict, no_ice_dict, neg_ice_dict


def get_clavrx_datasource(timestamp):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)
    date_dir_str = dt_obj.strftime(dir_fmt)
    ds = ds_dct.get(date_dir_str)
    if ds is None:
        ds = CLAVRx(clavrx_dir + date_dir_str + '/')
        ds_dct[date_dir_str] = ds
    return ds


def get_goes_datasource(timestamp):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)

    yr_dir = str(dt_obj.timetuple().tm_year)
    date_dir = dt_obj.strftime(dir_fmt)

    files_path = goes16_directory + '/' + yr_dir + '/' + date_dir + '/abi' + '/L1b' + '/RadC/'
    ds = goes_ds_dct.get(date_dir)
    if ds is None:
        ds = GOESL1B(files_path)
        goes_ds_dct[date_dir] = ds
    return ds


def get_grid_values(h5f, grid_name, j_c, i_c, half_width, scale_factor_name='scale_factor', add_offset_name='add_offset'):
    hfds = h5f[grid_name]
    attrs = hfds.attrs
    ylen, xlen = hfds.shape

    j_l = j_c-half_width
    i_l = i_c-half_width
    if j_l < 0 or i_l < 0:
        return None

    j_r = j_c+half_width+1
    i_r = i_c+half_width+1
    if j_r >= ylen or i_r >= xlen:
        return None

    grd_vals = hfds[j_l:j_r, i_l:i_r]
    grd_vals = np.where(grd_vals == -999, np.nan, grd_vals)
    grd_vals = np.where(grd_vals == -32768, np.nan, grd_vals)

    if attrs is None:
        return grd_vals

    if scale_factor_name is not None:
        scale_factor = attrs.get(scale_factor_name)[0]
        grd_vals = grd_vals * scale_factor

    if add_offset_name is not None:
        add_offset = attrs.get(add_offset_name)[0]
        grd_vals = grd_vals + add_offset

    return grd_vals


def create_file(filename, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, icing_intensity, unq_ids):
    h5f_expl = h5py.File(a_clvr_file, 'r')
    h5f = h5py.File(filename, 'w')

    for idx, ds_name in enumerate(ds_list):
        data = data_dct[ds_name]
        h5f.create_dataset(ds_name, data=data, dtype=ds_types[idx])

    lon_ds = h5f.create_dataset('longitude', data=lon_c, dtype='f4')
    lon_ds.dims[0].label = 'time'
    lon_ds.attrs.create('units', data='degrees_east')
    lon_ds.attrs.create('long_name', data='PIREP longitude')

    lat_ds = h5f.create_dataset('latitude', data=lat_c, dtype='f4')
    lat_ds.dims[0].label = 'time'
    lat_ds.attrs.create('units', data='degrees_north')
    lat_ds.attrs.create('long_name', data='PIREP latitude')

    time_ds = h5f.create_dataset('time', data=time_s)
    time_ds.dims[0].label = 'time'
    time_ds.attrs.create('units', data='seconds since 1970-1-1 00:00:00')
    time_ds.attrs.create('long_name', data='PIREP time')

    ice_alt_ds = h5f.create_dataset('icing_altitude', data=fl_alt_s, dtype='f4')
    ice_alt_ds.dims[0].label = 'time'
    ice_alt_ds.attrs.create('units', data='m')
    ice_alt_ds.attrs.create('long_name', data='PIREP altitude')

    if icing_intensity is not None:
        icing_int_ds = h5f.create_dataset('icing_intensity', data=icing_intensity, dtype='i4')
        icing_int_ds.attrs.create('long_name', data='From PIREP. 0:No intensity report, 1:Trace, 2:Light, 3:Light Moderate, 4:Moderate, 5:Moderate Severe, 6:Severe')

    unq_ids_ds = h5f.create_dataset('unique_id', data=unq_ids, dtype='i4')
    unq_ids_ds.attrs.create('long_name', data='ID mapping to PIREP icing dictionary: see pireps.py')

    # copy relevant attributes
    for ds_name in ds_list:
        h5f_ds = h5f[ds_name]
        h5f_ds.attrs.create('standard_name', data=h5f_expl[ds_name].attrs.get('standard_name'))
        h5f_ds.attrs.create('long_name', data=h5f_expl[ds_name].attrs.get('long_name'))
        h5f_ds.attrs.create('units', data=h5f_expl[ds_name].attrs.get('units'))

        h5f_ds.dims[0].label = 'time'
        h5f_ds.dims[1].label = 'y'
        h5f_ds.dims[2].label = 'x'

    h5f.close()
    h5f_expl.close()


def run(pirep_dct, outfile=None, outfile_l1b=None, dt_str_start=None, dt_str_end=None, reduce=False):
    time_keys = list(pirep_dct.keys())
    l1b_grd_dct = {name: [] for name in l1b_ds_list}
    ds_grd_dct = {name: [] for name in ds_list}

    t_start = None
    t_end = None
    if (dt_str_start is not None) and (dt_str_end is not None):
        dto = datetime.datetime.strptime(dt_str_start, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
        dto.replace(tzinfo=timezone.utc)
        t_start = dto.timestamp()

        dto = datetime.datetime.strptime(dt_str_end, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
        dto.replace(tzinfo=timezone.utc)
        t_end = dto.timestamp()

    nav = GEOSNavigation(sub_lon=-75.0, CFAC=5.6E-05, COFF=-0.101332, LFAC=-5.6E-05, LOFF=0.128212, num_elems=2500, num_lines=1500)

    lon_s = np.zeros(1)
    lat_s = np.zeros(1)
    last_clvr_file = None
    last_h5f = None

    lon_c = []
    lat_c = []
    time_s = []
    fl_alt_s = []
    ice_int_s = []
    unq_ids = []

    for idx, time in enumerate(time_keys):
        if t_start is not None:
            if time < t_start:
                continue
            if time > t_end:
                continue

        try:
            clvr_ds = get_clavrx_datasource(time)
        except Exception:
            continue

        clvr_file = clvr_ds.get_file(time)[0]
        if clvr_file is None:
            continue

        if clvr_file != last_clvr_file:
            try:
                h5f = h5py.File(clvr_file, 'r')
            except Exception:
                if h5f is not None:
                    h5f.close()
                print('Problem with file: ', clvr_file)
                continue
            if last_h5f is not None:
                last_h5f.close()
            last_h5f = h5f
            last_clvr_file = clvr_file
        else:
            h5f = last_h5f

        reports = pirep_dct[time]
        for tup in reports:
            lat, lon, fl, I, uid, rpt_str = tup
            lat_s[0] = lat
            lon_s[0] = lon

            cc, ll = nav.earth_to_lc_s(lon_s, lat_s)
            if cc[0] < 0:
                continue

            cnt_a = 0
            for didx, ds_name in enumerate(ds_list):
                gvals = get_grid_values(h5f, ds_name, ll[0], cc[0], 20)
                if gvals is not None:
                    ds_grd_dct[ds_name].append(gvals)
                    cnt_a += 1

            cnt_b = 0
            for didx, ds_name in enumerate(l1b_ds_list):
                gvals = get_grid_values(h5f, ds_name, ll[0], cc[0], 20)
                if gvals is not None:
                    l1b_grd_dct[ds_name].append(gvals)
                    cnt_b += 1

            if cnt_a > 0 and cnt_a != len(ds_list):
                raise GenericException('weirdness')
            if cnt_b > 0 and cnt_b != len(l1b_ds_list):
                raise GenericException('weirdness')

            if cnt_a == len(ds_list) and cnt_b == len(l1b_ds_list):
                lon_c.append(lon_s[0])
                lat_c.append(lat_s[0])
                time_s.append(time)
                fl_alt_s.append(fl)
                ice_int_s.append(I)
                unq_ids.append(uid)

            if reduce is True:
                break

    if len(time_s) == 0:
        return

    t_start = time_s[0]
    t_end = time_s[len(time_s)-1]

    data_dct = {}
    for ds_name in ds_list:
        data_dct[ds_name] = np.array(ds_grd_dct[ds_name])
    lon_c = np.array(lon_c)
    lat_c = np.array(lat_c)
    time_s = np.array(time_s)
    fl_alt_s = np.array(fl_alt_s)
    ice_int_s = np.array(ice_int_s)
    unq_ids = np.array(unq_ids)

    if outfile is not None:
        outfile = add_time_range_to_filename(outfile, t_start, t_end)
        create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)

    data_dct = {}
    for ds_name in l1b_ds_list:
        data_dct[ds_name] = np.array(l1b_grd_dct[ds_name])

    if outfile_l1b is not None:
        outfile_l1b = add_time_range_to_filename(outfile_l1b, t_start, t_end)
        create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)


def analyze(ice_dct, no_ice_dct):

    last_file = None
    ice_files = []
    ice_times = []
    for ts in list(ice_dct.keys()):
        try:
            ds = get_goes_datasource(ts)
            goes_file, t_0, _ = ds.get_file(ts)
            if goes_file is not None and goes_file != last_file:
                ice_files.append(goes_file)
                ice_times.append(t_0)
                last_file = goes_file
        except Exception:
            continue

    last_file = None
    no_ice_files = []
    no_ice_times = []
    for ts in list(no_ice_dct.keys()):
        try:
            ds = get_goes_datasource(ts)
            goes_file, t_0, _ = ds.get_file(ts)
            if goes_file is not None and goes_file != last_file:
                no_ice_files.append(goes_file)
                no_ice_times.append(t_0)
                last_file = goes_file
        except Exception:
            continue

    ice_times = np.array(ice_times)
    no_ice_times = np.array(no_ice_times)

    itrsct_vals, comm1, comm2 = np.intersect1d(no_ice_times, ice_times, return_indices=True)

    ice_indexes = np.arange(len(ice_times))

    ucomm2 = np.setxor1d(comm2, ice_indexes)
    np.random.seed(42)
    np.random.shuffle(ucomm2)
    ucomm2 = ucomm2[0:8000]

    files_comm = []
    for i in comm2:
        files_comm.append(ice_files[i])

    files_extra = []
    times_extra = []
    for i in ucomm2:
        files_extra.append(ice_files[i])
        times_extra.append(ice_times[i])

    files = files_comm + files_extra
    times = itrsct_vals.tolist() + times_extra
    times = np.array(times)

    sidxs = np.argsort(times)
    for i in sidxs:
        filename = os.path.split(files[i])[1]
        so = re.search('_s\\d{11}', filename)
        dt_str = so.group()
        print(dt_str[2:])


def process_2(ice_dct, no_ice_dct, neg_ice_dct):
    new_ice_dct = {}
    new_no_ice_dct = {}
    new_neg_ice_dct = {}

    ice_keys_5_6 = []
    ice_keys_1 = []
    ice_keys_4 = []
    ice_keys_3 = []
    ice_keys_2 = []

    print('num keys ice, no_ice, neg_ice: ', len(ice_dct), len(no_ice_dct), len(neg_ice_dct))
    no_intensity_cnt = 0

    for ts in list(ice_dct.keys()):
        rpts = ice_dct[ts]
        for tup in rpts:
            if tup[3] == 5 or tup[3] == 6:
                ice_keys_5_6.append(ts)
            elif tup[3] == 1:
                ice_keys_1.append(ts)
            elif tup[3] == 4:
                ice_keys_4.append(ts)
            elif tup[3] == 3:
                ice_keys_3.append(ts)
            elif tup[3] == 2:
                ice_keys_2.append(ts)
            else:
                no_intensity_cnt += 1

    no_ice_keys = []
    for ts in list(no_ice_dct.keys()):
        rpts = no_ice_dct[ts]
        for tup in rpts:
            no_ice_keys.append(ts)

    neg_ice_keys = []
    for ts in list(neg_ice_dct.keys()):
            rpts = neg_ice_dct[ts]
            for tup in rpts:
                neg_ice_keys.append(ts)

    ice_keys_5_6 = np.array(ice_keys_5_6)
    print('5_6: ', ice_keys_5_6.shape)

    ice_keys_4 = np.array(ice_keys_4)
    print('4: ', ice_keys_4.shape)

    ice_keys_3 = np.array(ice_keys_3)
    print('3: ', ice_keys_3.shape)

    ice_keys_2 = np.array(ice_keys_2)
    print('2: ', ice_keys_2.shape)
    np.random.seed(42)
    np.random.shuffle(ice_keys_2)
    ice_keys_2 = ice_keys_2[0:30000]

    ice_keys_1 = np.array(ice_keys_1)
    print('1: ', ice_keys_1.shape)
    print('no intensity: ', no_intensity_cnt)

    ice_keys = np.concatenate([ice_keys_5_6, ice_keys_1, ice_keys_2, ice_keys_3, ice_keys_4])
    uniq_sorted_keys = np.unique(ice_keys)
    print('ice: ', ice_keys.shape, uniq_sorted_keys.shape)

    uniq_sorted_keys = uniq_sorted_keys.tolist()
    for key in uniq_sorted_keys:
        new_ice_dct[key] = ice_dct[key]

    no_ice_keys = np.array(no_ice_keys)
    print('no ice total: ', no_ice_keys.shape)
    np.random.seed(42)
    np.random.shuffle(no_ice_keys)
    no_ice_keys = no_ice_keys[0:50000]
    uniq_sorted_no_ice = np.unique(no_ice_keys)
    print('no ice: ', no_ice_keys.shape, uniq_sorted_no_ice.shape)

    uniq_sorted_no_ice = uniq_sorted_no_ice.tolist()
    for key in uniq_sorted_no_ice:
        new_no_ice_dct[key] = no_ice_dct[key]

    neg_ice_keys = np.array(neg_ice_keys)
    print('neg ice total: ', neg_ice_keys.shape)
    np.random.seed(42)
    np.random.shuffle(neg_ice_keys)
    neg_ice_keys = neg_ice_keys[0:5000]
    uniq_sorted_neg_ice = np.unique(neg_ice_keys)
    print('neg ice: ', neg_ice_keys.shape, uniq_sorted_neg_ice.shape)

    for key in uniq_sorted_neg_ice:
        new_neg_ice_dct[key] = neg_ice_dct[key]

    return new_ice_dct, new_no_ice_dct, new_neg_ice_dct


def process_1(ice_dct, no_ice_dct, neg_ice_dct):
    new_ice_dct = {}
    new_no_ice_dct = {}
    new_neg_ice_dct = {}

    last_file = None
    ice_files_5_6 = []
    ice_times_5_6 = []
    ice_keys_5_6 = []
    ice_files_1 = []
    ice_times_1 = []
    ice_keys_1 = []
    ice_files_4 = []
    ice_times_4 = []
    ice_keys_4 = []
    ice_files_3 = []
    ice_times_3 = []
    ice_keys_3 = []
    ice_files_2 = []
    ice_times_2 = []
    ice_keys_2 = []
    print('num keys ice, no_ice, neg_ice: ', len(ice_dct), len(no_ice_dct), len(neg_ice_dct))

    for ts in list(ice_dct.keys()):
        try:
            ds = get_goes_datasource(ts)
            goes_file, t_0, _ = ds.get_file(ts)
            if goes_file is not None and goes_file != last_file:
                rpts = ice_dct[ts]
                for tup in rpts:
                    if tup[3] == 5 or tup[3] == 6:
                        ice_files_5_6.append(goes_file)
                        ice_times_5_6.append(t_0)
                        ice_keys_5_6.append(ts)
                    elif tup[3] == 1:
                        ice_files_1.append(goes_file)
                        ice_times_1.append(t_0)
                        ice_keys_1.append(ts)
                    elif tup[3] == 4:
                        ice_files_4.append(goes_file)
                        ice_times_4.append(t_0)
                        ice_keys_4.append(ts)
                    elif tup[3] == 3:
                        ice_files_3.append(goes_file)
                        ice_times_3.append(t_0)
                        ice_keys_3.append(ts)
                    else:
                        ice_files_2.append(goes_file)
                        ice_times_2.append(t_0)
                        ice_keys_2.append(ts)
                last_file = goes_file
        except Exception:
            continue

    last_file = None
    no_ice_files = []
    no_ice_times = []
    no_ice_keys = []
    for ts in list(no_ice_dct.keys()):
        try:
            ds = get_goes_datasource(ts)
            goes_file, t_0, _ = ds.get_file(ts)
            if goes_file is not None and goes_file != last_file:
                rpts = no_ice_dct[ts]
                for tup in rpts:
                    no_ice_files.append(goes_file)
                    no_ice_times.append(t_0)
                    no_ice_keys.append(ts)
                    last_file = goes_file
        except Exception:
            continue

    last_file = None
    neg_ice_files = []
    neg_ice_times = []
    neg_ice_keys = []
    for ts in list(neg_ice_dct.keys()):
        try:
            ds = get_goes_datasource(ts)
            goes_file, t_0, _ = ds.get_file(ts)
            if goes_file is not None and goes_file != last_file:
                rpts = neg_ice_dct[ts]
                for tup in rpts:
                    neg_ice_files.append(goes_file)
                    neg_ice_times.append(t_0)
                    neg_ice_keys.append(ts)
                    last_file = goes_file
        except Exception:
            continue

    ice_times_5_6 = np.array(ice_times_5_6)
    ice_keys_5_6 = np.array(ice_keys_5_6)
    print('5_6: ', ice_times_5_6.shape)

    ice_times_4 = np.array(ice_times_4)
    ice_keys_4 = np.array(ice_keys_4)
    print('4: ', ice_times_4.shape)

    ice_times_3 = np.array(ice_times_3)
    ice_keys_3 = np.array(ice_keys_3)
    print('3: ', ice_times_3.shape)

    ice_times_2 = np.array(ice_times_2)
    ice_keys_2 = np.array(ice_keys_2)
    print('2: ', ice_times_2.shape)
    np.random.seed(42)
    np.random.shuffle(ice_times_2)
    np.random.seed(42)
    np.random.shuffle(ice_keys_2)
    ice_keys_2 = ice_keys_2[0:30000]

    ice_times_1 = np.array(ice_times_1)
    ice_keys_1 = np.array(ice_keys_1)
    print('1: ', ice_times_1.shape)

    ice_times = np.concatenate([ice_times_5_6, ice_times_1, ice_times_2, ice_times_3, ice_times_4])
    ice_keys = np.concatenate([ice_keys_5_6, ice_keys_1, ice_keys_2, ice_keys_3, ice_keys_4])
    uniq_sorted = np.unique(ice_times)
    uniq_sorted_keys = np.unique(ice_keys)
    print(ice_times.shape, uniq_sorted.shape)
    print(ice_keys.shape, uniq_sorted_keys.shape)

    uniq_sorted_keys = uniq_sorted_keys.tolist()
    for key in uniq_sorted_keys:
        new_ice_dct[key] = ice_dct[key]

    no_ice_times = np.array(no_ice_times)
    neg_ice_times = np.array(neg_ice_times)
    print('no ice: ', no_ice_times.shape)
    print('neg ice: ', neg_ice_times.shape)
    no_ice_keys = np.array(no_ice_keys)
    np.random.seed(42)
    np.random.shuffle(no_ice_keys)
    no_ice_keys = no_ice_keys[0:50000]
    uniq_sorted_no_ice = np.unique(no_ice_keys)
    print(no_ice_keys.shape, uniq_sorted_no_ice.shape)

    uniq_sorted_no_ice = uniq_sorted_no_ice.tolist()
    for key in uniq_sorted_no_ice:
        new_no_ice_dct[key] = no_ice_dct[key]

    neg_ice_keys = np.array(neg_ice_keys)
    np.random.seed(42)
    np.random.shuffle(neg_ice_keys)
    neg_ice_keys = neg_ice_keys[0:5000]
    uniq_sorted_neg_ice = np.unique(neg_ice_keys)
    print(neg_ice_keys.shape, uniq_sorted_neg_ice.shape)

    for key in uniq_sorted_neg_ice:
        new_neg_ice_dct[key] = neg_ice_dct[key]

    return new_ice_dct, new_no_ice_dct, new_neg_ice_dct


def run_qc(filename, filename_l1b, outfile, outfile_l1b):
    f = h5py.File(filename, 'r')
    icing_alt = f['icing_altitude'][:]
    cld_top_hgt = f['cld_height_acha'][:, 10:30, 10:30]
    cld_phase = f['cloud_phase'][:, 10:30, 10:30]
    cld_opd = f['cld_opd_acha'][:, 10:30, 10:30]
    cld_opd_dc = f['cld_opd_dcomp'][:, 10:30, 10:30]
    cld_mask = f['cloud_mask'][:, 10:30, 10:30]

    f_l1b = h5py.File(filename_l1b, 'r')
    bt_11um = f_l1b['temp_11_0um_nom'][:, 10:30, 10:30]

    print('num pireps: ', len(icing_alt))

    mask = apply_qc_icing_pireps(icing_alt, cld_top_hgt, cld_phase, cld_opd, cld_mask, bt_11um)

    bts = []
    phs = []
    opd = []
    opd_dc = []
    keep_idxs =[]
    for i in range(len(mask)):
        if (np.sum(mask[i]) / 400) > 0.20:
            bts.append((bt_11um[i,].flatten())[mask[i]])
            phs.append((cld_phase[i,].flatten())[mask[i]])
            opd.append((cld_opd[i,].flatten())[mask[i]])
            #opd_dc.append(cld_opd_dc[i,].flatten())[mask[i]]
            keep_idxs.append(i)
        #else:
        #    bts.append((bt_11um[i,].flatten())[mask[i]])
    print('num valid pireps: ', len(bts))

    bts = np.concatenate(bts)
    phs = np.concatenate(phs)
    opd = np.concatenate(opd)
    #opd_dc = np.concatenate(opd_dc)

    print(np.histogram(bts, bins=10))
    print(np.histogram(opd, bins=10))
    #print(np.histogram(opd_dc, bins=10))
    print(np.histogram(phs, bins=6))

    keep_idxs = np.array(keep_idxs)

    data_dct = {}
    for didx, ds_name in enumerate(ds_list):
        data_dct[ds_name] = f[ds_name][keep_idxs,]

    lon_c = f['longitude'][keep_idxs]
    lat_c = f['latitude'][keep_idxs]
    time_s = f['time'][keep_idxs]
    fl_alt_s = f['icing_altitude'][keep_idxs]
    ice_int_s = f['icing_intensity'][keep_idxs]
    unq_ids = f['unique_id'][keep_idxs]

    create_file(outfile, data_dct, ds_list, ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)

    data_dct = {}
    for didx, ds_name in enumerate(l1b_ds_list):
        data_dct[ds_name] = f_l1b[ds_name][keep_idxs]

    create_file(outfile_l1b, data_dct, l1b_ds_list, l1b_ds_types, lon_c, lat_c, time_s, fl_alt_s, ice_int_s, unq_ids)

    f.close()
    f_l1b.close()

    return mask


def apply_qc_icing_pireps(icing_alt, cld_top_hgt, cld_phase, cld_opd, cld_mask, bt_11um):
    opd_threshold = 2
    closeness = 100.0  # meters
    num_obs = len(icing_alt)
    cld_mask = cld_mask.reshape((num_obs, -1))
    cld_top_hgt = cld_top_hgt.reshape((num_obs, -1))
    cld_phase = cld_phase.reshape((num_obs, -1))
    cld_opd = cld_opd.reshape((num_obs, -1))
    bt_11um = bt_11um.reshape((num_obs, -1))

    skip = False
    mask = []
    for i in range(num_obs):
        keep_0 = np.logical_or(cld_mask[i,] == 2, cld_mask[i,] == 3)  # cloudy
        keep_1 = np.invert(np.isnan(cld_top_hgt[i,]))
        keep_2 = np.invert(np.isnan(bt_11um[i,]))
        keep_3 = np.invert(np.isnan(cld_opd[i,]))
        keep = keep_0 & keep_1 & keep_2 & keep_3
        if skip:
            mask.append(keep)
            continue

        keep = np.where(keep, cld_top_hgt[i,] > icing_alt[i], False)

        keep = np.where(keep,
            np.invert((cld_phase[i,] == 4) &
                np.logical_and(cld_top_hgt[i,]+closeness > icing_alt[i], cld_top_hgt[i,]-closeness < icing_alt[i])),
                     False)

        keep = np.where(keep, (cld_opd[i,] >= opd_threshold) & (cld_phase[i,] == 4) & (cld_top_hgt[i,] > icing_alt[i]), False)

        keep = np.where(keep, np.invert((cld_phase[i,] == 4) & (cld_opd[i,] < 0.1) & (cld_top_hgt[i,] > icing_alt[i])), False)

        keep = np.where(keep, np.invert(bt_11um[i,] > 270.0), False)

        keep = np.where(keep, np.invert(bt_11um[i,] < 228.0), False)

        mask.append(keep)

    return mask