import numpy as np
import xarray as xr
# from util.setup import home_dir
from util.lon_lat_grid import earth_to_indexs
from util.util import haversine_np
from util.gfs_reader import get_vert_profile_s
# import cartopy
from cartopy import *
import metpy.calc as mpcalc
import cartopy.crs as ccrs
from metpy.units import units
import h5py
import time
# from scipy.interpolate import interp1d
from util.util import minimize_quadratic
from netCDF4 import Dataset
import datetime
from datetime import timezone
from util.util import LatLonTuple
import re

# -- AMV intercompare stuff ------------------------------------------
# --------------------------------------------------------------------

#amv_hdr_list = ['lat', 'lon', 'tbox', 'sbox', 'spd', 'dir', 'pres', 'lowl', 'mspd', 'mdir',
#                'alb', 'corr', 'tmet', 'perr', 'qi', 'cqi', 'qif']
amv_hdr_list = ['id', 'lat', 'lon', 'tbox', 'sbox', 'spd', 'dir', 'pres', 'lowl', 'mspd', 'mdir',
                'alb', 'corr', 'tmet', 'perr', 'hmet', 'qi', 'cqi', 'qif']

raob_hdr_list = ['pres', 'temp', 'dir', 'spd']

raob_spd_idx = raob_hdr_list.index('spd')
raob_dir_idx = raob_hdr_list.index('dir')
amv_lon_idx = amv_hdr_list.index('lon')
amv_lat_idx = amv_hdr_list.index('lat')
amv_pres_idx = amv_hdr_list.index('pres')
raob_pres_idx = raob_hdr_list.index('pres')
amv_spd_idx = amv_hdr_list.index('spd')
amv_dir_idx = amv_hdr_list.index('dir')
amv_prs_idx = amv_hdr_list.index('pres')
amv_qi_idx = amv_hdr_list.index('qi')
amv_cqi_idx = amv_hdr_list.index('cqi')
amv_qif_idx = amv_hdr_list.index('qif')

amv_centers_list = ['EUM', 'BRZ', 'JMA', 'KMA', 'NOA', 'NWC']


def get_amv_nd(filename, delimiter=';'):
    header = None
    data = []

    with open(filename) as file:
        for line in file:
            if header is None:
                header = line
                continue
            toks = line.split(delimiter)
            id = float(toks[0])
            lat = float(toks[1])
            lon = float(toks[2])
            tbox = float(toks[3])
            sbox = float(toks[4])
            spd = float(toks[5])
            dir = float(toks[6])
            pres = float(toks[7])
            lowl = float(toks[8])
            mspd = float(toks[9])
            mdir = float(toks[10])
            alb = float(toks[11])
            corr = float(toks[12])
            tmet = float(toks[13])
            perr = float(toks[14])
            hmet = float(toks[15])
            qi = float(toks[16])
            qif = float(toks[17])
            cqi = float(toks[18])

            if pres < 10 or pres > 1020:
                continue
            if lat < -90 or lat > 90:
                continue
            if lon < -180 or lon > 180:
                continue
            if spd < 0 or spd > 150:
                continue
            if dir < 0 or dir > 360:
                continue

            #tup = (lat, lon, tbox, sbox, spd, dir, pres, lowl, mspd, mdir, alb, corr, tmet, perr, qi, cqi, qif)
            tup = (id, lat, lon, tbox, sbox, spd, dir, pres, lowl, mspd, mdir, alb, corr, tmet, perr, hmet, qi, cqi, qif)

            data.append(tup)

    return np.array(data)


def filter_amvs(amvs, qitype=amv_cqi_idx, qival=50, lat_range=[-61, 61]):
    if lat_range is not None:
        keep = np.logical_and(amvs[:, amv_lat_idx] > lat_range[0], amvs[:, amv_lat_idx] < lat_range[1])
        if np.sum(keep) == 0:
            return None
        amvs = amvs[keep, :]

    if qitype is None:
        return amvs

    keep = amvs[:, qitype] >= qival

    if np.sum(keep) == 0:
        return None
    else:
        return amvs[keep, :]


def get_raob_dict_cdf(pathname):
    rtgrp = Dataset(pathname, 'r', format='NETCDF3')

    wmo_stat_id_var = rtgrp['wmoStat']

    sta_lats_var = rtgrp['staLat']
    sta_lats_var.set_always_mask(False)
    sta_lats_var.set_auto_mask(False)

    sta_lons_var = rtgrp['staLon']
    sta_lons_var.set_always_mask(False)
    sta_lons_var.set_auto_mask(False)

    sta_lats = sta_lats_var[:]
    sta_lons = sta_lons_var[:]
    sta_ids = wmo_stat_id_var[:]

    sta_msk = np.logical_and(sta_lats > -90.0, sta_lats < 90.0)
    # sta_msk = np.invert(sta_lats._mask)
    num_sta = np.sum(sta_msk)
    sta_lats = sta_lats[sta_msk]
    sta_lons = sta_lons[sta_msk]
    sta_ids = sta_ids[sta_msk]

    # mandatory levels
    man_levs = rtgrp['prMan']
    man_levs = man_levs[:, :]
    num_man_levs = rtgrp['numMand']
    num_man_levs = num_man_levs[:]._data
    man_temp = rtgrp['tpMan']
    man_temp = man_temp[:, :]
    man_spd = rtgrp['wsMan']
    man_spd = man_spd[:, :]
    man_dir = rtgrp['wdMan']
    man_dir = man_dir[:, :]

    # significant levels
    sig_levs = rtgrp['prSigW']
    sig_levs = sig_levs[:, :]
    num_sig_levs = rtgrp['numSigW']
    num_sig_levs = num_sig_levs[:]._data
    sig_temp = rtgrp['tpSigW']
    sig_temp = sig_temp[:, :]
    sig_spd = rtgrp['wsSigW']
    sig_spd = sig_spd[:, :]
    sig_dir = rtgrp['wdSigW']
    sig_dir = sig_dir[:, :]

    man_levs = man_levs[sta_msk,]
    num_man_levs = num_man_levs[sta_msk]

    sig_levs = sig_levs[sta_msk,]
    num_sig_levs = num_sig_levs[sta_msk]

    man_temp = man_temp[sta_msk,]
    man_spd = man_spd[sta_msk,]
    man_dir = man_dir[sta_msk,]

    sig_temp = sig_temp[sta_msk,]
    sig_spd = sig_spd[sta_msk,]
    sig_dir = sig_dir[sta_msk,]

    raob_dct = {}
    for k in range(num_sta):
        lon = sta_lons[k]
        lat = sta_lats[k]
        id = sta_ids[k]

        vld_man = np.invert(man_levs[k]._mask)
        vld_man = np.logical_and(vld_man, np.invert(man_spd[k]._mask))
        vld_man = np.logical_and(vld_man, np.invert(man_dir[k]._mask))

        vld_sig = np.invert(sig_levs[k]._mask)
        vld_sig = np.logical_and(vld_sig, np.invert(sig_spd[k]._mask))
        vld_sig = np.logical_and(vld_sig, np.invert(sig_dir[k]._mask))

        vld_man_levs = man_levs[k, vld_man]
        vld_man_spd = man_spd[k, vld_man]
        vld_man_dir = man_dir[k, vld_man]
        vld_man_temp = man_temp[k, vld_man]

        s_idxs = np.argsort(vld_man_levs)
        s_man_levs = vld_man_levs[s_idxs]
        s_man_spd = vld_man_spd[s_idxs]
        s_man_dir = vld_man_dir[s_idxs]
        s_man_temp = vld_man_temp[s_idxs]

        vld_sig_levs = sig_levs[k, vld_sig]
        vld_sig_spd = sig_spd[k, vld_sig]
        vld_sig_dir = sig_dir[k, vld_sig]
        vld_sig_temp = sig_temp[k, vld_sig]

        all_levs = np.append(s_man_levs, vld_sig_levs)
        all_temp = np.append(s_man_temp, vld_sig_temp)
        all_spd = np.append(s_man_spd, vld_sig_spd)
        all_dir = np.append(s_man_dir, vld_sig_dir)

        if len(all_levs) == 0:
            continue

        s_idxs = np.argsort(all_levs)
        all_levs = all_levs[s_idxs]
        all_temp = all_temp[s_idxs]
        all_spd = all_spd[s_idxs]
        all_dir = all_dir[s_idxs]

        all_levs = all_levs[::-1]
        all_temp = all_temp[::-1]
        all_spd = all_spd[::-1]
        all_dir = all_dir[::-1]

        raob = np.stack([all_levs._data, all_temp._data, all_dir._data, all_spd._data], axis=1)

        raob_dct[LatLonTuple(lat=lat, lon=lon)] = [raob, id]

    rtgrp.close()

    return raob_dct


