import datetime, os
from datetime import timezone
import glob
import numpy as np
import xarray as xr
from netCDF4 import Dataset
from aeolus.geos_nav import get_navigation
from aeolus.datasource import get_datasource, CLAVRx_CALIPSO, CLAVRx
from util.util import haversine_np, LatLonTuple, add_time_range_to_filename
from amv.intercompare import best_fit, bin_data_by, get_press_bin_ranges, spd_dir_from_uv, uv_from_spd_dir, \
    direction_difference, run_best_fit_gfs
import math
from metpy.units import units
from util.gfs_reader import get_vert_profile_s
from amv.intercompare import get_raob_dict_cdf
from util.line_plot import do_plot
import pickle


amv_file_duration = 60  # minutes
half_width = 50  # search box centered on AEOLUS profile (FGF coordinates)
num_elems = 5424
num_lines = 5424


def get_amvs(amv_ds, timestamp, filepath=None):
    if filepath is None:
        filepath, ftime, f_idx = amv_ds.get_file(timestamp)
    if filepath is None:
        return None
    amv_params = amv_ds.get_parameters()
    # TODO: Need to generalize this better
    amv_params = [amv_ds.lat_name, amv_ds.lon_name, amv_ds.elem_name, amv_ds.line_name] + amv_params
    ds = Dataset(filepath)

    param_s = []
    vld = None
    for name in amv_params:
        if name is None:
            continue
        param = ds[name][:]
        if vld is None:
            vld = np.invert(param.mask)
        else:
            vld = np.logical_and(vld, np.invert(param.mask))
        param_s.append(param.data)

    # filter
    qc_name = amv_ds.get_qc_params()
    if qc_name is not None:
        qc_param = ds[qc_name][:].data
        good = amv_ds.filter(qc_param)
        vld = np.logical_and(vld, good)

    param_nd = np.vstack(param_s)
    param_nd = param_nd[:, vld]
    param_nd = np.transpose(param_nd, axes=[1, 0])

    if amv_ds.elem_name is None:
        nav = amv_ds.get_navigation()
        cc, ll = nav.earth_to_lc_s(param_nd[:, 1], param_nd[:, 0])
        tmp_nd = np.insert(param_nd, 2, cc, axis=1)
        param_nd = np.insert(tmp_nd, 3, ll, axis=1)

    return param_nd


# For each raob position in raob_dct, find all AMVs in time/space window
# raob_dict: (lat,lon) -> profiles
# raob_time: nominal time of raob_dict
# amv_ds: AMV data source
# time_window: (minutes) to search for AMVs around raob_time
# filepath: AMV filepath for RAOB match. If None (default), determined from raob_time and AMV data source
# return dict: raob (lat, lon) -> tuple (amv_lon, amv_lat, elem, line, amv_pres, amv_spd, amv_dir), amv file
def match_amvs_to_raobs(raob_dict, raob_time, amv_ds, time_window=10, filepath=None):
    nav = amv_ds.get_navigation()
    amv_params = amv_ds.get_parameters()
    match_dict = {}

    if filepath is None:
        filepath, _, _ = amv_ds.get_file(raob_time, window=time_window)
    if filepath is None:
        return None, None

    nc4f = Dataset(filepath)

    amv_lons = nc4f[amv_ds.lon_name][:].data
    amv_lats = nc4f[amv_ds.lat_name][:].data
    if amv_ds.elem_name is None or amv_ds.line_name is None:
        cc, ll = nav.earth_to_lc_s(amv_lons, amv_lats)
    else:
        cc = nc4f[amv_ds.elem_name][:].data
        ll = nc4f[amv_ds.line_name][:].data

    num_amvs = amv_lons.shape[0]

    param_s = []
    param_s.append(amv_lons)
    param_s.append(amv_lats)
    param_s.append(cc)
    param_s.append(ll)

    vld = None
    for param in amv_params:
        data = nc4f[param][:]
        if vld is None:
            vld = np.invert(data.mask)
        else:
            vld = np.logical_and(vld, np.invert(data.mask))
        if param == 'V_3D':
            param_s.append(data.data[0])
            param_s.append(data.data[1])
        else:
            param_s.append(data.data)

    # filter
    qc_name = amv_ds.get_qc_params()
    if qc_name is not None:
        qc_param = nc4f[qc_name][:].data
        good = amv_ds.filter(qc_param)
        # vld = np.logical_and(vld, good)
        vld = good

    param_nd = np.vstack(param_s)
    print(param_nd.shape)
    param_nd = param_nd[:, vld]
    print(param_nd.shape)
    cc = param_nd[2, :]
    ll = param_nd[3, :]

    nc4f.close()

    keys = list(raob_dict.keys())

    for key in keys:

        raob, sta_id = raob_dict.get(key)
        if raob is None:
            continue

        cc_raob, ll_raob = nav.earth_to_lc(key.lon, key.lat)
        if cc_raob is None or ll_raob is None:
            continue

        c_rng, l_rng = get_search_box(cc_raob, ll_raob)

        in_cc = np.logical_and(cc > c_rng[0], cc < c_rng[1])
        in_ll = np.logical_and(ll > l_rng[0], ll < l_rng[1])
        in_box = np.logical_and(in_cc, in_ll)

        if np.sum(in_box) == 0:
            continue
        match_dict[key] = param_nd[:, in_box]
    print('Done: ', filepath)

    return match_dict, filepath


def create_file2(filename, raob_to_amv_dct, raob_dct, amv_files):
    keys = list(raob_to_amv_dct.keys())

    num_amvs = []
    num_levs = []
    times = []

    namvs = 0
    nlevs = 0
    for key in keys:
        param_nd = raob_to_amv_dct.get(key)
        num_amvs.append(param_nd.shape[1])
        namvs += param_nd.shape[1]

        prof = raob_dct.get(key)
        num_levs.append(prof.shape[0])
        nlevs += prof.shape[0]

        # TODO: need a time
        times.append(0.0)

    # amv_per_alus = len(aeolus_to_amv_dct)
    rootgrp = Dataset(filename, 'w', format='NETCDF4')
    dim_amvs = rootgrp.createDimension('amvs', size=namvs)
    dim_alus = rootgrp.createDimension('raobs', size=nlevs)
    dim_num_aeolus_prof = rootgrp.createDimension('num_raobs', size=len(raob_to_amv_dct))

    nc4_vars = []
    out_params = amv_files.get_out_parameters()
    meta_dict = amv_files.get_meta_dict()

    for pidx, param in enumerate(out_params):
        u, t = meta_dict.get(param)
        var = rootgrp.createVariable(param, t, ['amvs'])
        if u is not None:
            var.units = u
        nc4_vars.append(var)

    dist = rootgrp.createVariable('dist_to_raob', 'f4', ['amvs'])
    dist.units = 'km'

    num_amvs_per_raob = rootgrp.createVariable('num_amvs_per_raob', 'i4', ['num_raobs'])
    num_levs_per_raob = rootgrp.createVariable('num_levs_per_raob', 'i4', ['num_raobs'])
    prof_time = rootgrp.createVariable('time', 'f4', ['num_raobs'])
    # ---- Profile variables ---------------
    prf_lon = rootgrp.createVariable('raob_longitude', 'f4', ['num_raobs'])
    prf_lon.units = 'degrees east'
    prf_lat = rootgrp.createVariable('raob_latitude', 'f4', ['num_raobs'])
    prf_lat.units = 'degrees north'
    prof_time.units = 'seconds since 1970-01-1 00:00:00'

    prf_azm = rootgrp.createVariable('prof_azm', 'f4', ['raobs'])
    prf_azm.units = 'degree'
    prf_spd = rootgrp.createVariable('prof_spd', 'f4', ['raobs'])
    prf_spd.units = 'm s-1'
    prf_prs = rootgrp.createVariable('prof_pres', 'f4', ['raobs'])
    prf_prs.units = 'hPa'
    prf_tmp = rootgrp.createVariable('prof_temp', 'f4', ['raobs'])
    prf_tmp.units = 'K'
    # --------------------------------------

    i_a = 0
    i_c = 0
    for idx, key in enumerate(keys):
        namvs = num_amvs[idx]
        nlevs = num_levs[idx]
        i_b = i_a + namvs
        i_d = i_c + nlevs

        prof = raob_dct.get(key)
        prf_azm[i_c:i_d] = prof[:, 2]
        prf_spd[i_c:i_d] = prof[:, 3]
        prf_prs[i_c:i_d] = prof[:, 0]
        prf_tmp[i_c:i_d] = prof[:, 1]
        i_c += nlevs

        plat = key.lat
        plon = key.lon
        prf_lat[idx::] = plat
        prf_lon[idx::] = plon

        param_nd = raob_to_amv_dct.get(key)
        for pidx, param in enumerate(out_params):
            nc4_vars[pidx][i_a:i_b] = param_nd[pidx, :]
        dist[i_a:i_b] = haversine_np(plon, plat, param_nd[0, :], param_nd[1, :])
        i_a += namvs

    num_amvs_per_raob[:] = num_amvs
    num_levs_per_raob[:] = num_levs
    prof_time[:] = times

    rootgrp.close()


def run_best_fit_driver(output_path, amv_dir, source, raob_path, gfs_path, product_dir=None, product=None):
    amv_ds = get_datasource(amv_dir, source)
    gfs_files = get_datasource(gfs_path, 'GFS')
    raob_ds = get_datasource(raob_path, 'RAOB')
    raob_files = raob_ds.flist
    prd_files = None
    if product_dir is not None:
        prd_files = get_datasource(product_dir, product)

    out_list = []
    amvs_list = []
    bfs_list = []
    rb_list = []
    bfs_gfs_list = []
    prd_list = []

    ts_first, tf_last = None, None

    for k, raob_filename in enumerate(raob_files):
        raob_dct = get_raob_dict_cdf(raob_filename)
        raob_locs = list(raob_dct.keys())
        ts = raob_ds.ftimes[k,0]
        m_d, amv_filename = match_amvs_to_raobs(raob_dct, ts, amv_ds)
        if m_d is None:
            continue

        gfs_file = gfs_files.get_file(ts)[0]
        if gfs_file is None:
            continue

        if ts_first is None:
            ts_first = ts
        ts_last = ts

        if gfs_file is not None:
            locs = np.array(raob_locs)
            xr_dataset = xr.open_dataset(gfs_file)
            gfs_press = xr_dataset['pressure levels']
            gfs_press = gfs_press.values
            gfs_press = gfs_press[::-1]

            lon_idx = LatLonTuple._fields.index('lon')
            lat_idx = LatLonTuple._fields.index('lat')

            uv_wind = get_vert_profile_s(xr_dataset, ['u-wind', 'v-wind'], locs[:, lon_idx], locs[:, lat_idx],
                                         method='nearest')
            uv_wind = uv_wind.values
            wspd, wdir = spd_dir_from_uv(uv_wind[0, :, :], uv_wind[1, :, :])
            wspd = wspd.magnitude
            wdir = wdir.magnitude
            gfs_spd = wspd[:, ::-1]
            gfs_dir = wdir[:, ::-1]

            gfs_at_raob_dct = {}
            for key_idx, loc in enumerate(raob_locs):
                gfs_at_raob_dct[loc] = (gfs_spd[key_idx], gfs_dir[key_idx], gfs_press)

        bf_dct = run_best_fit(m_d, raob_dct, gfs_at_raob_dct)

        prd_dct = None
        if prd_files is not None:
            prd_dct = get_product_at_locs(m_d, ts, prd_files)
            if prd_dct is None:
                continue

        outfile = output_path+'amv_bestfit.nc'
        outfile = add_time_range_to_filename(outfile, ts, None)
        create_bestfit_file(outfile, m_d, raob_dct, ts, gfs_at_raob_dct, bf_dct, prd_dct, amv_ds.get_parameters(), raob_filename, amv_filename)

        out_list.append((bf_dct, prd_dct))

    # aggregate and dump to a pickle file
    for tup in out_list:
        ab_dct = tup[0]
        pr_dct = tup[1]

        keys = list(ab_dct.keys())
        for key in keys:
            tup = ab_dct.get(key)
            amvs_list.append(tup[0])
            bfs_list.append(tup[1])
            rb_list.append(tup[2])
            bfs_gfs_list.append(tup[3])

        if pr_dct is not None:
            keys = list(pr_dct.keys())
            for key in keys:
                prd_list.append(pr_dct.get(key))

    amvs = np.concatenate(amvs_list)
    bfs = np.concatenate(bfs_list)
    rbm = np.concatenate(rb_list)
    bfs_gfs = np.concatenate(bfs_gfs_list)
    if len(prd_list) == 0:
        tup = (amvs, bfs, None, bfs_gfs, rbm)
    else:
        prd = np.concatenate(prd_list)
        tup = (amvs, bfs, prd, bfs_gfs, rbm)

    outfile = output_path + 'amv_bestfit.pkl'
    outfile = add_time_range_to_filename(outfile, ts_first, ts_last)

    f = open(outfile, 'wb')
    pickle.dump(tup, f)
    f.close()


def run_best_fit_gfs_driver(time_s, amv_dir, amv_source, gfs_path, product_dir, product_type):
    amv_files = get_datasource(amv_dir, amv_source)
    gfs_files = get_datasource(gfs_path, 'GFS')
    prd_files = get_datasource(product_dir, product_type)

    out_time_s = []
    amvs_list = []
    bfs_list = []
    prd_list = []

    for ts in time_s:
        amvs = get_amvs(amv_files, ts)
        if amvs is None:
            continue
        gfs_file = gfs_files.get_file(ts)[0]
        if gfs_file is None:
            continue
        bfs = run_best_fit_gfs(amvs, gfs_file, amv_lat_idx=0, amv_lon_idx=1, amv_prs_idx=4, amv_spd_idx=5,
                               amv_dir_idx=6)

        if product_dir is not None:
            alons = amvs[:, 0]
            alats = amvs[:, 1]
            prds = get_product_at_lat_lons(prd_files, ts, alons, alats, filepath=None)
            if prds is None:
                continue
            prd_list.append(prds)

        out_time_s.append(ts)
        amvs_list.append(amvs)
        bfs_list.append(np.array(bfs))

    if len(out_time_s) == 0:
        return (None, None, None)

    out_time_s = np.array(out_time_s)
    amvs = np.concatenate(amvs_list)
    bfs = np.concatenate(bfs_list)
    if len(prd_list) == 0:
        return (amvs, bfs, None, out_time_s)

    prd = np.concatenate(prd_list)
    return (amvs, bfs, prd, out_time_s)


def get_product_at_locs(raob_to_amv_dct, ts, files, filepath=None):
    keys = list(raob_to_amv_dct.keys())
    m_dct = {}

    nav = files.get_navigation()
    the_params = files.get_parameters()
    num_params = len(the_params)

    if filepath is None:
        filepath, ftime, f_idx = files.get_file(ts)
        if filepath is None:
            return None
    ds = Dataset(filepath)

    param_s = []
    for pstr in the_params:
        data = ds[pstr][:]
        data = data.data
        param_s.append(data)
    param_nd = np.stack(param_s)

    for key in keys:
        amvs = raob_to_amv_dct.get(key)
        num_amvs = amvs.shape[1]
        alons = amvs[0, :]
        alats = amvs[1, :]

        cc, ll = nav.earth_to_lc_s(alons, alats)

        aaa = np.zeros((num_params, num_amvs), dtype=np.float)
        for k in range(num_amvs):
            aaa[:, k] = param_nd[:, ll[k], cc[k]]

        m_dct[key] = np.transpose(aaa, axes=[1, 0])

    ds.close()

    return m_dct


def get_product_at_locs1x1(raob_to_amv_dct, ts, files, filepath=None):
    keys = list(raob_to_amv_dct.keys())
    m_dct = {}

    nav = files.get_navigation()
    the_params = files.get_parameters()
    num_params = len(the_params)

    if filepath is None:
        filepath, ftime, f_idx = files.get_file(ts)
    ds = Dataset(filepath)

    var_s = []
    for pstr in the_params:
        var = ds[pstr]
        var_s.append(var)

    for key in keys:
        amvs = raob_to_amv_dct.get(key)
        num_amvs = amvs.shape[1]
        alons = amvs[0, :]
        alats = amvs[1, :]

        cc, ll = nav.earth_to_lc_s(alons, alats)

        aaa = np.zeros((num_params, num_amvs), dtype=np.float)
        for vidx, var in enumerate(var_s):
            for k in range(num_amvs):
                aaa[vidx, k] = var[ll[k], cc[k]].data

        m_dct[key] = np.transpose(aaa, axes=[1,0])

    ds.close()

    return m_dct


def get_product_at_lat_lons(files, ts, lons, lats, filepath=None):
    nav = files.get_navigation()
    the_params = files.get_parameters()
    num_params = len(the_params)

    if filepath is None:
        filepath, ftime, f_idx = files.get_file(ts)
        if filepath is None:
            return None
    ds = Dataset(filepath)

    var_s = []
    for pstr in the_params:
        var = ds[pstr]
        var_s.append(var)

    num = lons.shape[0]

    cc, ll = nav.earth_to_lc_s(lons, lats)

    aaa = np.zeros((num_params, num), dtype=np.float)
    for vidx, var in enumerate(var_s):
        for k in range(num):
            aaa[vidx, k] = var[ll[k], cc[k]].data

    return np.transpose(aaa, axes=[1, 0])


def run_best_fit(raob_to_amv_dct, raob_dct, gfs_at_raob_dct):
    out_dct = {}
    keys = list(raob_to_amv_dct.keys())

    do_gfs_best_fit = False
    if gfs_at_raob_dct is not None:
        do_gfs_best_fit = True

    for key_idx, key in enumerate(keys):
        raob, sta_id = raob_dct.get(key)
        raob_prs = raob[:, 0]
        raob_spd = raob[:, 3]
        raob_dir = raob[:, 2]
        gfs_spd, gfs_dir, gfs_press = gfs_at_raob_dct.get(key)

        amvs = raob_to_amv_dct.get(key)
        num_amvs = amvs.shape[1]
        bf_list = []
        raob_match_list = []
        bf_gfs_list = []
        for i in range(num_amvs):
            amv_lon = amvs[0, i]
            amv_lat = amvs[1, i]
            amv_prs = amvs[4, i]
            amv_spd = amvs[5, i]
            amv_dir = amvs[6, i]

            bf = best_fit(amv_spd, amv_dir, amv_prs, amv_lat, amv_lon, raob_spd, raob_dir, raob_prs)
            bf_list.append(bf)

            pdiff = np.abs(amv_prs - raob_prs)
            lev_idx = np.argmin(pdiff)
            if raob_spd[lev_idx] < 0 or raob_dir[lev_idx] < 0:
                tup = (raob_spd[lev_idx], raob_dir[lev_idx], raob_prs[lev_idx], -9)
            elif pdiff[lev_idx] > 60.0:
                tup = (raob_spd[lev_idx], raob_dir[lev_idx], raob_prs[lev_idx], -9)
            else:
                uv = uv_from_spd_dir(raob_spd[lev_idx], raob_dir[lev_idx])
                tup = (uv[0].magnitude, uv[1].magnitude, raob_prs[lev_idx], 0)
            raob_match_list.append(tup)

            if do_gfs_best_fit:
                bf = best_fit(amv_spd, amv_dir, amv_prs, amv_lat, amv_lon, gfs_spd, gfs_dir, gfs_press)
                bf_gfs_list.append(bf)

        out_dct[key] = (np.transpose(amvs, axes=[1, 0]), np.array(bf_list), np.array(raob_match_list), np.array(bf_gfs_list))

    return out_dct


# def analyze_bestfit_product_file(file):
#     xr_ds = xr.load_dataset(file)