def get_raob_dict(filename):
    header = None
    data = []
    raob_dict = {}

    with open(filename) as file:
        for line in file:
            if header is None:
                header = line
                continue
            toks = line.split()

            lat = float(toks[0])
            lon = float(toks[1])
            pres = float(toks[2])
            temp = float(toks[3])
            dir = float(toks[4])
            spd = float(toks[5])

            if pres < 0 or pres > 1050:
                continue
            if lat < -90 or lat > 90:
                continue
            if lon < -180 or lon > 180:
                continue
            if spd < 0 or spd > 150:
                continue
            if dir < 0 or dir > 360:
                continue

            tup = (lat, lon, pres, temp, dir, spd)
            data.append(tup)

        file.close()

    key = None
    for tup in data:
        loc = (tup[0], tup[1])
        rtup = (tup[2], tup[3], tup[4], tup[5])

        if loc != key:
            lst = []
            lst.append(rtup)
            key = loc
            raob_dict[key] = lst
        else:
            lst = raob_dict.get(loc)
            lst.append(rtup)

    return raob_dict_to_nd(raob_dict)


def raob_dict_to_nd(raob_dict):
    loc_keys = raob_dict.keys()
    loc_to_nd = {}

    for key in loc_keys:
        obs_lst = raob_dict.get(key)
        loc_to_nd[key] = np.array(obs_lst)

    return loc_to_nd


def amv_bulk_stats(amvs, qitype=amv_cqi_idx, qival=50):
    idx = amv_hdr_list.index('lat')
    lat_min = amvs[:, idx].min()
    lat_max = amvs[:, idx].max()
    print(lat_min, lat_max)

    idx = amv_hdr_list.index('lon')
    lon_min = amvs[:, idx].min()
    lon_max = amvs[:, idx].max()
    print(lon_min, lon_max)

    good = amvs[:, qitype] >= qival
    num_good = np.sum(good)
    good_amvs = amvs[good, :]

    print('------------------------------')
    print('num good amvs: ', num_good)

    sidx = amv_hdr_list.index('spd')
    spd_min = good_amvs[:, sidx].min()
    spd_max = good_amvs[:, sidx].max()
    print('spd min/max/mean: ', spd_min, spd_max, np.average(good_amvs[:, sidx]))

    pidx = amv_hdr_list.index('pres')
    p_min = good_amvs[:, pidx].min()
    p_max = good_amvs[:, pidx].max()
    print('pres min/max/mean: ', p_min, p_max, np.average(good_amvs[:, pidx]))

    low = good_amvs[:, pidx] >= 700
    mid = np.logical_and(good_amvs[:, pidx] < 700, good_amvs[:, pidx] > 400)
    hgh = good_amvs[:, pidx] <= 400

    n_low = np.sum(low)
    n_mid = np.sum(mid)
    n_hgh = np.sum(hgh)

    print('% low: ', 100.0*(n_low/num_good))
    print('% mid: ', 100.0*(n_mid/num_good))
    print('% hgh: ', 100.0*(n_hgh/num_good))
    print('---------------------------')

    print('Low Spd min/max/mean: ', good_amvs[low, sidx].min(), good_amvs[low, sidx].max(), good_amvs[low,sidx].mean())
    print('Low Press min/max/mean: ', good_amvs[low, pidx].min(), good_amvs[low, pidx].max(), good_amvs[low, pidx].mean())
    print('---------------------------')

    print('Mid Spd min/max/mean: ', good_amvs[mid, sidx].min(), good_amvs[mid, sidx].max(), good_amvs[mid, sidx].mean())
    print('Mid Press min/max/mean: ', good_amvs[mid, pidx].min(), good_amvs[mid, pidx].max(), good_amvs[mid, pidx].mean())
    print('---------------------------')

    print('Hgh Spd min/max/mean: ', good_amvs[hgh, sidx].min(), good_amvs[hgh, sidx].max(), good_amvs[hgh, sidx].mean())
    print('Hgh Press min/max/mean: ', good_amvs[hgh, pidx].min(), good_amvs[hgh, pidx].max(), good_amvs[hgh, pidx].mean())

    print('----------------------------')


press_diff_threshold = 50.0  # mb
dist_threshold = 150.0  # km
km_per_deg = 111.0


# Find raobs within threshold distance for each (lat, lon) point
def get_raob_neighbors(lats, lons, raob_dct, dist_threshold=150):
    nbor_dct = {}

    num_pts = lats.shape[0]
    keys = list(raob_dct.keys())

    rlat_s = []
    rlon_s = []
    for ll_tup in keys:
        rlat = ll_tup[0]
        rlon = ll_tup[1]
        rlat_s.append(rlat)
        rlon_s.append(rlon)
    rlats = np.array(rlat_s)
    rlons = np.array(rlon_s)

    for j in range(num_pts):
        alat = lats[j]
        alon = lons[j]

        dist = haversine_np(rlons, rlats, alon, alat)
        ridxs = (np.nonzero(dist < dist_threshold))[0]

        if ridxs.shape[0] != 0:
            # sort by distance
            dist = dist[ridxs]
            sidxs = np.argsort(dist)
            dist = dist[sidxs]
            ridxs = ridxs[sidxs]

            nbor_keys = []
            for i in range(ridxs.shape[0]):
                nbor_keys.append(keys[ridxs[i]])
            nbor_dct[j] = (nbor_keys, dist)
        else:
            nbor_dct[j] = None

    return nbor_dct