def analyze2(amvs, bfs, raob_match, bfs_gfs, amv_prod):

    good_amvs = amvs
    num_good = good_amvs.shape[0]
    didx = 6
    sidx = 5
    pidx = 4

    print('Number of AMVs: {0:d}'.format(num_good))
    spd_min = good_amvs[:, sidx].min()
    spd_max = good_amvs[:, sidx].max()
    print('spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(spd_min, spd_max, np.average(good_amvs[:, sidx])))

    p_min = good_amvs[:, pidx].min()
    p_max = good_amvs[:, pidx].max()
    print('pres min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(p_min, p_max, np.average(good_amvs[:, pidx])))

    low = good_amvs[:, pidx] >= 700
    mid = np.logical_and(good_amvs[:, pidx] < 700, good_amvs[:, pidx] > 400)
    hgh = good_amvs[:, pidx] <= 400

    n_low = np.sum(low)
    n_mid = np.sum(mid)
    n_hgh = np.sum(hgh)

    print('% low: {0:.2f}'.format(100.0*(n_low/num_good)))
    print('% mid: {0:.2f}'.format(100.0*(n_mid/num_good)))
    print('% hgh: {0:.2f}'.format(100.0*(n_hgh/num_good)))
    print('---------------------------')

    print('Low Spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[low, sidx].min(), good_amvs[low, sidx].max(), good_amvs[low,sidx].mean()))
    print('Low Press min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[low, pidx].min(), good_amvs[low, pidx].max(), good_amvs[low, pidx].mean()))
    print('---------------------------')

    print('Mid Spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[mid, sidx].min(), good_amvs[mid, sidx].max(), good_amvs[mid, sidx].mean()))
    print('Mid Press min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[mid, pidx].min(), good_amvs[mid, pidx].max(), good_amvs[mid, pidx].mean()))
    print('---------------------------')

    print('Hgh Spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[hgh, sidx].min(), good_amvs[hgh, sidx].max(), good_amvs[hgh, sidx].mean()))
    print('Hgh Press min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[hgh, pidx].min(), good_amvs[hgh, pidx].max(), good_amvs[hgh, pidx].mean()))

    # Comparison to Level of Best Fit (LBF) ------------------------------------------------------------------
    # --------------------------------------------------------------------------------------------------------
    bin_size = 100.0
    bin_ranges = get_press_bin_ranges(100, 1000, bin_size=bin_size)

    vld_bf = bfs[:, 3] == 0
    keep_idxs = vld_bf

    amv_p = good_amvs[keep_idxs, pidx]
    bf_p = bfs[keep_idxs, 2]
    diff = amv_p - bf_p
    mad = np.average(np.abs(diff))
    bias = np.average(diff)
    print('********************************************************')
    print('Number of good best fits to RAOB: ', bf_p.shape[0])
    print('press, MAD: {0:.2f}'.format(mad))
    print('press, bias: {0:.2f}'.format(bias))
    pd_std = np.std(diff)
    pd_mean = np.mean(diff)
    print('press bias/rms: {0:.2f}  {1:.2f} '.format(pd_mean, np.sqrt(pd_mean**2 + pd_std**2)))
    print('------------------------------------------')

    bin_pres = bin_data_by(diff, amv_p, bin_ranges)

    amv_spd = good_amvs[keep_idxs, sidx]
    amv_dir = good_amvs[keep_idxs, didx]
    bf_spd, bf_dir = spd_dir_from_uv(bfs[keep_idxs, 0], bfs[keep_idxs, 1])

    diff = amv_spd * units('m/s') - bf_spd
    diff = diff.magnitude
    spd_mad = np.average(np.abs(diff))
    spd_bias = np.average(diff)
    print('spd, MAD: {0:.2f}'.format(spd_mad))
    print('spd, bias: {0:.2f}'.format(spd_bias))
    spd_mean = np.mean(diff)
    spd_std = np.std(diff)
    print('spd MAD/bias/rms: {0:.2f}  {1:.2f}  {2:.2f}'.format(np.average(np.abs(diff)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2)))
    print('-----------------')
    bin_spd = bin_data_by(diff, amv_p, bin_ranges)

    dir_diff = direction_difference(amv_dir, bf_dir.magnitude)
    print('dir, MAD: {0:.2f}'.format(np.average(np.abs(dir_diff))))
    print('dir bias: {0:.2f}'.format(np.average(dir_diff)))
    print('-------------------------------------')
    bin_dir = bin_data_by(dir_diff, amv_p, bin_ranges)

    amv_u, amv_v = uv_from_spd_dir(good_amvs[keep_idxs, sidx], good_amvs[keep_idxs, didx])
    u_diffs = amv_u - (bfs[keep_idxs, 0] * units('m/s'))
    v_diffs = amv_v - (bfs[keep_idxs, 1] * units('m/s'))

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: {0:.2f}  {1:.2f}'.format(vd_mean, np.sqrt(vd_mean**2 + vd_std**2)))
    print('******************************************************')

    x_values = []
    num_pres = []
    num_spd = []
    num_dir = []
    print('level   num cases   hgt MAD/bias    spd MAD/bias     dir MAD/bias')
    print('-------------------------------------------------------------------')
    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))
        num_pres.append(bin_pres[i].shape[0])
        num_spd.append(bin_spd[i].shape[0])
        num_dir.append(bin_dir[i].shape[0])

        print('{0:d}        {1:d}       {2:.2f}/{3:.2f}       {4:.2f}/{5:.2f}        {6:.2f}/{7:.2f}'
              .format(int(x_values[i]), num_pres[i], np.average(np.abs(bin_pres[i])), np.average(bin_pres[i]),
                      np.average(np.abs(bin_spd[i])), np.average(bin_spd[i]), np.average(np.abs(bin_dir[i])), np.average(bin_dir[i])))

    # Comparison to Level of Best Fit (LBF) GFS ------------------------------------------------------------------
    # ------------------------------------------------------------------------------------------------------------
    bfs_raob = bfs
    bfs = bfs_gfs
    vld_bf = bfs[:, 3] == 0
    keep_idxs = vld_bf

    amv_p = good_amvs[keep_idxs, pidx]
    bf_p = bfs[keep_idxs, 2]
    diff = amv_p - bf_p
    mad = np.average(np.abs(diff))
    bias = np.average(diff)
    print('********************************************************')
    print('Number of good best fits to GFS: ', bf_p.shape[0])
    print('press, MAD: {0:.2f}'.format(mad))
    print('press, bias: {0:.2f}'.format(bias))
    pd_std = np.std(diff)
    pd_mean = np.mean(diff)
    print('press bias/rms: {0:.2f}  {1:.2f} '.format(pd_mean, np.sqrt(pd_mean**2 + pd_std**2)))
    print('------------------------------------------')

    bin_pres = bin_data_by(diff, amv_p, bin_ranges)

    amv_spd = good_amvs[keep_idxs, sidx]
    amv_dir = good_amvs[keep_idxs, didx]
    bf_spd, bf_dir = spd_dir_from_uv(bfs[keep_idxs, 0], bfs[keep_idxs, 1])

    diff = amv_spd * units('m/s') - bf_spd
    diff = diff.magnitude
    spd_mad = np.average(np.abs(diff))
    spd_bias = np.average(diff)
    print('spd, MAD: {0:.2f}'.format(spd_mad))
    print('spd, bias: {0:.2f}'.format(spd_bias))
    spd_mean = np.mean(diff)
    spd_std = np.std(diff)
    print('spd MAD/bias/rms: {0:.2f}  {1:.2f}  {2:.2f}'.format(np.average(np.abs(diff)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2)))
    print('-----------------')
    bin_spd = bin_data_by(diff, amv_p, bin_ranges)

    dir_diff = direction_difference(amv_dir, bf_dir.magnitude)
    print('dir, MAD: {0:.2f}'.format(np.average(np.abs(dir_diff))))
    print('dir bias: {0:.2f}'.format(np.average(dir_diff)))
    print('-------------------------------------')
    bin_dir = bin_data_by(dir_diff, amv_p, bin_ranges)

    amv_u, amv_v = uv_from_spd_dir(good_amvs[keep_idxs, sidx], good_amvs[keep_idxs, didx])
    u_diffs = amv_u - (bfs[keep_idxs, 0] * units('m/s'))
    v_diffs = amv_v - (bfs[keep_idxs, 1] * units('m/s'))

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: {0:.2f}  {1:.2f}'.format(vd_mean, np.sqrt(vd_mean**2 + vd_std**2)))
    print('******************************************************')

    x_values = []
    num_pres = []
    num_spd = []
    num_dir = []
    print('level   num cases   hgt MAD/bias    spd MAD/bias     dir MAD/bias')
    print('-------------------------------------------------------------------')
    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))
        num_pres.append(bin_pres[i].shape[0])
        num_spd.append(bin_spd[i].shape[0])
        num_dir.append(bin_dir[i].shape[0])

        print('{0:d}        {1:d}       {2:.2f}/{3:.2f}       {4:.2f}/{5:.2f}        {6:.2f}/{7:.2f}'
              .format(int(x_values[i]), num_pres[i], np.average(np.abs(bin_pres[i])), np.average(bin_pres[i]),
                      np.average(np.abs(bin_spd[i])), np.average(bin_spd[i]), np.average(np.abs(bin_dir[i])), np.average(bin_dir[i])))

    # Direct comparison to RAOB profile ---------------------------------------------------------------
    # -------------------------------------------------------------------------------------------------
    vld = raob_match[:, 3] == 0
    keep_idxs = vld

    amv_p = good_amvs[keep_idxs, pidx]
    bf_p = raob_match[keep_idxs, 2]
    diff = amv_p - bf_p
    mad = np.average(np.abs(diff))
    bias = np.average(diff)
    print('********************************************************')
    print('Number of good raob matches: ', bf_p.shape[0])
    print('press, MAD: {0:.2f}'.format(mad))
    print('press, bias: {0:.2f}'.format(bias))
    pd_std = np.std(diff)
    pd_mean = np.mean(diff)
    print('press bias/rms: {0:.2f}  {1:.2f} '.format(pd_mean, np.sqrt(pd_mean**2 + pd_std**2)))
    print('------------------------------------------')
    bin_pres = bin_data_by(diff, amv_p, bin_ranges)


    amv_spd = good_amvs[keep_idxs, sidx]
    amv_dir = good_amvs[keep_idxs, didx]
    bf_spd, bf_dir = raob_match[keep_idxs, 0], raob_match[keep_idxs, 1]

    #diff = amv_spd * units('m/s') - bf_spd
    #diff = diff.magnitude
    diff = amv_spd - bf_spd
    spd_mad = np.average(np.abs(diff))
    spd_bias = np.average(diff)
    print('spd, MAD: {0:.2f}'.format(spd_mad))
    print('spd, bias: {0:.2f}'.format(spd_bias))
    spd_mean = np.mean(diff)
    spd_std = np.std(diff)
    print('spd MAD/bias/rms: {0:.2f}  {1:.2f}  {2:.2f}'.format(np.average(np.abs(diff)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2)))
    print('-----------------')
    bin_spd = bin_data_by(diff, amv_p, bin_ranges)

    dir_diff = direction_difference(amv_dir, bf_dir)
    print('dir, MAD: {0:.2f}'.format(np.average(np.abs(dir_diff))))
    print('dir bias: {0:.2f}'.format(np.average(dir_diff)))
    print('-------------------------------------')
    bin_dir = bin_data_by(dir_diff, amv_p, bin_ranges)

    amv_u, amv_v = uv_from_spd_dir(good_amvs[keep_idxs, sidx], good_amvs[keep_idxs, didx])
    bf_u, bf_v = uv_from_spd_dir(raob_match[keep_idxs, 0], raob_match[keep_idxs, 1])
    u_diffs = amv_u - bf_u
    v_diffs = amv_v - bf_v

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: {0:.2f}  {1:.2f}'.format(vd_mean, np.sqrt(vd_mean**2 + vd_std**2)))
    print('******************************************************')

    x_values = []
    num_pres = []
    num_spd = []
    num_dir = []
    print('level   num cases   hgt MAD/bias    spd MAD/bias     dir MAD/bias')
    print('-------------------------------------------------------------------')
    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))
        num_pres.append(bin_pres[i].shape[0])
        num_spd.append(bin_spd[i].shape[0])
        num_dir.append(bin_dir[i].shape[0])

        print('{0:d}        {1:d}       {2:.2f}/{3:.2f}       {4:.2f}/{5:.2f}        {6:.2f}/{7:.2f}'
              .format(int(x_values[i]), num_pres[i], np.average(np.abs(bin_pres[i])), np.average(bin_pres[i]),
                      np.average(np.abs(bin_spd[i])), np.average(bin_spd[i]), np.average(np.abs(bin_dir[i])), np.average(bin_dir[i])))


    # Comparison to Level of Best Fits (LBF) GFS to RAOB ------------------------------------------------------------------
    # ------------------------------------------------------------------------------------------------------------
    bfs = bfs_raob
    vld_bf = bfs[:, 3] == 0
    vld_bf_gfs = bfs_gfs[:, 3] == 0
    keep_idxs = np.logical_and(vld_bf, vld_bf_gfs)

    amv_p = good_amvs[keep_idxs, pidx]
    bf_p = bfs[keep_idxs, 2]
    bf_p_gfs = bfs_gfs[keep_idxs, 2]
    diff = bf_p - bf_p_gfs
    mad = np.average(np.abs(diff))
    bias = np.average(diff)
    print('********************************************************')
    print('Number of good best fits to GFS: ', bf_p.shape[0])
    print('press, MAD: {0:.2f}'.format(mad))
    print('press, bias: {0:.2f}'.format(bias))
    pd_std = np.std(diff)
    pd_mean = np.mean(diff)
    print('press bias/rms: {0:.2f}  {1:.2f} '.format(pd_mean, np.sqrt(pd_mean**2 + pd_std**2)))
    print('------------------------------------------')

    bin_pres = bin_data_by(diff, amv_p, bin_ranges)

    amv_spd = good_amvs[keep_idxs, sidx]
    amv_dir = good_amvs[keep_idxs, didx]
    bf_gfs_spd, bf_gfs_dir = spd_dir_from_uv(bfs_gfs[keep_idxs, 0], bfs_gfs[keep_idxs, 1])
    bf_spd, bf_dir = spd_dir_from_uv(bfs[keep_idxs, 0], bfs[keep_idxs, 1])

    diff = bf_spd - bf_gfs_spd
    diff = diff.magnitude
    spd_mad = np.average(np.abs(diff))
    spd_bias = np.average(diff)
    print('spd, MAD: {0:.2f}'.format(spd_mad))
    print('spd, bias: {0:.2f}'.format(spd_bias))
    spd_mean = np.mean(diff)
    spd_std = np.std(diff)
    print('spd MAD/bias/rms: {0:.2f}  {1:.2f}  {2:.2f}'.format(np.average(np.abs(diff)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2)))
    print('-----------------')
    bin_spd = bin_data_by(diff, amv_p, bin_ranges)

    dir_diff = direction_difference(bf_dir.magnitude, bf_gfs_dir.magnitude)
    print('dir, MAD: {0:.2f}'.format(np.average(np.abs(dir_diff))))
    print('dir bias: {0:.2f}'.format(np.average(dir_diff)))
    print('-------------------------------------')
    bin_dir = bin_data_by(dir_diff, amv_p, bin_ranges)

    u_diffs = bfs[keep_idxs, 0] - bfs_gfs[keep_idxs, 0]
    v_diffs = bfs[keep_idxs, 1] - bfs_gfs[keep_idxs, 1]

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: {0:.2f}  {1:.2f}'.format(vd_mean, np.sqrt(vd_mean**2 + vd_std**2)))
    print('******************************************************')

    x_values = []
    num_pres = []
    num_spd = []
    num_dir = []
    print('level   num cases   hgt MAD/bias    spd MAD/bias     dir MAD/bias')
    print('-------------------------------------------------------------------')
    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))
        num_pres.append(bin_pres[i].shape[0])
        num_spd.append(bin_spd[i].shape[0])
        num_dir.append(bin_dir[i].shape[0])

        print('{0:d}        {1:d}       {2:.2f}/{3:.2f}       {4:.2f}/{5:.2f}        {6:.2f}/{7:.2f}'
              .format(int(x_values[i]), num_pres[i], np.average(np.abs(bin_pres[i])), np.average(bin_pres[i]),
                      np.average(np.abs(bin_spd[i])), np.average(bin_spd[i]), np.average(np.abs(bin_dir[i])), np.average(bin_dir[i])))

    #return bin_ranges, bin_pres, bin_spd, bin_dir


def compare_amvs_bestfit_driver(amvs, bfs, prd, bfs_gfs, rbm, bin_size=200):

    thin = np.logical_and(prd[:, 2] > 0, prd[:, 2] < 1)
    thick = np.logical_and(prd[:, 2] >= 1, prd[:, 2] < 6)
    deep = np.logical_and(prd[:, 2] >= 6, prd[:, 2] <= 20)

    amvs_thin = amvs[thin, :]
    amvs_thick = amvs[thick, :]
    amvs_deep = amvs[deep, :]

    bfs_thin = bfs[thin, :]
    bfs_thick = bfs[thick, :]
    bfs_deep = bfs[deep, :]

    bfs_gfs_thin = bfs_gfs[thin, :]
    bfs_gfs_thick = bfs_gfs[thick, :]
    bfs_gfs_deep = bfs_gfs[deep, :]

    rbm_thin = rbm[thin, :]
    rbm_thick = rbm[thick, :]
    rbm_deep = rbm[deep, :]

    # values_pres = []
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thin, bfs_thin, bin_size=bin_size)
    # values_pres.append(bin_pres)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thick, bfs_thick, bin_size=bin_size)
    # values_pres.append(bin_pres)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs, bfs, bin_size=bin_size)
    # values_pres.append(bin_pres)
    #
    # values_spd = []
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thin, bfs_gfs_thin, bin_size=bin_size)
    # values_spd.append(bin_spd)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thick, bfs_gfs_thick, bin_size=bin_size)
    # values_spd.append(bin_spd)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs, bfs_gfs, bin_size=bin_size)
    # values_spd.append(bin_spd)

    values_prs = []
    values_spd = []
    values_dir = []
    bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thin, bfs_thin, bin_size=bin_size)
    values_prs.append(bin_pres)
    values_spd.append(bin_spd)
    values_dir.append(bin_dir)

    bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thick, bfs_thick, bin_size=bin_size)
    values_prs.append(bin_pres)
    values_spd.append(bin_spd)
    values_dir.append(bin_dir)

    bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs, bfs, bin_size=bin_size)
    values_prs.append(bin_pres)
    values_spd.append(bin_spd)
    values_dir.append(bin_dir)

    return bin_ranges, values_prs, values_spd, values_dir


def compare_amvs_bestfit_driver2(amvs, bfs, prd, bfs_gfs, rbm, bin_size=200):

    good = np.logical_and(prd[:, 0] != 5, prd[:, 0] != 0)
    ice = np.logical_and(good, prd[:, 0] == 4)
    not_ice = np.logical_and(good, prd[:, 0] != 4)

    amvs_ice = amvs[ice, :]
    amvs_not_ice = amvs[not_ice, :]

    bfs_ice = bfs[ice, :]
    bfs_not_ice = bfs[not_ice, :]

    # bfs_gfs_ice = bfs_gfs[ice, :]
    # bfs_gfs_not_ice = bfs_gfs[not_ice, :]
    #
    # rbm_ice = rbm[ice, :]
    # rbm_not_ice = rbm[thick, :]

    # values_pres = []
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thin, bfs_thin, bin_size=bin_size)
    # values_pres.append(bin_pres)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thick, bfs_thick, bin_size=bin_size)
    # values_pres.append(bin_pres)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs, bfs, bin_size=bin_size)
    # values_pres.append(bin_pres)
    #
    # values_spd = []
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thin, bfs_gfs_thin, bin_size=bin_size)
    # values_spd.append(bin_spd)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_thick, bfs_gfs_thick, bin_size=bin_size)
    # values_spd.append(bin_spd)
    #
    # bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs, bfs_gfs, bin_size=bin_size)
    # values_spd.append(bin_spd)

    values_prs = []
    values_spd = []
    values_dir = []
    bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_ice, bfs_ice, bin_size=bin_size)
    values_prs.append(bin_pres)
    values_spd.append(bin_spd)
    values_dir.append(bin_dir)

    bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs_not_ice, bfs_not_ice, bin_size=bin_size)
    values_prs.append(bin_pres)
    values_spd.append(bin_spd)
    values_dir.append(bin_dir)

    bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs, bfs, bin_size=bin_size)
    values_prs.append(bin_pres)
    values_spd.append(bin_spd)
    values_dir.append(bin_dir)
    print('***************************************')

    return bin_ranges, values_prs, values_spd, values_dir


def compare_amvs_bestfit(amvs, bfs, bin_size=200):

    good_amvs = amvs
    num_good = good_amvs.shape[0]
    didx = 6
    sidx = 5
    pidx = 4

    print('Number of AMVs: {0:d}'.format(num_good))
    spd_min = good_amvs[:, sidx].min()
    spd_max = good_amvs[:, sidx].max()
    print('spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(spd_min, spd_max, np.average(good_amvs[:, sidx])))

    p_min = good_amvs[:, pidx].min()
    p_max = good_amvs[:, pidx].max()
    print('pres min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(p_min, p_max, np.average(good_amvs[:, pidx])))

    low = good_amvs[:, pidx] >= 700
    mid = np.logical_and(good_amvs[:, pidx] < 700, good_amvs[:, pidx] > 400)
    hgh = good_amvs[:, pidx] <= 400

    n_low = np.sum(low)
    n_mid = np.sum(mid)
    n_hgh = np.sum(hgh)

    print('% low: {0:.2f}'.format(100.0*(n_low/num_good)))
    print('% mid: {0:.2f}'.format(100.0*(n_mid/num_good)))
    print('% hgh: {0:.2f}'.format(100.0*(n_hgh/num_good)))
    print('---------------------------')

    print('Low Spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[low, sidx].min(), good_amvs[low, sidx].max(), good_amvs[low,sidx].mean()))
    print('Low Press min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[low, pidx].min(), good_amvs[low, pidx].max(), good_amvs[low, pidx].mean()))
    print('---------------------------')

    print('Mid Spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[mid, sidx].min(), good_amvs[mid, sidx].max(), good_amvs[mid, sidx].mean()))
    print('Mid Press min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[mid, pidx].min(), good_amvs[mid, pidx].max(), good_amvs[mid, pidx].mean()))
    print('---------------------------')

    print('Hgh Spd min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[hgh, sidx].min(), good_amvs[hgh, sidx].max(), good_amvs[hgh, sidx].mean()))
    print('Hgh Press min/max/mean: {0:.2f}  {1:.2f}  {2:.2f}'.format(good_amvs[hgh, pidx].min(), good_amvs[hgh, pidx].max(), good_amvs[hgh, pidx].mean()))

    # Comparison to Level of Best Fit (LBF) ------------------------------------------------------------------
    # --------------------------------------------------------------------------------------------------------
    bin_ranges = get_press_bin_ranges(100, 1000, bin_size=bin_size)
    bin_ranges = []
    bin_ranges.append([50, 300])
    bin_ranges.append([300, 400])
    bin_ranges.append([400, 500])
    bin_ranges.append([500, 600])
    bin_ranges.append([600, 700])
    bin_ranges.append([700, 800])
    bin_ranges.append([800, 1000])

    vld_bf = bfs[:, 3] == 0
    keep_idxs = vld_bf

    amv_p = good_amvs[keep_idxs, pidx]
    bf_p = bfs[keep_idxs, 2]
    diff = amv_p - bf_p
    mad = np.average(np.abs(diff))
    bias = np.average(diff)
    print('********************************************************')
    print('Number of good best fits: ', bf_p.shape[0])
    print('press, MAD: {0:.2f}'.format(mad))
    print('press, bias: {0:.2f}'.format(bias))
    pd_std = np.std(diff)
    pd_mean = np.mean(diff)
    print('press bias/rms: {0:.2f}  {1:.2f} '.format(pd_mean, np.sqrt(pd_mean**2 + pd_std**2)))
    print('------------------------------------------')

    bin_pres = bin_data_by(diff, amv_p, bin_ranges)

    amv_spd = good_amvs[keep_idxs, sidx]
    amv_dir = good_amvs[keep_idxs, didx]
    bf_spd, bf_dir = spd_dir_from_uv(bfs[keep_idxs, 0], bfs[keep_idxs, 1])
    #bf_spd, bf_dir = bfs[keep_idxs, 0], bfs[keep_idxs, 1]

    diff = amv_spd * units('m/s') - bf_spd
    diff = diff.magnitude
    #diff = amv_spd - bf_spd
    spd_mad = np.average(np.abs(diff))
    spd_bias = np.average(diff)
    print('spd, MAD: {0:.2f}'.format(spd_mad))
    print('spd, bias: {0:.2f}'.format(spd_bias))
    spd_mean = np.mean(diff)
    spd_std = np.std(diff)
    print('spd MAD/bias/rms: {0:.2f}  {1:.2f}  {2:.2f}'.format(np.average(np.abs(diff)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2)))
    print('-----------------')
    bin_spd = bin_data_by(diff, amv_p, bin_ranges)

    dir_diff = direction_difference(amv_dir, bf_dir.magnitude)
    #dir_diff = direction_difference(amv_dir, bf_dir)
    print('dir, MAD: {0:.2f}'.format(np.average(np.abs(dir_diff))))
    print('dir bias: {0:.2f}'.format(np.average(dir_diff)))
    print('-------------------------------------')
    bin_dir = bin_data_by(dir_diff, amv_p, bin_ranges)

    amv_u, amv_v = uv_from_spd_dir(good_amvs[keep_idxs, sidx], good_amvs[keep_idxs, didx])
    u_diffs = amv_u - (bfs[keep_idxs, 0] * units('m/s'))
    v_diffs = amv_v - (bfs[keep_idxs, 1] * units('m/s'))

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: {0:.2f}  {1:.2f}'.format(vd_mean, np.sqrt(vd_mean**2 + vd_std**2)))
    print('******************************************************')

    x_values = []
    num_pres = []
    num_spd = []
    num_dir = []
    print('level   num cases   hgt MAD/bias    spd MAD/bias     dir MAD/bias')
    print('-------------------------------------------------------------------')
    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))
        num_pres.append(bin_pres[i].shape[0])
        num_spd.append(bin_spd[i].shape[0])
        num_dir.append(bin_dir[i].shape[0])

        print('{0:d}        {1:d}       {2:.2f}/{3:.2f}       {4:.2f}/{5:.2f}        {6:.2f}/{7:.2f}'
              .format(int(x_values[i]), num_pres[i], np.average(np.abs(bin_pres[i])), np.average(bin_pres[i]),
                      np.average(np.abs(bin_spd[i])), np.average(bin_spd[i]), np.average(np.abs(bin_dir[i])), np.average(bin_dir[i])))

    return bin_ranges, bin_pres, bin_spd, bin_dir


def make_plot(bin_ranges, bin_values, title='ACHA - RAOB BestFit', colors=['blue', 'red', 'black'], line_labels=['ICE', 'NOT ICE', 'ALL']):

    bin_vals_r = bin_values[0]
    bin_vals_g = bin_values[1]
    bin_vals = bin_values[2]

    num_vals_r = []
    num_vals_g = []
    num_vals = []
    mad = []
    bias = []
    std = []
    mad_r = []
    bias_r = []
    std_r = []
    mad_g = []
    bias_g = []
    std_g = []

    num_r = 0
    num_g = 0
    num = 0
    for i in range(len(bin_ranges)):
        num_r += bin_vals_r[i].shape[0]
        num_g += bin_vals_g[i].shape[0]
        num += bin_vals[i].shape[0]

    x_values = []
    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))

        # num_vals_r.append((bin_vals_r[i].shape[0])/num_r)
        # num_vals_g.append((bin_vals_g[i].shape[0])/num_g)
        # num_vals.append((bin_vals[i].shape[0])/num)
        num_vals_r.append(np.log10(bin_vals_r[i].shape[0]))
        num_vals_g.append(np.log10(bin_vals_g[i].shape[0]))
        num_vals.append(np.log10(bin_vals[i].shape[0]))

        mad_r.append(np.average(np.abs(bin_vals_r[i])))
        bias_r.append(np.average(bin_vals_r[i]))
        std_r.append(np.std(bin_vals_r[i]))

        mad_g.append(np.average(np.abs(bin_vals_g[i])))
        bias_g.append(np.average(bin_vals_g[i]))
        std_g.append(np.std(bin_vals_g[i]))

        mad.append(np.average(np.abs(bin_vals[i])))
        bias.append(np.average(bin_vals[i]))
        std.append(np.std(bin_vals[i]))

    x_axis_label = 'STD (hPa)'
    #do_plot(x_values, [std_r, std_g, std], line_labels, colors, title=title, x_axis_label=x_axis_label, y_axis_label='hPa', invert=True, flip=True)

    x_axis_label = 'BIAS (hPa)'
    do_plot(x_values, [bias_r, bias_g], line_labels, colors, title=title, x_axis_label=x_axis_label, y_axis_label='hPa', invert=True, flip=True)

    x_axis_label = 'MAE (hPa)'
    #do_plot(x_values, [mad_r, mad_g], line_labels, colors, title=title, x_axis_label=x_axis_label, y_axis_label='hPa', invert=True, flip=True)

    #do_plot(x_values, [num_vals_r, num_vals_g, num_vals], ['ICE: '+str(num_r), 'NOT ICE: '+str(num_g), 'ALL: '+str(num)], colors, title=title, x_axis_label='log(Count)', y_axis_label='hPa', invert=True, flip=True)


def make_plot2(bin_pres_r, bin_ranges):

    x_values = []
    num_pres_r = []
    num_pres_g = []
    num_spd = []
    num_dir = []
    pres_mad_r = []
    pres_bias_r = []
    pres_mad_g = []
    pres_bias_g = []
    spd_mad_r = []
    spd_bias_r = []
    spd_mad_g = []
    spd_bias_g = []

    num_r = 0
    for i in range(len(bin_ranges)):
        num_r += bin_pres_r[i].shape[0]

    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))
        num_pres_r.append((bin_pres_r[i].shape[0])/num_r)

        pres_mad_r.append(np.average(np.abs(bin_pres_r[i])))
        pres_bias_r.append(np.average(bin_pres_r[i]))


        #spd_mad_r.append(np.average(np.abs(bin_spd_r[i])))
        #spd_bias_r.append(np.average(bin_spd_r[i]))


    do_plot(x_values, [pres_mad_r], ['RAOB'], ['blue'], title='ACHA - BestFit', x_axis_label='MAE', y_axis_label='hPa', invert=True, flip=True)
    #do_plot(x_values, [pres_mad_r, pres_mad_g], ['RAOB', 'GFS'], ['blue', 'red'], title='ACHA - BestFit', x_axis_label='MAD', y_axis_label='hPa', invert=True, flip=True)
    #do_plot(x_values, [spd_mad_r, spd_mad_g], ['RAOB', 'GFS'], ['blue', 'red'], title='ACHA - BestFit', x_axis_label='MAE (m/s)', y_axis_label='hPa', invert=True, flip=True)
    #do_plot(x_values, [pres_bias_r, pres_bias_g], ['RAOB', 'GFS'], ['blue', 'red'], title='ACHA - BestFit', x_axis_label='BIAS', y_axis_label='hPa', invert=True, flip=True)
    #do_plot(x_values, [spd_bias_r, spd_bias_g], ['RAOB', 'GFS'], ['blue', 'red'], title='ACHA - BestFit', x_axis_label='BIAS (m/s)', y_axis_label='hPa', invert=True, flip=True)
    #do_plot(x_values, [num_pres_r, num_pres_g], ['RAOB:'+str(num_r), 'GFS:'+str(num_g)], ['blue', 'red'], x_axis_label='Normalized Count', y_axis_label='hPa', invert=True, flip=True)


def make_plot3(amvs_list, bfs_list):
    num = len(amvs_list)

    bin_ranges = get_press_bin_ranges(100, 1000, bin_size=100)

    y_values = []
    x_values = None
    line_labels = []
    colors = []
    for k in range(num):
        amvs = amvs_list[k]
        bfs = bfs_list[k]

        bin_ranges, bin_pres, bin_spd, bin_dir = compare_amvs_bestfit(amvs, bfs, bin_size=100)

        #diff = amvs[:, 4] - amvs[:, 7]
        #bin_pres = bin_data_by(diff, amvs[:, 4], bin_ranges)

        pres_mad = []
        x_values = []
        for i in range(len(bin_ranges)):
            x_values.append(np.average(bin_ranges[i]))
            #num_pres_r.append((bin_pres[i].shape[0]) / num_r)

            pres_mad.append(np.average(np.abs(bin_pres[i])))
            #pres_bias.append(np.average(bin_pres[i]))

        line_labels.append(None)
        colors.append(None)
        y_values.append(pres_mad)

    do_plot(x_values, y_values, line_labels, colors, title='Daily ACHA - BestFit RAOB', x_axis_label='MAD', y_axis_label='hPa', invert=True, flip=True)


# imports the S4 NOAA output
# filename: full path as a string, '/home/user/filename'
# returns a dict: time -> list of profiles (a profile is a list of levels)
def get_aeolus_time_dict(filename, lon360=False, do_sort=True):
    time_dict = {}

    with open(filename) as file:
        while True:
            prof_hdr = file.readline()
            if not prof_hdr:
                break
            toks = prof_hdr.split()

            yr = int(float(toks[0]))
            mon = int(float(toks[1]))
            dy = int(float(toks[2]))
            hr = int(float(toks[3]))
            mn = int(float(toks[4]))
            ss = int(float(toks[5]))
            lon = float(toks[6])
            lat = float(toks[7])
            nlevs = int(toks[8])

            if lon360:
                if lon < 0:
                    lon += 360.0
            else:
                if lon > 180.0:
                    lon -= 360.0

            dto = datetime.datetime(year=yr, month=mon, day=dy, hour=hr, minute=mn, second=ss)
            dto = dto.replace(tzinfo=timezone.utc)
            timestamp = dto.timestamp()

            prof = []
            if time_dict.get(timestamp, -1) == -1:
                prof_s = []
                prof_s.append(prof)
                time_dict[timestamp] = prof_s
            else:
                prof_s = time_dict.get(timestamp)
                prof_s.append(prof)

            for k in range(nlevs):
                line = file.readline()
                toks = line.split()
                lvlidx = int(toks[0])
                hhh = float(toks[1]) * 1000.0
                hht = float(toks[2]) * 1000.0
                hhb = float(toks[3]) * 1000.0
                err = float(toks[4])
                azm = float(toks[5])
                ws = float(toks[6])
                len = float(toks[7])

                tup = (lat, lon, hhh, hht, hhb, azm, ws)

                prof.append(tup)

        file.close()

    if do_sort:
        keys = np.array(list(time_dict.keys()))
        keys.sort()
        keys = keys.tolist()

        sorted_time_dict = {}

        for key in keys:
            sorted_time_dict[key] = time_dict.get(key)
        time_dict = sorted_time_dict

    return time_dict


# make each profile at a timestamp a numpy array
def time_dict_to_nd_0(time_dict):
    keys = list(time_dict.keys())
    for key in keys:
        vals = time_dict[key]
        if vals is not None:
            nda = np.array(vals[0])  # only take one profile per second
            time_dict[key] = nda

    return time_dict


# # make each profile at a timestamp a numpy array
def time_dict_to_nd(time_dict):
    params = ['latitude', 'longitude', 'hhh', 'hht', 'hhb', 'azimuth', 'wind_speed']
    coords = {'num_params': params}
    dims = ['num_levels', 'num_params']

    keys = list(time_dict.keys())
    for key in keys:
        prof_s = time_dict[key]
        if prof_s is not None:
            for i in range(len(prof_s)):
                nda = np.stack(prof_s[i])
                nda_da = xr.DataArray(nda, coords=coords, dims=dims)
                prof_s[i] = nda_da

    return time_dict


def time_dict_to_nd_2(time_dict):
    keys = list(time_dict.keys())
    for key in keys:
        vals = time_dict[key]
        if vals is not None:
            time_dict[key] = np.stack(vals)

    return time_dict


def concat(t_dct_0, t_dct_1):
    keys_0 = list(t_dct_0.keys())
    nda_0 = np.array(keys_0)

    keys_1 = list(t_dct_1.keys())
    nda_1 = np.array(keys_1)

    comm_keys, comm0, comm1 = np.intersect1d(nda_0, nda_1, return_indices=True)

    comm_keys = comm_keys.tolist()

    for key in comm_keys:
        t_dct_1.pop(key)
    t_dct_0.update(t_dct_1)

    return t_dct_0


def get_aeolus_time_dict_s(files_path, lon360=False, do_sort=True, chan='mie'):
    ftimes = []
    fnames = glob.glob(files_path + chan + '1day.out.*')
    time_dct = {}
    for pathname in fnames:
        fname = os.path.split(pathname)[1]
        toks = fname.split('.')
        dstr = toks[2]
        dto = datetime.datetime.strptime(dstr, '%Y-%m-%d').replace(tzinfo=timezone.utc)
        ts = dto.timestamp()
        ftimes.append(ts)
        time_dct[ts] = pathname

    sorted_filenames = []
    ftimes.sort()
    for t in ftimes:
        sorted_filenames.append(time_dct.get(t))

    dct_s = []
    for fname in sorted_filenames:
        a_dct = get_aeolus_time_dict(fname, lon360=lon360, do_sort=do_sort)
        dct_s.append(a_dct)

    t_dct = dct_s[0]

    for dct in dct_s[1:]:
        concat(t_dct, dct)

    return t_dct


def time_dict_to_cld_layers(time_dict):
    time_dict_layers = {}

    keys = list(time_dict.keys())
    for key in keys:
        prof_s = time_dict[key]
        layers = []
        prof = prof_s[0]

        if len(prof) == 1:
            tup = prof[0]
            layers.append((tup[0], tup[1], tup[3], tup[4]))
            time_dict_layers[key] = layers
            continue

        top = -9999.9
        last_bot = -9999.9
        tup = None
        for i in range(len(prof)):
            tup = prof[i]

            if i == 0:
                top = tup[3]
                bot = tup[4]
                last_bot = bot
            else:
                if math.fabs(last_bot - tup[3]) > 10.0:
                    layers.append((tup[0], tup[1], top, last_bot))
                    top = tup[3]
                last_bot = tup[4]

        layers.append((tup[0], tup[1], top, tup[4]))

        time_dict_layers[key] = layers

    return time_dict_layers


def get_cloud_layers_dict(filename, lon360=False):
    a_d = get_aeolus_time_dict(filename, lon360=lon360)
    c_d = time_dict_to_cld_layers(a_d)
    cld_lyr_dct = time_dict_to_nd_2(c_d)
    return cld_lyr_dct


def get_cloud_layers_dict_s(aeolus_files_dir, lon360=False):
    a_d = get_aeolus_time_dict_s(aeolus_files_dir, lon360=lon360, do_sort=True, chan='mie')
    cld_lyr_dct = time_dict_to_cld_layers(a_d)
    cld_lyr_dct = time_dict_to_nd_2(cld_lyr_dct)
    return cld_lyr_dct


def match_aeolus_to_clavrx(aeolus_dict, clvrx_files):
    nav = get_navigation()
    match_dict = {}
    clvrx_params = clvrx_files.get_parameters()

    keys = list(aeolus_dict.keys())

    last_f_idx = -1
    param_nd = None
    dataset = None
    dataset_s = []

    for key in keys:
        fname, ftime, f_idx = clvrx_files.get_file_containing_time(key)
        if f_idx is None:
            continue

        prof_s = aeolus_dict.get(key)
        if prof_s is None:
            continue

        match_dict[key] = []

        if f_idx != last_f_idx:
            last_f_idx = f_idx
            dataset = Dataset(fname)
            dataset_s.append(dataset)

        for prof in prof_s:
            profv = prof.values
            lat = profv[0, 0]
            lon = profv[0, 1]

            cc, ll = nav.earth_to_lc(lon, lat)
            if cc is None or ll is None:
                continue
            if cc - 4 < 0 or ll - 4 < 0 or cc + 5 >= num_elems or ll + 5 >= num_lines:
                continue

            data = []
            for param in clvrx_params:
                data.append(dataset[param][cc-4:cc+5, ll-4:ll+5])
            param_nd = np.stack(data, axis=0)
            match_dict[key].append((cc, ll, f_idx, prof, param_nd))


    for ds in dataset_s:
        ds.close()

    return match_dict


def create_aeolus_clavrx_match_file(match_dct, filename, clvrx_params, clvrx_files):
    grd_x_len = 9
    grd_y_len = 9
    num_aparams = 7
    max_num_alevels = 0
    num_aprofs = 0

    alons = []
    alats = []
    atimes = []
    elems = []
    lines = []
    fidxs = []

    clvrx_files = np.array(clvrx_files, dtype='object')

    # scan to get max num levels
    keys = list(match_dct.keys())
    for key in keys:
        tup_s = match_dct.get(key)
        for tup in tup_s:
            num_aprofs += 1
            prof = tup[3]
            lat = prof[0, 0]
            lon = prof[0, 1]
            alons.append(lon)
            alats.append(lat)
            atimes.append(key)
            elems.append(tup[0])
            lines.append(tup[1])
            fidxs.append(tup[2])

            nlevs = prof.shape[0]
            if nlevs > max_num_alevels:
                max_num_alevels = nlevs

    alons = np.array(alons)
    alats = np.array(alats)
    atimes = np.array(atimes)
    elems = np.array(elems)
    lines = np.array(lines)
    fidxs = np.array(fidxs)

    # Sample file to retrieve and copy variable attributes
    rg_exmpl = Dataset('/home/rink/data/clavrx/clavrx_OR_ABI-L1b-RadF-M6C01_G16_s20192930000343.level2.nc', 'r')

    # the top level group for the output file
    rootgrp = Dataset(filename, 'w', format='NETCDF4')

    dim_amvs = rootgrp.createDimension('num_aeolus_params', size=num_aparams)
    dim_alevs = rootgrp.createDimension('max_num_aeolus_levels', size=max_num_alevels)
    dim_num_aeolus_prof = rootgrp.createDimension('num_aeolus_profs', size=num_aprofs)
    dim_grd_x = rootgrp.createDimension('grd_x_len', size=grd_x_len)
    dim_grd_y = rootgrp.createDimension('grd_y_len', size=grd_y_len)
    dim_num_files = rootgrp.createDimension('num_clvrx_files', size=clvrx_files.shape[0])

    prf_time = rootgrp.createVariable('time', 'f4', ['num_aeolus_profs'])
    clvrx_file_names = rootgrp.createVariable('clvrx_file_names', str, ['num_clvrx_files'])

    # ---- Profile variables ---------------
    prf_lon = rootgrp.createVariable('prof_longitude', 'f4', ['num_aeolus_profs'])
    prf_lon.units = 'degrees east'
    prf_lat = rootgrp.createVariable('prof_latitude', 'f4', ['num_aeolus_profs'])
    prf_lat.units = 'degrees north'
    prf_time.units = 'seconds since 1970-01-1 00:00:00'
    prf_elem = rootgrp.createVariable('FD_elem', 'f4', ['num_aeolus_profs'])
    prf_line = rootgrp.createVariable('FD_line', 'f4', ['num_aeolus_profs'])
    prf_azm = rootgrp.createVariable('prof_azm', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_azm.units = 'degree'
    prf_spd = rootgrp.createVariable('prof_spd', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_spd.units = 'm s-1'
    prf_hht = rootgrp.createVariable('prof_hht', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_hht.units = 'meter'
    prf_hhb = rootgrp.createVariable('prof_hhb', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_hhb.units = 'meter'

    # ----- Product variables ----------------
    nc4_vars = []
    var_s = rg_exmpl.variables
    for pidx, param in enumerate(clvrx_params):
        v = var_s[param]
        attr_s = v.ncattrs()
        if '_FillValue' in attr_s:
            var = rootgrp.createVariable(param, v.dtype, ['num_aeolus_profs', 'grd_y_len', 'grd_x_len'], fill_value=v.getncattr('_FillValue'))
        else:
            var = rootgrp.createVariable(param, v.dtype, ['num_aeolus_profs', 'grd_y_len', 'grd_x_len'])
        # copy attributes from example to new output variable of the same name
        for attr in attr_s:
            if attr != '_FillValue':
                var.setncattr(attr, v.getncattr(attr))
        nc4_vars.append(var)

    # Write data to file  ---------------------
    prf_lon[:] = alons
    prf_lat[:] = alats
    prf_elem[:] = elems
    prf_line[:] = lines
    prf_time[:] = atimes

    clvrx_file_names[:] = clvrx_files

    idx = 0
    for key in keys:
        tup_s = match_dct.get(key)
        for tup in tup_s:
            prof = tup[3]
            param_nd = tup[4]

            nlevs = prof.shape[0]
            for k in range(nlevs):
                prf_spd[idx,k] = prof[k,6]
                prf_azm[idx,k] = prof[k,5]
                prf_hht[idx,k] = prof[k,3]
                prf_hhb[idx,k] = prof[k,4]

            for pidx, param in enumerate(clvrx_params):
                nc4_vars[pidx][idx, :, :] = param_nd[pidx, :, :]
            idx += 1

    rg_exmpl.close()
    rootgrp.close()


def get_search_box(cc, ll):

    c_rng = [cc - half_width, cc + half_width]
    l_rng = [ll - half_width, ll + half_width]

    if c_rng[0] < 0:
        c_rng[0] = 0

    if l_rng[0] < 0:
        l_rng[0] = 0

    if c_rng[1] >= num_elems:
        c_rng[1] = num_elems - 1

    if l_rng[1] >= num_lines:
        l_rng[1] = num_lines - 1

    return c_rng, l_rng


# aeolus_dict: time -> profiles
# amv_files_path: directory containing AMVs, '/home/user/amvdir/'
# return dict: aeolus time -> tuple (amv_lon, amv_lat, amv_pres, amv_spd, amv_dir)
def match_amvs_to_aeolus_fast(aeolus_dict, amv_files_path, amv_source='OPS', band='14', amv_files=None):
    if amv_files is None:
        amv_files = get_datasource(amv_files_path, amv_source, band=band)
    nav = amv_files.get_navigation()
    amv_params = amv_files.get_parameters()
    all_params = [amv_files.lon_name, amv_files.lat_name, amv_files.elem_name, amv_files.line_name] + amv_params
    match_dict = {}
    coords = {'num_params': all_params}
    dims = ['num_params', 'num_amvs']

    keys = list(aeolus_dict.keys())

    last_f_idx = -1
    for key in keys:
        fname, ftime, f_idx = amv_files.get_file_containing_time(key)
        if f_idx is None:
            continue

        profs = aeolus_dict.get(key)
        if profs is None:
            continue

        match_dict[key] = []

        if f_idx != last_f_idx:
            last_f_idx = f_idx
            ds = Dataset(fname)

            amv_lons = ds[amv_files.lon_name][:]
            amv_lats = ds[amv_files.lat_name][:]
            cc = ds[amv_files.elem_name][:]
            ll = ds[amv_files.line_name][:]
            # cc, ll = nav.earth_to_lc_s(amv_lons, amv_lats)

            param_s = []
            param_s.append(amv_lons)
            param_s.append(amv_lats)
            param_s.append(cc)
            param_s.append(ll)
            for param in amv_params:
                if param == 'V_3D':
                    param_s.append(ds[param][:, 0])
                    param_s.append(ds[param][:, 1])
                else:
                    param_s.append(ds[param][:])

            ds.close()

        for prof in profs:
            profv = prof.values
            lat = profv[0, 0]
            lon = profv[0, 1]

            cc_prf, ll_prf = nav.earth_to_lc(lon, lat)
            if cc_prf is None or ll_prf is None:
                continue

            c_rng, l_rng = get_search_box(cc_prf, ll_prf)

            in_cc = np.logical_and(cc > c_rng[0], cc < c_rng[1])
            in_ll = np.logical_and(ll > l_rng[0], ll < l_rng[1])
            in_box = np.logical_and(in_cc, in_ll)

            num_amvs = np.sum(in_box)
            if num_amvs == 0:
                continue
            # dist = haversine_np(lon, lat, amv_lons[in_box], amv_lats[in_box])
            param_nd = np.vstack(param_s)
            param_nd = param_nd[:, in_box]
            amvs_da = xr.DataArray(param_nd, coords=coords, dims=dims)
            match_dict[key].append((cc_prf, ll_prf, f_idx, prof, amvs_da))

    return match_dict


def match_calipso_clavrx_to_amvs(calipso_clavrx_path, calipso_clavrx_file, amv_files_path, amv_source='OPS', band='14', amv_files=None):
    if amv_files is None:
        amv_files = get_datasource(amv_files_path, amv_source, band=band)
    nav = amv_files.get_navigation()
    amv_params = amv_files.get_parameters()
    all_params = [amv_files.lon_name, amv_files.lat_name, amv_files.elem_name, amv_files.line_name] + amv_params
    match_dict = {}
    coords = {'num_params': all_params}
    dims = ['num_params', 'num_amvs']

    calipso_clavrx_ds = CLAVRx_CALIPSO(calipso_clavrx_path)

    nom_time = calipso_clavrx_ds.get_datetime(calipso_clavrx_file).timestamp()
    calipso_clavrx_nc4 = Dataset(calipso_clavrx_file)

    fname, ftime, f_idx = amv_files.get_file_containing_time(nom_time)
    if f_idx is None:
        return

    match_dict[nom_time] = []

    amv_ds = Dataset(fname)
    amv_lons = amv_ds[amv_files.lon_name][:]
    amv_lats = amv_ds[amv_files.lat_name][:]
    if amv_files.elem_name is not None:
        amv_cc = amv_ds[amv_files.elem_name][:]
        amv_ll = amv_ds[amv_files.line_name][:]
    else:
        amv_cc, amv_ll = nav.earth_to_lc_s(amv_lons, amv_lats)

    param_s = []
    param_s.append(amv_lons)
    param_s.append(amv_lats)
    param_s.append(amv_cc)
    param_s.append(amv_ll)
    for param in amv_params:
        if param == 'V_3D':
            param_s.append(amv_ds[param][:, 0])
            param_s.append(amv_ds[param][:, 1])
        else:
            param_s.append(amv_ds[param][:])

    amv_ds.close()

    calipso_clavrx_params = ['closest_calipso_top_alt', 'closest_calipso_base_alt', 'cld_height_acha', 'cld_height_base_acha']
    coords_a = {'num_calipso_clavrx_params': calipso_clavrx_params}
    dims_a = ['num_calipso_clavrx_params']
    calipso_clavrx_data = []
    for pname in calipso_clavrx_params:
        calipso_clavrx_data.append(calipso_clavrx_nc4[pname][:])
    xs = calipso_clavrx_nc4['x'][:]
    ys = calipso_clavrx_nc4['y'][:]
    clvr_xy_s = zip(xs, ys)

    for idx, clvr_xy in enumerate(clvr_xy_s):
        cc, ll = clvr_xy

        c_rng, l_rng = get_search_box(cc, ll)

        in_cc = np.logical_and(amv_cc > c_rng[0], amv_cc < c_rng[1])
        in_ll = np.logical_and(amv_ll > l_rng[0], amv_ll < l_rng[1])
        in_box = np.logical_and(in_cc, in_ll)

        num_amvs = np.sum(in_box)
        if num_amvs == 0:
            continue
        # dist = haversine_np(lon, lat, amv_lons[in_box], amv_lats[in_box])
        param_nd = np.vstack(param_s)
        param_nd = param_nd[:, in_box]
        amvs_da = xr.DataArray(param_nd, coords=coords, dims=dims)

        data_nd = np.vstack(calipso_clavrx_data)
        data_da = xr.DataArray(data_nd[:, idx], coords=coords_a, dims=dims_a)

        match_dict[nom_time].append((cc, ll, f_idx, data_da, amvs_da))

    return match_dict


# aeolus_dict: time -> profiles
# amv_files_path: directory containing AMVs, '/home/user/amvdir/'
# return dict: aeolus time -> tuple (amv_lon, amv_lat, amv_pres, amv_spd, amv_dir)
def match_amvs_to_aeolus(aeolus_dict, amv_files_path, amv_source='OPS', band='14', amv_files=None):
    if amv_files is None:
        amv_files = get_datasource(amv_files_path, amv_source, band=band)
    nav = amv_files.get_navigation()
    amv_params = amv_files.get_parameters()
    all_params = [amv_files.lon_name, amv_files.lat_name, amv_files.elem_name, amv_files.line_name] + amv_params
    match_dict = {}
    coords = {'num_params': all_params}
    dims = ['num_params', 'num_amvs']

    keys = list(aeolus_dict.keys())

    dataset = None
    dataset_s = []
    amv_lons = None
    amv_lats = None
    cc = None
    ll = None

    last_f_idx = -1
    for key in keys:
        fname, ftime, f_idx = amv_files.get_file_containing_time(key)
        if f_idx is None:
            continue

        profs = aeolus_dict.get(key)
        if profs is None:
            continue

        match_dict[key] = []

        if f_idx != last_f_idx:
            last_f_idx = f_idx
            ds = Dataset(fname)
            dataset = ds
            dataset_s.append(ds)

            amv_lons = ds[amv_files.lon_name][:]
            amv_lats = ds[amv_files.lat_name][:]
            cc = ds[amv_files.elem_name][:]
            ll = ds[amv_files.line_name][:]
            # cc, ll = nav.earth_to_lc_s(amv_lons, amv_lats)

        for prof in profs:
            lat = prof[0, 0]
            lon = prof[0, 1]
            cc_prf, ll_prf = nav.earth_to_lc(lon, lat)
            if cc_prf is None or ll_prf is None:
                continue
            c_rng, l_rng = get_search_box(cc_prf, ll_prf)

            in_cc = np.logical_and(cc > c_rng[0], cc < c_rng[1])
            in_ll = np.logical_and(ll > l_rng[0], ll < l_rng[1])
            in_box = np.logical_and(in_cc, in_ll)

            num_amvs = np.sum(in_box)
            if num_amvs == 0:
                continue

            param_s = []
            param_s.append(amv_lons[in_box])
            param_s.append(amv_lats[in_box])
            param_s.append(cc[in_box])
            param_s.append(ll[in_box])
            for param in amv_params:
                if param == 'V_3D':
                    param_s.append(ds[param][in_box, 0])
                    param_s.append(ds[param][in_box, 1])
                else:
                    param_s.append(ds[param][in_box])

            param_nd = np.vstack(param_s)
            amvs_da = xr.DataArray(param_nd, coords=coords, dims=dims)
            match_dict[key].append(amvs_da)

    for ds in dataset_s:
        ds.close()

    return match_dict


# full path as string filename to create, '/home/user/newfilename'
# aeolus_to_amv_dct: output from match_amvs_to_aeolus
# aeolus_dct: output from get_aeolus_time_dict
# amv_files: container representing specific AMV product info
def create_file(filename, aeolus_to_amv_dct, aeolus_dct, amv_files, cld_lyr=False):
    keys = list(aeolus_to_amv_dct.keys())

    num_amvs = []
    num_levs = []
    times = []

    namvs = 0
    nlevs = 0
    for key in keys:
        param_nd = aeolus_to_amv_dct.get(key)
        num_amvs.append(param_nd.shape[1])
        namvs += param_nd.shape[1]

        prof = aeolus_dct.get(key)
        num_levs.append(prof.shape[0])
        nlevs += prof.shape[0]

        times.append(key)

    amv_per_alus = len(aeolus_to_amv_dct)
    rootgrp = Dataset(filename, 'w', format='NETCDF4')
    dim_amvs = rootgrp.createDimension('amvs', size=namvs)
    dim_alus = rootgrp.createDimension('profs', size=nlevs)
    dim_num_aeolus_prof = rootgrp.createDimension('num_aeolus_profs', size=len(aeolus_to_amv_dct))

    nc4_vars = []
    out_params = amv_files.get_out_parameters()
    meta_dict = amv_files.get_meta_dict()

    for pidx, param in enumerate(out_params):
        u, t = meta_dict.get(param)
        var = rootgrp.createVariable(param, t, ['amvs'])
        if u is not None:
            var.units = u
        nc4_vars.append(var)

    dist = rootgrp.createVariable('dist_to_prof', 'f4', ['amvs'])
    dist.units = 'km'

    num_amvs_per_prof = rootgrp.createVariable('num_amvs_per_prof', 'i4', ['num_aeolus_profs'])
    num_levs_per_prof = rootgrp.createVariable('num_levs_per_prof', 'i4', ['num_aeolus_profs'])
    prof_time = rootgrp.createVariable('time', 'f4', ['num_aeolus_profs'])
    # ---- Profile variables ---------------
    prf_lon = rootgrp.createVariable('prof_longitude', 'f4', ['num_aeolus_profs'])
    prf_lon.units = 'degrees east'
    prf_lat = rootgrp.createVariable('prof_latitude', 'f4', ['num_aeolus_profs'])
    prf_lat.units = 'degrees north'
    prof_time.units = 'seconds since 1970-01-1 00:00:00'

    if not cld_lyr:
        prf_azm = rootgrp.createVariable('prof_azm', 'f4', ['profs'])
        prf_azm.units = 'degree'
        prf_spd = rootgrp.createVariable('prof_spd', 'f4', ['profs'])
        prf_spd.units = 'm s-1'
    prf_hht = rootgrp.createVariable('prof_hht', 'f4', ['profs'])
    prf_hht.units = 'meter'
    prf_hhb = rootgrp.createVariable('prof_hhb', 'f4', ['profs'])
    prf_hhb.units = 'meter'
    # --------------------------------------

    i_a = 0
    i_c = 0
    for idx, key in enumerate(keys):
        namvs = num_amvs[idx]
        nlevs = num_levs[idx]
        i_b = i_a + namvs
        i_d = i_c + nlevs

        prof = aeolus_dct.get(key)
        if not cld_lyr:
            prf_hht[i_c:i_d] = prof[:, 3]
            prf_hhb[i_c:i_d] = prof[:, 4]
            prf_azm[i_c:i_d] = prof[:, 5]
            prf_spd[i_c:i_d] = prof[:, 6]
        else:
            prf_hht[i_c:i_d] = prof[:, 2]
            prf_hhb[i_c:i_d] = prof[:, 3]
        i_c += nlevs

        plat = prof[0, 0]
        plon = prof[0, 1]
        prf_lat[idx::] = plat
        prf_lon[idx::] = plon

        param_nd = aeolus_to_amv_dct.get(key)
        for pidx, param in enumerate(out_params):
            nc4_vars[pidx][i_a:i_b] = param_nd[pidx, :]
        dist[i_a:i_b] = haversine_np(plon, plat, param_nd[0, :], param_nd[1, :])
        i_a += namvs

    num_amvs_per_prof[:] = num_amvs
    num_levs_per_prof[:] = num_levs
    prof_time[:] = times

    rootgrp.close()


def create_bestfit_file(output_filename, match_dct, raob_dct, raob_ts, gfs_at_raob_dct, bf_dct, prd_dct, amv_params, raob_filename, amv_file_name):
    num_aparams = 4
    max_num_alevels = 0
    max_num_amvs = 0

    alons = []
    alats = []
    ids = []
    num_amvs_per_prof = []

    #amv_file_s = np.array(amv_file_s, dtype='object')

    # scan to get max num levels, amvs
    keys = list(match_dct.keys())
    num_aprofs = len(keys)
    for key in keys:
        raob, sta_id = raob_dct.get(key)
        nlevs = raob.shape[0]
        if nlevs > max_num_alevels:
            max_num_alevels = nlevs

        alons.append(key.lon)
        alats.append(key.lat)
        ids.append(sta_id)

        bf_tup = bf_dct.get(key)

        amvs = bf_tup[0]
        num_amvs = amvs.shape[0]
        num_amvs_per_prof.append(num_amvs)
        if num_amvs > max_num_amvs:
            max_num_amvs = num_amvs

    alons = np.array(alons)
    alats = np.array(alats)
    num_amvs_per_prof = np.array(num_amvs_per_prof)

    # Sample file to retrieve and copy variable attributes
    #rg_exmpl = Dataset('/ships19/cloud/scratch/4TH_AMV_INTERCOMPARISON/FMWK2_AMV/GOES16_ABI_2KM_FD_2019293_0020_34_WINDS_AMV_EN-14CT.nc', 'r')
    rg_exmpl = Dataset('/ships19/cloud/scratch/AMV_FAST_BIAS/OPS_AMVS/OR_ABI-L2-DMWF-M6C14_G16_s20201050000166_e20201050009474_c20201050023226.nc', 'r')

    # the top level group for the output file
    rootgrp = Dataset(output_filename, 'w', format='NETCDF4')

    rootgrp.createDimension('num_raob_params', size=num_aparams)
    rootgrp.createDimension('max_num_raob_levels', size=max_num_alevels)
    rootgrp.createDimension('max_num_amvs', size=max_num_amvs)
    rootgrp.createDimension('num_raob_profs', size=num_aprofs)
    rootgrp.createDimension('time', size=1)

    time = rootgrp.createVariable('time', 'f4', ['time'])
    time.units = 'seconds since 1970-01-01 00:00:00'
    time.calendar = 'standard'
    #amv_file_names = rootgrp.createVariable('amv_file_names', str, ['num_amv_files'])

    # ---- Profile variables ---------------
    prf_lon = rootgrp.createVariable('longitude', 'f4', ['time', 'num_raob_profs'])
    prf_lon.units = 'degrees_east'
    prf_lon.long_name = 'raob longitude'

    prf_lat = rootgrp.createVariable('latitude', 'f4', ['time', 'num_raob_profs'])
    prf_lat.units = 'degrees_north'
    prf_lat.long_name = 'raob latitude'

    prf_id = rootgrp.createVariable('station_id', 'i4', ['time', 'num_raob_profs'])
    prf_id.long_name = 'WMO station identifier'

    num_amvs = rootgrp.createVariable('num_amvs_per_profile', 'i4', ['time', 'num_raob_profs'])

    prf_dir = rootgrp.createVariable('raob_dir', 'f4', ['time', 'num_raob_profs', 'max_num_raob_levels'])
    prf_dir.units = 'degree'
    prf_dir[:,] = np.nan
    prf_spd = rootgrp.createVariable('raob_spd', 'f4', ['time', 'num_raob_profs', 'max_num_raob_levels'])
    prf_spd.units = 'm s-1'
    prf_spd[:,] = np.nan
    prf_lvl = rootgrp.createVariable('raob_levels', 'f4', ['time', 'num_raob_profs', 'max_num_raob_levels'])
    prf_lvl.units = 'hPa'
    prf_lvl[:,] = np.nan

    # ----- Product variables ----------------
    nc4_vars = []
    amv_lon = rootgrp.createVariable('amv_lon', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    amv_lon.units = 'degrees_east'
    amv_lon[:,] = np.nan
    nc4_vars.append(amv_lon)

    amv_lat = rootgrp.createVariable('amv_lat', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    amv_lat.units = 'degrees_north'
    amv_lat[:,] = np.nan
    nc4_vars.append(amv_lat)

    amv_elem = rootgrp.createVariable('amv_elem', 'i4', ['time', 'num_raob_profs', 'max_num_amvs'])
    amv_elem.long_name = 'FGF x coordinate Full Disk GOES16'
    amv_elem[:,] = -9999
    nc4_vars.append(amv_elem)

    amv_line = rootgrp.createVariable('amv_line', 'i4', ['time', 'num_raob_profs', 'max_num_amvs'])
    amv_line.long_name = 'FGF y coordinate Full Disk GOES16'
    amv_line[:,] = -9999
    nc4_vars.append(amv_line)

    dist_to_raob = rootgrp.createVariable('dist_to_raob', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    dist_to_raob.units = 'km'
    dist_to_raob[:,] = np.nan

    var_s = rg_exmpl.variables
    for pidx, param in enumerate(amv_params):
        v = var_s[param]
        attr_s = v.ncattrs()
        if '_FillValue' in attr_s:
            var = rootgrp.createVariable(param, v.dtype, ['time', 'num_raob_profs', 'max_num_amvs'], fill_value=v.getncattr('_FillValue'))
        else:
            var = rootgrp.createVariable(param, v.dtype, ['time', 'num_raob_profs', 'max_num_amvs'])
        # copy attributes from example to new output variable of the same name
        for attr in attr_s:
            if attr == '_FillValue' or attr == 'coordinates' or attr == 'grid_mapping' or attr == 'cell_methods':
                continue
            else:
                var.setncattr(attr, v.getncattr(attr))
        nc4_vars.append(var)

    amv_params = ['amv_lon', 'amv_lat', 'amv_elem', 'amv_line'] + amv_params

    # Bestfit variables ------------------
    var_u = rootgrp.createVariable('bf_raob_u', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_u.standard_name = 'eastward_wind'
    var_u.long_name = 'wind at bestfit level with raob as background profile'
    var_u.units = 'm s-1'
    var_u[:,] = np.nan

    var_v = rootgrp.createVariable('bf_raob_v', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_v.standard_name = 'northward_wind'
    var_v.long_name = 'wind at bestfit level with raob as background profile'
    var_v.units = 'm s-1'
    var_v[:,] = np.nan

    var_p = rootgrp.createVariable('bf_raob_press', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_p.standard_name = 'air_pressure'
    var_p.long_name = 'pressure level of bestfit with raob as background profile'
    var_p.units = 'hPa'
    var_p[:,] = np.nan

    var_flg = rootgrp.createVariable('bf_raob_flag', 'i4', ['time', 'num_raob_profs', 'max_num_amvs'])

    var_u_g = rootgrp.createVariable('bf_gfs_u', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_u_g.standard_name = 'eastward_wind'
    var_u_g.long_name = 'wind at bestfit level with GFS as background profile'
    var_u_g.units = 'm s-1'
    var_u_g[:,] = np.nan

    var_v_g = rootgrp.createVariable('bf_gfs_v', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_v_g.standard_name = 'northward_wind'
    var_v_g.long_name = 'wind at bestfit level with GFS as background profile'
    var_v_g.units = 'm s-1'
    var_v_g[:,] = np.nan

    var_p_g = rootgrp.createVariable('bf_gfs_press', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_p_g.standard_name = 'air_pressure'
    var_p_g.long_name = 'pressure level of bestfit with GFS as background profile'
    var_p_g.units = 'hPa'
    var_p_g[:,] = np.nan

    var_flg_g = rootgrp.createVariable('bf_gfs_flag', 'i4', ['time', 'num_raob_profs', 'max_num_amvs'])

    var_u_cp = rootgrp.createVariable('closest_press_u', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_u_cp.standard_name = 'eastward_wind'
    var_u_cp.long_name = 'wind at closest raob pressure level to amv within +/- 60mb'
    var_u_cp.units = 'm s-1'
    var_u_cp[:,] = np.nan

    var_v_cp = rootgrp.createVariable('closest_press_v', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_v_cp.standard_name = 'northward_wind'
    var_v_cp.long_name = 'wind at closest raob pressure level to amv within +/- 60mb'
    var_v_cp.units = 'm s-1'
    var_v_cp[:,] = np.nan

    var_p_cp = rootgrp.createVariable('closest_press', 'f4', ['time', 'num_raob_profs', 'max_num_amvs'])
    var_p_cp.standard_name = 'air_pressure'
    var_p_cp.long_name = 'closest raob pressure level to amv within +/- 60mb'
    var_p_cp.units = 'hPa'
    var_p_cp[:,] = np.nan

    # Write data to file  ---------------------
    time[0] = raob_ts
    prf_lon[0, :] = alons
    prf_lat[0, :] = alats
    prf_id[0, :] = ids
    num_amvs[0, :] = num_amvs_per_prof

    #amv_file_names[:] = amv_file_s

    for idx, key in enumerate(keys):
        raob, sta_id = raob_dct.get(key)
        nlevs = raob.shape[0]
        bf_tup = bf_dct.get(key)
        amvs = bf_tup[0]
        num_amvs = amvs.shape[0]

        prf_lvl[0, idx, 0:nlevs] = raob[:, 0]
        prf_dir[0, idx, 0:nlevs] = raob[:, 2]
        prf_spd[0, idx, 0:nlevs] = raob[:, 3]

        for pidx, param in enumerate(amv_params):
            nc4_vars[pidx][0, idx, 0:num_amvs] = amvs[:, pidx]

        bf_raob = bf_tup[1]
        var_u[0, idx, 0:num_amvs] = bf_raob[:, 0]
        var_v[0, idx, 0:num_amvs] = bf_raob[:, 1]
        var_p[0, idx, 0:num_amvs] = bf_raob[:, 2]

        cp_raob = bf_tup[2]
        var_u_cp[0, idx, 0:num_amvs] = cp_raob[:, 0]
        var_v_cp[0, idx, 0:num_amvs] = cp_raob[:, 1]
        var_p_cp[0, idx, 0:num_amvs] = cp_raob[:, 2]

        bf_gfs = bf_tup[3]
        var_u_g[0, idx, 0:num_amvs] = bf_gfs[:, 0]
        var_v_g[0, idx, 0:num_amvs] = bf_gfs[:, 1]
        var_p_g[0, idx, 0:num_amvs] = bf_gfs[:, 2]

        dist_to_raob[0, idx, 0:num_amvs] = haversine_np(key.lon, key.lat, amvs[:, 0], amvs[:, 1])

    rg_exmpl.close()
    rootgrp.close()


def create_amv_to_aeolus_file(match_dct, filename, amv_params, amv_file_s):
    num_aparams = 7
    num_aprofs = 0
    max_num_alevels = 0
    max_num_amvs = 0

    alons = []
    alats = []
    atimes = []
    elems = []
    lines = []
    fidxs = []

    amv_file_s = np.array(amv_file_s, dtype='object')

    # scan to get max num levels, amvs
    keys = list(match_dct.keys())
    for key in keys:
        tup_s = match_dct.get(key)
        for tup in tup_s:
            num_aprofs += 1
            prof = tup[3]
            lat = prof[0, 0]
            lon = prof[0, 1]
            alons.append(lon)
            alats.append(lat)
            atimes.append(key)
            elems.append(tup[0])
            lines.append(tup[1])
            fidxs.append(tup[2])

            nlevs = prof.shape[0]
            if nlevs > max_num_alevels:
                max_num_alevels = nlevs

            amvs = tup[4]
            num_amvs = amvs.shape[1]
            if num_amvs > max_num_amvs:
                max_num_amvs = num_amvs

    alons = np.array(alons)
    alats = np.array(alats)
    atimes = np.array(atimes)
    elems = np.array(elems)
    lines = np.array(lines)
    fidxs = np.array(fidxs)

    # Sample file to retrieve and copy variable attributes
    rg_exmpl = Dataset('/ships19/cloud/scratch/4TH_AMV_INTERCOMPARISON/FMWK2_AMV/GOES16_ABI_2KM_FD_2019293_0020_34_WINDS_AMV_EN-14CT.nc', 'r')

    # the top level group for the output file
    rootgrp = Dataset(filename, 'w', format='NETCDF4')

    dim_aparams = rootgrp.createDimension('num_aeolus_params', size=num_aparams)
    dim_alevs = rootgrp.createDimension('max_num_aeolus_levels', size=max_num_alevels)
    dim_amvs = rootgrp.createDimension('max_num_amvs', size=max_num_amvs)
    dim_num_aeolus_prof = rootgrp.createDimension('num_aeolus_profs', size=num_aprofs)
    dim_num_files = rootgrp.createDimension('num_amv_files', size=amv_file_s.shape[0])

    prf_time = rootgrp.createVariable('time', 'f4', ['num_aeolus_profs'])
    amv_file_names = rootgrp.createVariable('amv_file_names', str, ['num_amv_files'])

    # ---- Profile variables ---------------
    prf_lon = rootgrp.createVariable('prof_longitude', 'f4', ['num_aeolus_profs'])
    prf_lon.units = 'degrees east'
    prf_lat = rootgrp.createVariable('prof_latitude', 'f4', ['num_aeolus_profs'])
    prf_lat.units = 'degrees north'
    prf_time.units = 'seconds since 1970-01-1 00:00:00'
    prf_elem = rootgrp.createVariable('FD_elem', 'f4', ['num_aeolus_profs'])
    prf_line = rootgrp.createVariable('FD_line', 'f4', ['num_aeolus_profs'])
    prf_fidx = rootgrp.createVariable('amv_file_index', 'i4', ['num_aeolus_profs'])
    prf_azm = rootgrp.createVariable('prof_azm', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_azm.units = 'degree'
    prf_spd = rootgrp.createVariable('prof_spd', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_spd.units = 'm s-1'
    prf_hht = rootgrp.createVariable('prof_hht', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_hht.units = 'meter'
    prf_hhb = rootgrp.createVariable('prof_hhb', 'f4', ['num_aeolus_profs', 'max_num_aeolus_levels'])
    prf_hhb.units = 'meter'

    # ----- Product variables ----------------
    nc4_vars = []
    var_s = rg_exmpl.variables
    for pidx, param in enumerate(amv_params):
        v = var_s[param]
        attr_s = v.ncattrs()
        if '_FillValue' in attr_s:
            var = rootgrp.createVariable(param, v.dtype, ['num_aeolus_profs', 'max_num_amvs'], fill_value=v.getncattr('_FillValue'))
        else:
            var = rootgrp.createVariable(param, v.dtype, ['num_aeolus_profs', 'max_num_amvs'])
        # copy attributes from example to new output variable of the same name
        for attr in attr_s:
            if attr != '_FillValue':
                var.setncattr(attr, v.getncattr(attr))
        nc4_vars.append(var)

    # Write data to file  ---------------------
    prf_lon[:] = alons
    prf_lat[:] = alats
    prf_elem[:] = elems
    prf_line[:] = lines
    prf_time[:] = atimes
    prf_fidx[:] = fidxs

    amv_file_names[:] = amv_file_s

    idx = 0
    for key in keys:
        tup_s = match_dct.get(key)
        for tup in tup_s:
            prof = tup[3]
            param_nd = tup[4]

            nlevs = prof.shape[0]
            for k in range(nlevs):
                prf_spd[idx,k] = prof[k,6]
                prf_azm[idx,k] = prof[k,5]
                prf_hht[idx,k] = prof[k,3]
                prf_hhb[idx,k] = prof[k,4]

            for pidx, param in enumerate(amv_params):
                nda = param_nd[pidx+4,]
                cnt = nda.shape[0]
                nc4_vars[pidx][idx, 0:cnt,] = nda
            idx += 1

    rg_exmpl.close()
    rootgrp.close()


def run_clavrx_to_aeolus_match(aeolus_files_dir, clvrx_files_dir, outfile, chan='mie'):
    if os.path.isdir(aeolus_files_dir):
        a_dct = get_aeolus_time_dict_s(aeolus_files_dir, chan=chan)
    else:
        a_dct = get_aeolus_time_dict(aeolus_files_dir)

    a_dct = time_dict_to_nd(a_dct)
    clvrx_files = CLAVRx(clvrx_files_dir)
    m_d = match_aeolus_to_clavrx(a_dct, clvrx_files)
    create_aeolus_clavrx_match_file(m_d, outfile, clvrx_files.get_parameters(), clvrx_files.flist)


# aeolus_files_dir: S4 NOAA txt output files
# amv_files_dir: G16/17 AMV product files
# outfile: pathname for the Netcdf match file
def match_amv_to_aeolus(aeolus_files_dir, amv_files_dir, outfile, amv_source='OPS', band='14', chan='mie'):

    if os.path.isdir(aeolus_files_dir):
        a_d = get_aeolus_time_dict_s(aeolus_files_dir, chan=chan)
    else:
        a_d = get_aeolus_time_dict(aeolus_files_dir)
    a_d = time_dict_to_nd(a_d)

    if amv_source == 'CARR':
        amv_files = get_datasource(amv_files_dir, 'CARR', file_time_span=60, band=band)
    else:
        amv_files = get_datasource(amv_files_dir, amv_source, band=band)

    m_d = match_amvs_to_aeolus_fast(a_d, amv_files_dir, amv_source, band, amv_files)

    create_amv_to_aeolus_file(m_d, outfile, amv_files.get_parameters(), amv_files.flist)


# match_file: pathname for the product file
# dt_str_0: start time (YYYY-MM-DD_HH:MM)
# dt_str_1: end time (YYYY-MM-DD_HH:MM)
# amv_var_names: list of amv parameters (see match_file ncdump) to subset
# returns: Xarray.DataArray
# amvs[nprofs, max_num_amvs_per_prof, num_of_params], profs[nprofs, max_num_levs_per_prof, num_of_params],
# prof_locs[nprofs, (lon, lat)
def subset_by_time(match_file, dt_str_0, dt_str_1, amv_var_names):
    rootgrp = Dataset(match_file, 'r', format='NETCDF4')
    all_dims = rootgrp.dimensions
    t_var = rootgrp['time']

    n_profs = len(all_dims['num_aeolus_profs'])
    n_amvs_per_prof = rootgrp['num_amvs_per_prof'][:]
    n_levs_per_prof = rootgrp['num_levs_per_prof'][:]

    a_nc_vars = []
    for vname in amv_var_names:
        a_nc_vars.append(rootgrp[vname])
    nvars = len(a_nc_vars)

    mf_vars = list(rootgrp.variables.keys())

    p_lon_v = rootgrp['prof_longitude']
    p_lat_v = rootgrp['prof_latitude']
    p_vars = []
    p_var_names = []
    if 'prof_hhb' in mf_vars:
        p_vars.append(rootgrp['prof_hhb'])
        p_var_names.append('layer_bot')
    if 'prof_hht' in mf_vars:
        p_vars.append(rootgrp['prof_hht'])
        p_var_names.append('layer_top')
    if 'prof_spd' in mf_vars:
        p_vars.append(rootgrp['prof_spd'])
        p_var_names.append('speed')
    if 'prof_azm' in mf_vars:
        p_vars.append(rootgrp['prof_azm'])
        p_var_names.append('azimuth')
    npvars = len(p_vars)

    dto = datetime.datetime.strptime(dt_str_0, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
    dto.replace(tzinfo=timezone.utc)
    t_0 = dto.timestamp()
    dto = datetime.datetime.strptime(dt_str_1, '%Y-%m-%d_%H:%M').replace(tzinfo=timezone.utc)
    dto.replace(tzinfo=timezone.utc)
    t_1 = dto.timestamp()

    if t_1 < t_0:
        t = t_0
        t_1 = t_0
        t_0 = t

    times = t_var[:]
    time_idxs = np.arange(n_profs)
    valid = np.logical_and(times >= t_0, times < t_1)
    time_idxs = time_idxs[valid]
    n_times = time_idxs.shape[0]

    lons = p_lon_v[:]
    lats = p_lat_v[:]

    prf_idx_start = np.sum(n_levs_per_prof[0:time_idxs[0]])
    amv_idx_start = np.sum(n_amvs_per_prof[0:time_idxs[0]])

    mx_namvs = np.max(n_amvs_per_prof[time_idxs[0]:time_idxs[0]+n_times])
    mx_nlevs = np.max(n_levs_per_prof[time_idxs[0]:time_idxs[0]+n_times])

    amvs = np.zeros((n_times, mx_namvs, nvars))
    profs = np.zeros((n_times, mx_nlevs, npvars))
    amvs.fill(-999.0)
    profs.fill(-999.0)

    accum_prf = prf_idx_start
    accum_amv = amv_idx_start
    for idx, t_i in enumerate(time_idxs):
        n_amvs = n_amvs_per_prof[t_i]
        n_levs = n_levs_per_prof[t_i]

        a = accum_amv
        b = accum_amv + n_amvs

        c = accum_prf
        d = accum_prf + n_levs

        for k in range(nvars):
            amvs[idx, 0:n_amvs, k] = a_nc_vars[k][a:b]

        for k in range(npvars):
            profs[idx, 0:n_levs, k] = p_vars[k][c:d]

        accum_amv += n_amvs
        accum_prf += n_levs

    coords = {'num_profs': times[time_idxs], 'num_params': p_var_names}
    prof_da = xr.DataArray(profs, coords=coords, dims=['num_profs', 'max_num_levels', 'num_params'])

    coords = {'num_profs': times[time_idxs], 'num_params': amv_var_names}
    amvs_da = xr.DataArray(amvs, coords=coords, dims=['num_profs', 'max_num_amvs', 'num_params'])

    prof_locs_da = xr.DataArray(np.column_stack([lons[time_idxs], lats[time_idxs], n_amvs_per_prof[time_idxs], n_levs_per_prof[time_idxs]]),
                                coords=[times[time_idxs], ['longitude', 'latitude', 'num_amvs', 'num_levels']],
                                dims=['num_profs', 'space'])

    return prof_da, amvs_da, prof_locs_da


def analyze(prof_da, amvs_da, prof_locs_da, dist_threshold=5.0):
    # sort amvs by distance to profile
    dst = amvs_da.sel(num_params='dist_to_prof')
    s_i = np.argsort(dst, axis=1)
    s_i = s_i.values
    for k in range(amvs_da.shape[0]):
        amvs_da[k] = amvs_da[k, s_i[k]]

    # sort profiles by level highest to lowest
    top = prof_da.sel(num_params='layer_top')
    s_i = np.argsort(top, axis=1)
    s_i = s_i.values
    for k in range(prof_da.shape[0]):
        prof_da[k] = prof_da[k, s_i[k, ::-1]]

    # analyze single cloud layer profiles
    one_lyr = prof_locs_da.sel(space='num_levels') == 1
    one_lyr_profs = prof_da[one_lyr]
    one_lyr_amvs = amvs_da[one_lyr]
    print('number of one layer profs: ', one_lyr_profs.shape[0])

    hgt_vld = one_lyr_amvs.sel(num_params='H_3D') > 0
    hgt_vld = hgt_vld.values
    has_vld = hgt_vld.sum(1) > 0

    one_lyr_amvs = one_lyr_amvs[has_vld]
    one_lyr_profs = one_lyr_profs[has_vld]
    print('number of one layer profs with at least one valid AMV height', one_lyr_profs.shape[0])

    # compare profile highest cloud layer top to closest within 5km
    dst = one_lyr_amvs.sel(num_params='dist_to_prof')
    close = (np.logical_and(dst > 0.0, dst < dist_threshold)).values
    close = close.sum(1) > 0

    one_lyr_amvs = one_lyr_amvs[close]
    one_lyr_profs = one_lyr_profs[close]
    num_one_lyr_profs = one_lyr_profs.shape[0]
    print('number of one layer profs with at least one AMV within threshold: ', num_one_lyr_profs)

    cnt = 0
    prof_bot = one_lyr_profs.sel(num_params='layer_bot')
    prof_top = one_lyr_profs.sel(num_params='layer_top')
    for k in range(num_one_lyr_profs):
        dst = one_lyr_amvs[k, :, ].sel(num_params='dist_to_prof')
        b = np.logical_and(dst > 0.0, dst < dist_threshold)
        h_3d = one_lyr_amvs[k, :].sel(num_params='H_3D')
        h_3d = h_3d[b]
        vld = h_3d > 0
        h_3d = h_3d[vld]
        if len(h_3d) > 0:
            in_lyr = np.logical_and(h_3d > prof_bot[k, 0], h_3d < prof_top[k, 0])
            cnt += np.sum(in_lyr)

    print('fraction hits single cloud layer: ', cnt/num_one_lyr_profs)

    # Do calculations for single and multi-layer combined
    hgt_vld = amvs_da.sel(num_params='H_3D') > 0
    hgt_vld = hgt_vld.values
    has_vld = hgt_vld.sum(1) > 0

    amvs_da = amvs_da[has_vld]
    prof_da = prof_da[has_vld]
    prof_locs_da = prof_locs_da[has_vld]
    print('number of profs with at least one valid AMV height', prof_da.shape[0])

    # compare profile highest cloud layer top to closest within 5km
    dst = amvs_da.sel(num_params='dist_to_prof')
    close = (np.logical_and(dst > 0.0, dst < dist_threshold)).values
    close = close.sum(1) > 0

    amvs_da = amvs_da[close]
    prof_da = prof_da[close]
    prof_locs_da = prof_locs_da[close]
    num_profs = prof_da.shape[0]
    print('number of profs with at least one AMV within 5km: ', num_profs)

    cnt = 0
    prof_bot = prof_da.sel(num_params='layer_bot')
    prof_top = prof_da.sel(num_params='layer_top')
    for k in range(num_profs):
        dst = amvs_da[k, :, ].sel(num_params='dist_to_prof')
        b = np.logical_and(dst > 0.0, dst < dist_threshold)
        h_3d = amvs_da[k, :].sel(num_params='H_3D')
        h_3d = h_3d[b]
        vld = h_3d > 0
        h_3d = h_3d[vld]
        if len(h_3d) > 0:
            nlevs = prof_locs_da[k].values.astype(int)[3]
            for j in range(nlevs):
                in_lyr = np.logical_and(h_3d > prof_bot[k, j], h_3d < prof_top[k, j])
                cnt += np.sum(in_lyr)

    print('fraction hits multi layer: ', cnt/num_profs)
    return one_lyr_profs, one_lyr_amvs