# Find points (lat, lon) within threshold distance to each raob
def get_neighbors_to_raob(lats, lons, raob_dct, dist_threshold=150):
    nbor_dct = {}

    keys = list(raob_dct.keys())

    for ll_tup in keys:
        rlat = ll_tup[0]
        rlon = ll_tup[1]

        dist = haversine_np(lons, lats, rlon, rlat)

        idxs = (np.nonzero(dist < dist_threshold))[0]

        if idxs.shape[0] != 0:
            nbor_keys = []
            for i in range(idxs.shape[0]):
                nbor_keys.append(idxs[i])
            nbor_dct[ll_tup] = (nbor_keys, dist[idxs])
        else:
            nbor_dct[ll_tup] = None

    return nbor_dct


def amv_raob_diff_2(amvs, raob_dct):
    aspd = amvs[:, amv_spd_idx]
    adir = amvs[:, amv_dir_idx]
    aprs = amvs[:, amv_prs_idx]
    alat = amvs[:, amv_lat_idx]
    alon = amvs[:, amv_lon_idx]

    spd_diffs, dir_diffs = amv_raob_diff(alat, alon, aspd, adir, aprs, raob_dct)

    return spd_diffs, dir_diffs


def amv_raob_diff(amv_lat, amv_lon, amv_spd, amv_dir, amv_press, raob_dct):
    num_amvs = amv_lat.shape[0]

    amv_u = -amv_spd * np.sin((np.pi/180.0) * amv_dir)
    amv_v = -amv_spd * np.cos((np.pi/180.0) * amv_dir)

    raob_nbor_dct = get_raob_neighbors(amv_lat, amv_lon, raob_dct)

    dir_diffs = []
    spd_diffs = []
    u_diffs = []
    v_diffs = []
    prs_diffs = []

    for j in range(num_amvs):  # compute stats
        tup = raob_nbor_dct[j]
        if tup is not None:
            nbor_keys = tup[0]
            # nbor_dsts = tup[1] future use maybe

            for ll_tup in nbor_keys:  # using only closest (see break below)
                raob = raob_dct.get(ll_tup)

                pdiff = amv_press[j] - raob[:, raob_pres_idx]
                lev_idx = np.argmin(np.abs(pdiff))
                if np.abs(pdiff[lev_idx]) > press_diff_threshold:
                    continue

                aspd = amv_spd[j]
                adir = amv_dir[j]
                rdir = raob[lev_idx, raob_dir_idx]
                rspd = raob[lev_idx, raob_spd_idx]

                ddiff = direction_difference(adir, rdir)
                dir_diffs.append(ddiff)

                spd_diffs.append(aspd - rspd)

                au = amv_u[j]
                av = amv_v[j]

                ru = -rspd * np.sin((np.pi / 180.0) * rdir)
                rv = -rspd * np.cos((np.pi / 180.0) * rdir)

                udiff = au - ru
                vdiff = av - rv

                u_diffs.append(udiff)
                v_diffs.append(vdiff)
                prs_diffs.append(pdiff[lev_idx])

                break  # Use first (closest) raob only

    num = len(u_diffs)
    dir_diffs = np.array(dir_diffs)
    spd_diffs = np.array(spd_diffs)
    prs_diffs = np.array(prs_diffs)
    u_diffs = np.array(u_diffs)
    v_diffs = np.array(v_diffs)

    # spd_bias = np.sqrt(np.sum(u_diffs)**2 + np.sum(v_diffs)**2) / num
    # spd_rms = np.sqrt((np.sqrt(np.sum(np.array(u_diffs) ** 2)/num))**2 + (np.sqrt(np.sum(np.array(v_diffs) ** 2)/num))**2)
    # dir_bias = np.average(dir_diffs)
    print('num amvs: ', num_amvs)
    print('num samples: ', num)

    spd_mean = np.mean(spd_diffs)
    spd_std = np.std(spd_diffs)
    print('spd MAD/bias/rms: ', np.average(np.abs(spd_diffs)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
    print('-----------------')

    dir_mean = np.mean(dir_diffs)
    print('dir bias: ', dir_mean)
    print('-----------------------')

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
    print('--------------')

    pd_std = np.std(prs_diffs)
    pd_mean = np.mean(prs_diffs)
    print('press bias/rms ', pd_mean, np.sqrt(pd_mean**2 + pd_std**2))
    print('------------------------------------------')

    return spd_diffs, dir_diffs


def run_best_fit(amvs, raob_dct, dist_threshold=200, min_num_levs=20):
    num_amvs = amvs.shape[0]
    bestfits = []

    raob_nbor_dct = get_raob_neighbors(amvs[:, amv_lat_idx], amvs[:, amv_lon_idx], raob_dct, dist_threshold=dist_threshold)

    for j in range(num_amvs):
        tup = raob_nbor_dct[j]
        if tup is not None:
            nbor_keys = tup[0]
            nbor_dsts = tup[1]
            s_i = np.argsort(nbor_dsts).tolist()

            for k in range(len(nbor_keys)):
                # dist = nbor_dsts[s_i[k]]
                ll_tup = nbor_keys[s_i[k]]
                raob = raob_dct.get(ll_tup)

                if raob.shape[0] < min_num_levs:
                    continue

                rspd = raob[:, raob_spd_idx]
                rdir = raob[:, raob_dir_idx]
                rprs = raob[:, raob_pres_idx]

                aspd = amvs[j, amv_spd_idx]
                adir = amvs[j, amv_dir_idx]
                aprs = amvs[j, amv_prs_idx]
                alat = amvs[j, amv_lat_idx]
                alon = amvs[j, amv_lon_idx]

                bf = best_fit(aspd, adir, aprs, alat, alon, rspd, rdir, rprs)
                bf = (j, bf[0], bf[1], bf[2], bf[3])
                if bf[4] == 0:
                    bestfits.append(bf)
                    break

    return bestfits


def run_amv_clavr_compare(amvs, clvr_filename):
    h5file = h5py.File(clvr_filename, 'r')

    # Keep this for reference
    # lons = h5file['longitude'][0:2500, :]
    # lats = h5file['latitude'][0:2500, :]
    # lons = np.where(lons < 0, lons + 360, lons)
    # lons = np.where(lons == -639, np.nan, lons)
    # lats = np.where(lats == -999.0, np.nan, lats)
    # ll_grid = LonLatGrid(lons, lats)

    t0 = int(round(time.time() * 1000))
    t1 = int(round(time.time() * 1000))

    a_lons = []
    a_lats = []
    for k in range(5000):
        lon = amvs[k, amv_lon_idx]
        lat = amvs[k, amv_lat_idx]
        a_lons.append(lon)
        a_lats.append(lat)
    a_lons = np.array(a_lons)
    a_lats = np.array(a_lats)
    idxs = earth_to_indexs(a_lons, a_lats, 5500)
    print('search time: ', (int(round(time.time() * 1000)) - t1))
    print('--------------')


def run_best_fit_gfs_all(amvs, gfs_filename):
    num_centers = len(amvs)

    bfs_s = []
    for k in range(num_centers):
        bfs = run_best_fit_gfs(amvs[k], gfs_filename)
        bfs_s.append(bfs)

    return bfs_s


def run_best_fit_gfs(amvs, gfs_filename, amv_lat_idx=amv_lat_idx, amv_lon_idx=amv_lon_idx, amv_prs_idx=amv_prs_idx,
                     amv_spd_idx=amv_spd_idx, amv_dir_idx=amv_dir_idx, method='nearest'):
    xr_dataset = xr.open_dataset(gfs_filename)
    num_amvs = amvs.shape[0]
    bestfits = []

    gfs_press = xr_dataset['pressure levels']
    gfs_press = gfs_press.data
    gfs_press = gfs_press[::-1]

    uv_wind = get_vert_profile_s(xr_dataset, ['u-wind', 'v-wind'], amvs[:, amv_lon_idx], amvs[:, amv_lat_idx], method=method)
    uv_wind = uv_wind.data
    wspd, wdir = spd_dir_from_uv(uv_wind[0,:,:], uv_wind[1,:,:])
    wspd = wspd.magnitude
    wdir = wdir.magnitude
    wspd = wspd[:,::-1]
    wdir = wdir[:,::-1]

    for j in range(num_amvs):

        aspd = amvs[j, amv_spd_idx]
        adir = amvs[j, amv_dir_idx]
        aprs = amvs[j, amv_prs_idx]
        alat = amvs[j, amv_lat_idx]
        alon = amvs[j, amv_lon_idx]

        gfs_spd = wspd[j]
        gfs_dir = wdir[j]

        bf = best_fit(aspd, adir, aprs, alat, alon, gfs_spd, gfs_dir, gfs_press)
        # bf = (j, bf[0], bf[1], bf[2], bf[3])
        # if bf[4] == 0:
        #     bestfits.append(bf)
        bestfits.append(bf)

    return bestfits


def raob_gfs_diff(raob_dct, gfs_filename):
    r_press = 900.0
    n_raobs = len(raob_dct)
    keys = list(raob_dct.keys())
    rb_lats = []
    rb_lons = []
    rb_spd = []
    rb_dir = []

    for ll_tup in keys:

        prof = raob_dct[ll_tup]
        raob_pres_prof = raob_dct[ll_tup][:, raob_pres_idx]
        pdiff = r_press - raob_pres_prof
        lev_idx = np.argmin(np.abs(pdiff))
        rb_lats.append(ll_tup[0])
        rb_lons.append(ll_tup[1])
        rb_spd.append(prof[lev_idx, raob_spd_idx])
        rb_dir.append(prof[lev_idx, raob_dir_idx])

    rb_lons = np.array(rb_lons)
    rb_lats = np.array(rb_lats)
    rb_spd = np.array(rb_spd)
    rb_dir = np.array(rb_dir)


    xr_dataset = xr.open_dataset(gfs_filename)

    gfs_press = xr_dataset['pressure levels']
    g_lev_idx = np.argmin(np.abs(500 - gfs_press))

    uv_wind = get_profile_multi(xr_dataset, ['u-wind', 'v-wind'], rb_lons, rb_lats)
    uv_wind = uv_wind.data
    wspd, wdir = spd_dir_from_uv(uv_wind[0,:,:], uv_wind[1,:,:])
    g_wspd = wspd.magnitude
    g_wdir = wdir.magnitude

    dir_diffs = []
    spd_diffs = []
    u_diffs = []
    v_diffs = []

    for j in range(n_raobs):

        rspd = rb_spd[j]
        rdir = rb_dir[j]

        gspd = g_wspd[j, g_lev_idx]
        gdir = g_wdir[j, g_lev_idx]
        print(rspd, gspd, ':', rdir, gdir)

        ddiff = direction_difference(gdir, rdir)
        dir_diffs.append(ddiff)

        spd_diffs.append(gspd - rspd)

        gu = -gspd * np.sin((np.pi / 180.0) * gdir)
        gv = -gspd * np.cos((np.pi / 180.0) * gdir)

        ru = -rspd * np.sin((np.pi / 180.0) * rdir)
        rv = -rspd * np.cos((np.pi / 180.0) * rdir)

        udiff = gu - ru
        vdiff = gv - rv

        u_diffs.append(udiff)
        v_diffs.append(vdiff)


    num = len(u_diffs)
    dir_diffs = np.array(dir_diffs)
    spd_diffs = np.array(spd_diffs)
    u_diffs = np.array(u_diffs)
    v_diffs = np.array(v_diffs)

    # spd_bias = np.sqrt(np.sum(u_diffs)**2 + np.sum(v_diffs)**2) / num
    # spd_rms = np.sqrt((np.sqrt(np.sum(np.array(u_diffs) ** 2)/num))**2 + (np.sqrt(np.sum(np.array(v_diffs) ** 2)/num))**2)
    # dir_bias = np.average(dir_diffs)
    print('num rabos: ', n_raobs)
    print('num samples: ', num)

    spd_mean = np.mean(spd_diffs)
    spd_std = np.std(spd_diffs)
    print('spd MAD/bias/rms: ', np.average(np.abs(spd_diffs)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
    print('-----------------')

    dir_mean = np.mean(dir_diffs)
    print('dir bias: ', dir_mean)
    print('-----------------------')

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
    print('--------------')

    return None


def amv_gfs_diff(amvs, gfs_filename):
    xr_dataset = xr.open_dataset(gfs_filename)
    num_amvs = amvs.shape[0]

    gfs_press = xr_dataset['pressure levels']

    uv_wind = get_profile_multi(xr_dataset, ['u-wind', 'v-wind'], amvs[:, amv_lon_idx], amvs[:, amv_lat_idx])
    uv_wind = uv_wind.data
    wspd, wdir = spd_dir_from_uv(uv_wind[0,:,:], uv_wind[1,:,:])
    wspd = wspd.magnitude
    wdir = wdir.magnitude

    amv_u = -amvs[:, amv_spd_idx] * np.sin((np.pi / 180.0) * amvs[:, amv_dir_idx])
    amv_v = -amvs[:, amv_spd_idx] * np.cos((np.pi / 180.0) * amvs[:, amv_dir_idx])

    dir_diffs = []
    spd_diffs = []
    u_diffs = []
    v_diffs = []
    prs_diffs = []

    for j in range(num_amvs):

        aspd = amvs[j, amv_spd_idx]
        adir = amvs[j, amv_dir_idx]
        aprs = amvs[j, amv_prs_idx]

        pdiff = aprs - gfs_press
        lev_idx = np.argmin(np.abs(pdiff))
        if np.abs(pdiff[lev_idx]) > press_diff_threshold:
            continue

        rspd = wspd[j, lev_idx]
        rdir = wdir[j, lev_idx]

        ddiff = direction_difference(adir, rdir)
        dir_diffs.append(ddiff)

        spd_diffs.append(aspd - rspd)

        au = amv_u[j]
        av = amv_v[j]

        ru = -rspd * np.sin((np.pi / 180.0) * rdir)
        rv = -rspd * np.cos((np.pi / 180.0) * rdir)

        udiff = au - ru
        vdiff = av - rv

        u_diffs.append(udiff)
        v_diffs.append(vdiff)
        prs_diffs.append(pdiff[lev_idx])

    num = len(u_diffs)
    dir_diffs = np.array(dir_diffs)
    spd_diffs = np.array(spd_diffs)
    prs_diffs = np.array(prs_diffs)
    u_diffs = np.array(u_diffs)
    v_diffs = np.array(v_diffs)

    # spd_bias = np.sqrt(np.sum(u_diffs)**2 + np.sum(v_diffs)**2) / num
    # spd_rms = np.sqrt((np.sqrt(np.sum(np.array(u_diffs) ** 2)/num))**2 + (np.sqrt(np.sum(np.array(v_diffs) ** 2)/num))**2)
    # dir_bias = np.average(dir_diffs)
    print('num amvs: ', num_amvs)
    print('num samples: ', num)

    spd_mean = np.mean(spd_diffs)
    spd_std = np.std(spd_diffs)
    print('spd MAD/bias/rms: ', np.average(np.abs(spd_diffs)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
    print('-----------------')

    dir_mean = np.mean(dir_diffs)
    print('dir bias: ', dir_mean)
    print('-----------------------')

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
    print('--------------')

    pd_std = np.std(prs_diffs)
    pd_mean = np.mean(prs_diffs)
    print('press bias/rms ', pd_mean, np.sqrt(pd_mean**2 + pd_std**2))
    print('------------------------------------------')

    return None


# ddiff = np.abs(adir - rdir)
# if ddiff > 180:
#     ddiff = 360 - ddiff
#     if rdir < adir:
#         ddiff = -1 * ddiff
# else:
#     if rdir > adir:
#         ddiff = -1 * ddiff
def direction_difference(dir_a, dir_b):
    diff = np.abs(dir_a - dir_b)

    d180 = diff > 180

    diff = np.where(d180, 360 - diff, diff)

    diff = np.where(d180, np.where(dir_b < dir_a, -1 * diff, diff), np.where(dir_b > dir_a, -1 * diff, diff))

    return diff


def get_amv_winds_match(dist_threshold=50, qitype=amv_cqi_idx, qival=50, lat_range=[-61, 61]):
    #----- match all to EUM
    amvs_eum = get_amv_nd('/Users/tomrink/data/amv_intercompare/EUM321_csv.csv')
    amvs_brz = get_amv_nd('/Users/tomrink/data/amv_intercompare/BRZCPTECfin_121_csv.csv')
    amvs_jma = get_amv_nd('/Users/tomrink/data/amv_intercompare/JMA421_csv.csv')
    amvs_kma = get_amv_nd('/Users/tomrink/data/amv_intercompare/KMA521NI_csv.csv')
    amvs_noa = get_amv_nd('/Users/tomrink/data/amv_intercompare/NOA621_csv.csv')
    amvs_nwc = get_amv_nd('/Users/tomrink/data/amv_intercompare/NWC721_csv.csv')

    amvs_eum = filter_amvs(amvs_eum, qitype=qitype, qival=qival, lat_range=lat_range)
    amvs_brz = filter_amvs(amvs_brz, qitype=qitype, qival=qival, lat_range=lat_range)
    amvs_jma = filter_amvs(amvs_jma, qitype=qitype, qival=qival, lat_range=lat_range)
    amvs_kma = filter_amvs(amvs_kma, qitype=qitype, qival=qival, lat_range=lat_range)
    amvs_nwc = filter_amvs(amvs_nwc, qitype=qitype, qival=qival, lat_range=lat_range)
    amvs_noa = filter_amvs(amvs_noa, qitype=qitype, qival=qival, lat_range=lat_range)

    lat_idx = amv_hdr_list.index('lat')
    lon_idx = amv_hdr_list.index('lon')

    num_eum = amvs_eum.shape[0]
    print('num EUM amvs: ', num_eum)

    eum_idxs = []
    brz_idxs = []
    jma_idxs = []
    kma_idxs = []
    noa_idxs = []
    nwc_idxs = []

    cnt = 0
    for i in range(num_eum):
        dist = haversine_np(amvs_nwc[:, lon_idx], amvs_nwc[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
        i_nwc = np.argmin(dist)
        dist = dist[i_nwc]

        if dist < dist_threshold:
            dist = haversine_np(amvs_brz[:, lon_idx], amvs_brz[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
            i_brz = np.argmin(dist)
            dist = dist[i_brz]

            if dist < dist_threshold:
                dist = haversine_np(amvs_jma[:, lon_idx], amvs_jma[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
                i_jma = np.argmin(dist)
                dist = dist[i_jma]

                if dist < dist_threshold:
                    dist = haversine_np(amvs_noa[:, lon_idx], amvs_noa[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
                    i_noa = np.argmin(dist)
                    dist = dist[i_noa]

                    if dist < dist_threshold:
                        dist = haversine_np(amvs_kma[:, lon_idx], amvs_kma[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
                        i_kma = np.argmin(dist)
                        dist = dist[i_kma]

                        if dist < dist_threshold:
                            eum_idxs.append(i)
                            brz_idxs.append(i_brz)
                            kma_idxs.append(i_kma)
                            jma_idxs.append(i_jma)
                            noa_idxs.append(i_noa)
                            nwc_idxs.append(i_nwc)

                            cnt += 1

    print('total num matches: ', cnt)

    eum_idxs = np.array(eum_idxs)
    brz_idxs = np.array(brz_idxs)
    kma_idxs = np.array(kma_idxs)
    jma_idxs = np.array(jma_idxs)
    noa_idxs = np.array(noa_idxs)
    nwc_idxs = np.array(nwc_idxs)

    amvs_dict = {}
    amvs_dict['EUM'] = amvs_eum[eum_idxs, :]
    amvs_dict['BRZ'] = amvs_brz[brz_idxs, :]
    amvs_dict['KMA'] = amvs_kma[kma_idxs, :]
    amvs_dict['JMA'] = amvs_jma[jma_idxs, :]
    amvs_dict['NOA'] = amvs_noa[noa_idxs, :]
    amvs_dict['NWC'] = amvs_nwc[nwc_idxs, :]

    return amvs_dict


def run_best_fit_all(amvs_dict, raobs_pathname, dist_threshold=200, min_num_levs=20):
    raob_dct = get_raob_dict(raobs_pathname)
    print('got raobs')

    bfs_dct = {}
    for key in amv_centers_list:
        amvs = amvs_dict.get(key)
        bfs = run_best_fit(amvs, raob_dct, dist_threshold=dist_threshold, min_num_levs=min_num_levs)
        print('done: ', key)
        bfs_dct[key] = bfs

    return bfs_dct


def best_fit(amv_spd, amv_dir, amv_prs, amv_lat, amv_lon, fcst_spd, fcst_dir, fcst_prs, verbose=False):

    """Finds the background model best fit pressure associated with the AMV.
       The model best-fit pressure is the height (in pressure units) where the
       vector difference between the observed AMV and model background is a
       minimum.  This calculation may only work approximately 1/3 of the time.

       Reference:
       Salonen et al (2012), "Characterising AMV height assignment error by
       comparing best-fit pressure statistics from the Met Office and ECMWF
       System."  Proceedings of the 11th International Winds Workshop,
       Auckland, New Zealand, 20-24 February 2012.

       Input contained in amv_data and fcst_data:
       amv_spd -  AMV speed m/s
       amv_dir -  AMV direction deg
       amv_prs -  AMV pressure hPa
       fcst_spd - (level) forecast speed m/s
       fcst_dir - (level) forecast direction (deg)
       fcst_prs - (level) forecast pressure (hPa)

       Output contained in bf_data:
       SatwindBestFitU - AMV best fit U component m, unconstrained value is undef
       SatwindBestFitV - AMV best fit V component m, unconstrained value is undef
       bfit_prs - AMV best fit pressure m/s, unconstrained value is undef
       flag - 0 found, 1 not contrained, 2 vec diff minimum not met, 3 failed to find suitable fcst pressure match

       History:
       10/2012 - Steve Wanzong - Created in Fortran
       10/2013 - Sharon Nebuda - rewritten for python
       10/2021 - Tom Rink - adapted from previous version to leverage more of NumPy
    """

    undef = -9999.0
    flag = 3
    bf_data = (undef, undef, undef, flag)

    fcst_num_levels = fcst_prs.shape[0]

    PressDiff = 150.0       # pressure above and below AMV to look for fit
    TopPress = 50.0         # highest level to allow search
    BotPress = fcst_prs[0]  # lowest level to allow search

    if amv_prs < TopPress:
        if verbose:
            print('AMV location lat,lon,prs ({0},{1},{2}) is higher than pressure {3}'.format(amv_lat, amv_lon, amv_prs, TopPress))
        return bf_data

# Calculate the pressure +/- 150 hPa from the AMV pressure.
    PressMax = min((amv_prs + PressDiff), BotPress)
    PressMin = max((amv_prs - PressDiff), TopPress)

# 1d array of indicies to consider for best fit location
    kk = np.where((fcst_prs < PressMax) & (fcst_prs > PressMin))
    if len(kk[0]) == 0:
        if verbose:
            print('AMV location lat,lon,prs ({0},{1},{2}) failed to find fcst prs around AMV'.format(amv_lat, amv_lon, amv_prs))
        return bf_data

# Diagnostic field: Find the forecast/model minimum speed and maximum speed within PressDiff of the AMV.
    if verbose:
        SatwindMinSpeed = min(fcst_spd[kk])
        SatwindMaxSpeed = max(fcst_spd[kk])

# Compute U anv V for both AMVs and forecast
    dr = 0.0174533  # pi/180
    amv_uwind = -amv_spd * np.sin(dr*amv_dir)
    amv_vwind = -amv_spd * np.cos(dr*amv_dir)
    fcst_uwind = -fcst_spd * np.sin(dr*fcst_dir)
    fcst_vwind = -fcst_spd * np.cos(dr*fcst_dir)

# Calculate the vector difference between the AMV and forecast/model background at all levels.
    VecDiff = np.sqrt((amv_uwind - fcst_uwind) ** 2 + (amv_vwind - fcst_vwind) ** 2)

# Find the model level of best-fit pressure, from the minimum vector difference.
    MinVecDiff = min(VecDiff[kk])
    imin = np.where(VecDiff == MinVecDiff)
    if imin[0].size == 0:
        if verbose:
            print('AMV location lat,lon,prs ({0},{1},{2}) failed to find the min vector difference in layers around AMV'.format(amv_lat, amv_lon, amv_prs))
        return bf_data

    imin = imin[0][0]  # index of min vec diff in absolute coordinates

# Use a parabolic fit to find the best-fit pressure.
# p2 - Minimized model pressure at level imin (hPa)
# v2 - Minimized vector difference at level imin (m/s)
# p1 - 1 pressure level lower in atmosphere than p2
# p3 - 1 pressure level above in atmosphere than p2
# v1 - Vector difference 1 pressure level lower than p2
# v3 - Vector difference 1 pressure level above than p2

    p2 = fcst_prs[imin]
    v2 = VecDiff[imin]
    SatwindBestFitPress = p2
    SatwindBestFitU = fcst_uwind[imin]
    SatwindBestFitV = fcst_vwind[imin]

# assumes fcst data level 0 at surface and (fcst_num_levels-1) at model top
# if bottom model level
    if imin == 0 or imin == fcst_num_levels - 1:
        if verbose:
            print('AMV location lat,lon,alt ({0},{1},{2}) min vector difference at top or bottom'.format(amv_lat, amv_lon, amv_prs))
    else:
        p3 = fcst_prs[imin+1]
        p1 = fcst_prs[imin-1]
        v3 = VecDiff[imin+1]
        v1 = VecDiff[imin-1]

        if p3 >= TopPress and v1 != v2 and v2 != v3:  # check not collinear, above allowed region
            SatwindBestFitPress = p2 - (0.5 *
            ((((p2 - p1) * (p2 - p1) * (v2 - v3)) - ((p2 - p3) * (p2 - p3) * (v2 - v1))) /
            (((p2 - p1) * (v2 - v3)) - ((p2 - p3) * (v2 - v1)))))
            if (SatwindBestFitPress < p3) or (SatwindBestFitPress > p1):
                if (verbose):
                    print('Best Fit not found between two pressure layers')
                    print('SatwindBestFitPress {0} p1 {1} p2 {2} p3 {3} imin {4}'.format(SatwindBestFitPress, p1, p2, p3, imin))
                SatwindBestFitPress = p2
        else:
            if verbose:
                print('Top of parabolic fit window above TopPress, or co-linear vector difference profile')
                print('SatwindBestFitPress {0} p1 {1} p2 {2} p3 {3} imin {4}'.format(SatwindBestFitPress, p1, p2, p3, imin))
                print('Vector difference profile p1 {0} p2 {1} p3 {2}'.format(v1, v2, v3))

        # Find best fit U and V by linear interpolation.
        if p2 != SatwindBestFitPress:
            if p2 < SatwindBestFitPress:
                LevBelow = imin - 1
                LevAbove = imin
                Prop = (SatwindBestFitPress - p1) / (p2 - p1)
            else:
                LevBelow = imin
                LevAbove = imin + 1
                Prop = (SatwindBestFitPress - p2) / (p3 - p2)

            SatwindBestFitU = fcst_uwind[LevBelow] * (1.0 - Prop) + fcst_uwind[LevAbove] * Prop
            SatwindBestFitV = fcst_vwind[LevBelow] * (1.0 - Prop) + fcst_vwind[LevAbove] * Prop

    # best fit success and good constraints -----
    SatwindGoodConstraint = 1
    flag = 0

    # Check for overall good agreement
    vdiff_best = np.sqrt((fcst_uwind - SatwindBestFitU) ** 2 + (fcst_vwind - SatwindBestFitV) ** 2)
    min_vdiff_best = min(vdiff_best)
    if min_vdiff_best > 4.0:
        SatwindGoodConstraint = 0
        flag = 2

    # Check for a substantial secondary local minimum or if local minimum is too broad
    if SatwindGoodConstraint == 1:
        mm = np.where(fcst_prs > (SatwindBestFitPress + 100))[0]
        nn = np.where(fcst_prs < (SatwindBestFitPress - 100))[0]
        if (np.sum(vdiff_best[mm] < (min_vdiff_best + 2.0)) + np.sum(vdiff_best[nn] < (min_vdiff_best + 2.0))) > 0:
            SatwindGoodConstraint = 0
            flag = 1

    if SatwindGoodConstraint == 1:
        bfit_prs = SatwindBestFitPress
        bfit_u = SatwindBestFitU
        bfit_v = SatwindBestFitV
    else:
        bfit_prs = undef
        bfit_u = undef
        bfit_v = undef

    if verbose:
        print('*** AMV best-fit *********************************')
        print('AMV -> p/minspd/maxspd: {0} {1} {2}'.format(amv_prs,SatwindMinSpeed,SatwindMaxSpeed))
        # print('Bestfit -> p1,p2,p3,v1,v2,v3: {0} {1} {2} {3} {4} {5}'.format(p1,p2,p3,v1,v2,v3))
        print('Bestfit -> pbest,bfu,bfv,amvu,amvv,bgu,bgv: {0} {1} {2} {3} {4} {5} {6}'.format(
        SatwindBestFitPress,SatwindBestFitU,SatwindBestFitV,amv_uwind,amv_vwind,fcst_uwind[imin],fcst_vwind[imin]))
        print('Good Constraint: {0}'.format(SatwindGoodConstraint))
        print('Minimum Vector Difference: {0}'.format(min_vdiff_best))
        print('Vector Difference Profile: ')
        print(vdiff_best)
        print('Pressure Profile: ')
        print(fcst_prs)

        if (abs(SatwindBestFitU - amv_uwind) > 4.0) or (abs(SatwindBestFitV - amv_vwind) > 4.0):
            print('U Diff: {0}'.format(abs(SatwindBestFitU - amv_uwind)))
            print('V Diff: {0}'.format(abs(SatwindBestFitV - amv_vwind)))

    bf_data = (bfit_u, bfit_v, bfit_prs, flag)

    return bf_data


def best_fit_altitude(amv_spd, amv_dir, amv_alt, amv_lat, amv_lon, fcst_spd, fcst_dir, fcst_alt,
                      amv_uwind=None, amv_vwind=None,
                      fcst_uwind=None, fcst_vwind=None,
                      alt_top=25000.0, alt_bot=0.0, bf_half_width=1000.0, constraint_half_width=1000.0, verbose=False):
    fcst_num_levels = fcst_alt.shape[0]

    flag = 3
    bf_tup = (np.nan, np.nan, np.nan, flag)
    if amv_alt > alt_top:
        if verbose:
            print('AMV location lat,lon,lat ({0},{1},{2}) is higher than max altitude {3}'.format(amv_lat, amv_lon, amv_alt, alt_top))
        return bf_tup

    # Calculate the height +/- alt_diff from the AMV height
    alt_max = min((amv_alt + bf_half_width), alt_top)
    alt_min = max((amv_lat - bf_half_width), alt_bot)

    # 1d array of indices to consider for best fit height
    kk = np.where((fcst_alt > alt_min) & (fcst_alt < alt_max))
    if kk[0].size == 0:
        if verbose:
            print('AMV location lat,lon,alt ({0},{1},{2}) failed to find fcst alt around AMV'.format(amv_lat,amv_lon,amv_alt))
        return bf_tup

    # Compute U anv V for both AMVs and forecast
    if amv_uwind is None:
        amv_uwind = -amv_spd * np.sin((np.pi/180.0)*amv_dir)
        amv_vwind = -amv_spd * np.cos((np.pi/180.0)*amv_dir)
    if fcst_uwind is None:
        fcst_uwind = -fcst_spd * np.sin((np.pi/180.0)*fcst_dir)
        fcst_vwind = -fcst_spd * np.cos((np.pi/180.0)*fcst_dir)
    else:
        fcst_spd = np.sqrt(fcst_uwind**2 + fcst_vwind**2)

    # Diagnostic field: Find the model minimum speed and maximum speed within PressDiff of the AMV.
    if verbose:
        sat_wind_min_spd = min(fcst_spd[kk])
        sat_wind_max_spd = max(fcst_spd[kk])

    # Calculate the vector difference between the AMV and model background at all levels.
    vec_diff = np.sqrt((amv_uwind - fcst_uwind) ** 2 + (amv_vwind - fcst_vwind) ** 2)

    # Find the model level of best-fit pressure, from the minimum vector difference.
    min_vec_diff = min(vec_diff[kk])
    imin = np.where(vec_diff == min_vec_diff)
    if imin[0].size == 0:
        if verbose:
            print('AMV location lat,lon,alt ({0},{1},{2}) failed to find min vector difference in layers around AMV'.format(amv_lat,amv_lon,amv_alt))
        return bf_tup

    imin = imin[0][0]

    a2 = fcst_alt[imin]
    v2 = vec_diff[imin]

    sat_wind_best_fit_alt = a2
    sat_wind_best_fit_u = fcst_uwind[imin]
    sat_wind_best_fit_v = fcst_vwind[imin]

    if imin == 0 or imin == (fcst_num_levels - 1):
        if verbose:
            print('AMV location lat,lon,alt ({0},{1},{2}) min vector difference at top/bottom'.format(amv_lat,amv_lon,amv_alt))
    else:
        a1 = fcst_alt[imin-1]
        v1 = vec_diff[imin-1]
        a3 = fcst_alt[imin+1]
        v3 = vec_diff[imin+1]

        # check not co-linear and below alt_max
        if a3 < alt_max and v1 != v2 and v2 != v3:
            sat_wind_best_fit_alt = minimize_quadratic(a1, a2, a3, v1, v2, v3)
            if sat_wind_best_fit_alt < a1 or sat_wind_best_fit_alt > a3:
                sat_wind_best_fit_alt = a2

        # Find best fit U and V by linear interpolation:
        if a2 != sat_wind_best_fit_alt:
            if a2 < sat_wind_best_fit_alt:
                lev_below = imin
                lev_above = imin + 1
                prop = (sat_wind_best_fit_alt - a2) / (a3 - a2)
            else:
                lev_below = imin - 1
                lev_above = imin
                prop = (sat_wind_best_fit_alt - a1) / (a2 - a1)

            sat_wind_best_fit_u = fcst_uwind[lev_below] * (1.0 - prop) + fcst_uwind[lev_above] * prop
            sat_wind_best_fit_v = fcst_vwind[lev_below] * (1.0 - prop) + fcst_vwind[lev_above] * prop

    # check contraints
    good_constraint = 1
    flag = 0
    # Check for overall good agreement: want vdiff less than 4 m/s
    vdiff = np.sqrt((fcst_uwind - sat_wind_best_fit_u) ** 2 + (fcst_vwind - sat_wind_best_fit_v) ** 2)
    min_vdiff = min(vdiff)
    if min_vdiff > 4.0:
        good_constraint = 0
        flag = 2

    # Check for a substantial secondary local minimum or if local minimum is too broad
    if good_constraint == 1:
        mm = np.where(fcst_alt > (sat_wind_best_fit_alt + constraint_half_width))[0]
        nn = np.where(fcst_alt < (sat_wind_best_fit_alt - constraint_half_width))[0]
        if (np.sum(vdiff[mm] < (min_vdiff + 2.0)) + np.sum(vdiff[nn] < (min_vdiff + 2.0))) > 0:
            good_constraint = 0
            flag = 1

    if good_constraint == 1:
        bf_tup = (sat_wind_best_fit_u, sat_wind_best_fit_v, sat_wind_best_fit_alt, flag)
    else:
        bf_tup = (np.nan, np.nan, np.nan, flag)

    if verbose:
        print('*** AMV best-fit ***')
        print('AMV -> p/minspd/maxspd: {0} {1} {2}'.format(amv_alt,sat_wind_min_spd,sat_wind_max_spd))
        # print('Bestfit -> p1,p2,p3,v1,v2,v3: {0} {1} {2} {3} {4} {5}'.format(p1,p2,p3,v1,v2,v3))
        print('Bestfit -> pbest,bfu,bfv,amvu,amvv,bgu,bgv: {0} {1} {2} {3} {4} {5} {6}'.format(
        sat_wind_best_fit_alt,sat_wind_best_fit_u,sat_wind_best_fit_v,amv_uwind,amv_vwind,fcst_uwind[imin],fcst_vwind[imin]))
        print('Good Constraint: {0}'.format(good_constraint))
        print('Minimum Vector Difference: {0}'.format(vec_diff[imin]))
        print('Vector Difference Profile: ')
        print(vdiff)
        print('Altitude Profile: ')
        print(fcst_alt)

        if (abs(sat_wind_best_fit_u - amv_uwind) > 4.0) or (abs(sat_wind_best_fit_v - amv_vwind) > 4.0):
            print('U Diff: {0}'.format(abs(sat_wind_best_fit_u - amv_uwind)))
            print('V Diff: {0}'.format(abs(sat_wind_best_fit_v - amv_vwind)))

    return bf_tup


def get_press_bin_ranges(lop, hip, bin_size=100):
    bin_ranges = []
    delp = hip - lop
    nbins = int(delp/bin_size)

    for i in range(nbins):
        rng = [lop + i*bin_size, lop + i*bin_size + bin_size]
        bin_ranges.append(rng)

    return bin_ranges


def analyze_best_fit_all(amvs_dct, bfs_dct, bin_size=200):
    pres_mad = []
    pres_bias = []
    spd_mad = []
    spd_bias = []
    dir_mad = []
    dir_bias = []

    for key in amv_centers_list:
        amvs = amvs_dct.get(key)
        bfs = bfs_dct.get(key)
        x_values, bin_pres, num_pres, bin_spd, num_spd, bin_dir, num_dir = analyze_best_fit_single(amvs, bfs, bin_size=bin_size)

        pres_mad_c = []
        pres_bias_c = []
        spd_mad_c = []
        spd_bias_c = []
        dir_mad_c = []
        dir_bias_c = []

        for i in range(len(x_values)):
            pres_mad_c.append(np.average(np.abs(bin_pres[i])))
            pres_bias_c.append(np.average(bin_pres[i]))
            spd_mad_c.append(np.average(np.abs(bin_spd[i])))
            spd_bias_c.append(np.average(bin_spd[i]))
            dir_mad_c.append(np.average(np.abs(bin_dir[i])))
            dir_bias_c.append(np.average(bin_dir[i]))

        pres_mad.append(pres_mad_c)
        pres_bias.append(pres_bias_c)
        spd_mad.append(spd_mad_c)
        spd_bias.append(spd_bias_c)
        dir_mad.append(dir_mad_c)
        dir_bias.append(dir_bias_c)

    return x_values, pres_mad, pres_bias, spd_mad, spd_bias, dir_mad, dir_bias


def analyze_best_fit_single(amvs, bfs, bin_size=200):
    bfs = np.array(bfs)

    keep_idxs = bfs[:, 0]
    keep_idxs = keep_idxs.astype(np.int32)

    amv_p = amvs[keep_idxs, amv_pres_idx]
    bf_p = bfs[:, 3]
    diff = amv_p - bf_p
    mad = np.average(np.abs(diff))
    bias = np.average(diff)
    print('********************************************************')
    print('num of best fits: ', bf_p.shape[0])
    print('press, MAD: ', mad)
    print('press, bias: ', bias)
    pd_std = np.std(diff)
    pd_mean = np.mean(diff)
    print('press bias/rms ', pd_mean, np.sqrt(pd_mean**2 + pd_std**2))
    print('------------------------------------------')

    bin_ranges = get_press_bin_ranges(50, 1050, bin_size=bin_size)
    bin_pres = bin_data_by(diff, amv_p, bin_ranges)

    amv_spd = amvs[keep_idxs, amv_spd_idx]
    amv_dir = amvs[keep_idxs, amv_dir_idx]
    bf_spd, bf_dir = spd_dir_from_uv(bfs[:, 1], bfs[:, 2])

    diff = amv_spd * units('m/s') - bf_spd
    spd_mad = np.average(np.abs(diff))
    spd_bias = np.average(diff)
    print('spd, MAD: ', spd_mad)
    print('spd, bias: ', spd_bias)
    spd_mean = np.mean(diff)
    spd_std = np.std(diff)
    print('spd MAD/bias/rms: ', np.average(np.abs(diff)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
    print('-----------------')
    bin_spd = bin_data_by(diff, amv_p, bin_ranges)

    amv_dir = amv_dir * units.deg
    dir_diff = direction_difference(amv_dir, bf_dir)
    print('dir, MAD: ', np.average(np.abs(dir_diff)))
    print('dir bias: ', np.average(dir_diff))
    print('-------------------------------------')
    bin_dir = bin_data_by(dir_diff, amv_p, bin_ranges)

    amv_u, amv_v = uv_from_spd_dir(amvs[keep_idxs, amv_spd_idx], amvs[keep_idxs, amv_dir_idx])
    u_diffs = amv_u - (bfs[:, 1] * units('m/s'))
    v_diffs = amv_v - (bfs[:, 2] * units('m/s'))

    vd = np.sqrt(u_diffs**2 + v_diffs**2)
    vd_mean = np.mean(vd)
    vd_std = np.std(vd)
    print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
    print('------------------------------------------')

    x_values = []
    num_pres = []
    num_spd = []
    num_dir = []
    for i in range(len(bin_ranges)):
        x_values.append(np.average(bin_ranges[i]))
        num_pres.append(bin_pres[i].shape[0])
        num_spd.append(bin_spd[i].shape[0])
        num_dir.append(bin_dir[i].shape[0])

    return x_values, bin_pres, num_pres, bin_spd, num_spd, bin_dir, num_dir


def bin_data_by(a, b, bin_ranges):
    nbins = len(bin_ranges)
    binned_data = []

    for i in range(nbins):
        rng = bin_ranges[i]
        idxs = (b >= rng[0]) & (b < rng[1])
        binned_data.append(a[idxs])

    return binned_data


# earth relative
def uv_from_spd_dir(speed, direction, speed_unit=None, dir_unit=None):
    if speed_unit is None:
        speed_unit = units('m/s')
    if dir_unit is None:
        dir_unit = units.deg

    u, v = mpcalc.wind_components(speed * speed_unit, direction * dir_unit)

    return u, v


def spd_dir_from_uv(u, v, speed_unit=None):
    if speed_unit is None:
        speed_unit = units('m/s')

    dir = mpcalc.wind_direction(u * speed_unit, v * speed_unit)
    spd = mpcalc.wind_speed(u * speed_unit, v * speed_unit)

    return spd, dir


# earth relative to projection
def uv_proj_from_uv_earth(u_earth, v_earth, lons, lats, proj, units=None):
    if units is None:
        units = units('m/s')

    u_proj, v_proj = proj.transform_vectors(ccrs.PlateCarree(), lons, lats, u_earth, v_earth)

    return u_proj, v_proj


# wind direction conversion
def azimuth_to_met(direction):
    direction += 180
    direction = np.where(direction < 360, direction, direction - 360.0)
    return direction