Skip to content
Snippets Groups Projects
Select Git revision
  • 48cf4e1320f87c075788e2e149ce4346bd9b7fd9
  • master default protected
  • use_flight_altitude
  • distribute
4 results

intercompare.py

Blame
  • user avatar
    tomrink authored
    c7c84f08
    History
    intercompare.py 47.40 KiB
    import numpy as np
    import xarray as xr
    # from util.setup import home_dir
    from util.lon_lat_grid import earth_to_indexs
    from util.util import haversine_np
    from util.gfs_reader import get_vert_profile_s
    # import cartopy
    from cartopy import *
    import metpy.calc as mpcalc
    import cartopy.crs as ccrs
    from metpy.units import units
    import h5py
    import time
    # from scipy.interpolate import interp1d
    from util.util import minimize_quadratic
    from netCDF4 import Dataset
    import datetime
    from datetime import timezone
    from util.util import LatLonTuple
    import re
    
    # -- AMV intercompare stuff ------------------------------------------
    # --------------------------------------------------------------------
    
    #amv_hdr_list = ['lat', 'lon', 'tbox', 'sbox', 'spd', 'dir', 'pres', 'lowl', 'mspd', 'mdir',
    #                'alb', 'corr', 'tmet', 'perr', 'qi', 'cqi', 'qif']
    amv_hdr_list = ['id', 'lat', 'lon', 'tbox', 'sbox', 'spd', 'dir', 'pres', 'lowl', 'mspd', 'mdir',
                    'alb', 'corr', 'tmet', 'perr', 'hmet', 'qi', 'cqi', 'qif']
    amv_lon_idx = amv_hdr_list.index('lon')
    amv_lat_idx = amv_hdr_list.index('lat')
    amv_pres_idx = amv_hdr_list.index('pres')
    amv_spd_idx = amv_hdr_list.index('spd')
    amv_dir_idx = amv_hdr_list.index('dir')
    amv_prs_idx = amv_hdr_list.index('pres')
    amv_qi_idx = amv_hdr_list.index('qi')
    amv_cqi_idx = amv_hdr_list.index('cqi')
    amv_qif_idx = amv_hdr_list.index('qif')
    
    raob_hdr_list = ['pres', 'temp', 'dir', 'spd']
    raob_spd_idx = raob_hdr_list.index('spd')
    raob_dir_idx = raob_hdr_list.index('dir')
    raob_pres_idx = raob_hdr_list.index('pres')
    
    amv_centers_list = ['EUM', 'BRZ', 'JMA', 'KMA', 'NOA', 'NWC']
    
    
    def get_amv_nd(filename, delimiter=';'):
        header = None
        data = []
    
        with open(filename) as file:
            for line in file:
                if header is None:
                    header = line
                    continue
                toks = line.split(delimiter)
                id = float(toks[0])
                lat = float(toks[1])
                lon = float(toks[2])
                tbox = float(toks[3])
                sbox = float(toks[4])
                spd = float(toks[5])
                dir = float(toks[6])
                pres = float(toks[7])
                lowl = float(toks[8])
                mspd = float(toks[9])
                mdir = float(toks[10])
                alb = float(toks[11])
                corr = float(toks[12])
                tmet = float(toks[13])
                perr = float(toks[14])
                hmet = float(toks[15])
                qi = float(toks[16])
                qif = float(toks[17])
                cqi = float(toks[18])
    
                if lon > 180:
                    lon -= 360
    
                if pres < 10 or pres > 1020:
                    continue
                if lat < -90 or lat > 90:
                    continue
                if lon < -180 or lon > 180:
                    continue
                if spd < 0 or spd > 150:
                    continue
                if dir < 0 or dir > 360:
                    continue
    
                #tup = (lat, lon, tbox, sbox, spd, dir, pres, lowl, mspd, mdir, alb, corr, tmet, perr, qi, cqi, qif)
                tup = (id, lat, lon, tbox, sbox, spd, dir, pres, lowl, mspd, mdir, alb, corr, tmet, perr, hmet, qi, cqi, qif)
    
                data.append(tup)
    
        return np.array(data)
    
    
    def filter_amvs(amvs, qitype=amv_cqi_idx, qival=50, lat_range=[-61, 61]):
        if lat_range is not None:
            keep = np.logical_and(amvs[:, amv_lat_idx] > lat_range[0], amvs[:, amv_lat_idx] < lat_range[1])
            if np.sum(keep) == 0:
                return None
            amvs = amvs[keep, :]
    
        if qitype is None:
            return amvs
    
        keep = amvs[:, qitype] >= qival
    
        if np.sum(keep) == 0:
            return None
        else:
            return amvs[keep, :]
    
    
    def get_raob_dict_cdf(pathname):
        rtgrp = Dataset(pathname, 'r', format='NETCDF3')
    
        wmo_stat_id_var = rtgrp['wmoStat']
    
        sta_lats_var = rtgrp['staLat']
        sta_lats_var.set_always_mask(False)
        sta_lats_var.set_auto_mask(False)
    
        sta_lons_var = rtgrp['staLon']
        sta_lons_var.set_always_mask(False)
        sta_lons_var.set_auto_mask(False)
    
        sta_lats = sta_lats_var[:]
        sta_lons = sta_lons_var[:]
        sta_ids = wmo_stat_id_var[:]
    
        sta_msk = np.logical_and(sta_lats > -90.0, sta_lats < 90.0)
        # sta_msk = np.invert(sta_lats._mask)
        num_sta = np.sum(sta_msk)
        sta_lats = sta_lats[sta_msk]
        sta_lons = sta_lons[sta_msk]
        sta_ids = sta_ids[sta_msk]
    
        # mandatory levels
        man_levs = rtgrp['prMan']
        man_levs = man_levs[:, :]
        num_man_levs = rtgrp['numMand']
        num_man_levs = num_man_levs[:]._data
        man_temp = rtgrp['tpMan']
        man_temp = man_temp[:, :]
        man_spd = rtgrp['wsMan']
        man_spd = man_spd[:, :]
        man_dir = rtgrp['wdMan']
        man_dir = man_dir[:, :]
    
        # significant levels
        sig_levs = rtgrp['prSigW']
        sig_levs = sig_levs[:, :]
        num_sig_levs = rtgrp['numSigW']
        num_sig_levs = num_sig_levs[:]._data
        sig_temp = rtgrp['tpSigW']
        sig_temp = sig_temp[:, :]
        sig_spd = rtgrp['wsSigW']
        sig_spd = sig_spd[:, :]
        sig_dir = rtgrp['wdSigW']
        sig_dir = sig_dir[:, :]
    
        man_levs = man_levs[sta_msk,]
        num_man_levs = num_man_levs[sta_msk]
    
        sig_levs = sig_levs[sta_msk,]
        num_sig_levs = num_sig_levs[sta_msk]
    
        man_temp = man_temp[sta_msk,]
        man_spd = man_spd[sta_msk,]
        man_dir = man_dir[sta_msk,]
    
        sig_temp = sig_temp[sta_msk,]
        sig_spd = sig_spd[sta_msk,]
        sig_dir = sig_dir[sta_msk,]
    
        raob_dct = {}
        for k in range(num_sta):
            lon = sta_lons[k]
            lat = sta_lats[k]
            id = sta_ids[k]
    
            vld_man = np.invert(man_levs[k]._mask)
            vld_man = np.logical_and(vld_man, np.invert(man_spd[k]._mask))
            vld_man = np.logical_and(vld_man, np.invert(man_dir[k]._mask))
    
            vld_sig = np.invert(sig_levs[k]._mask)
            vld_sig = np.logical_and(vld_sig, np.invert(sig_spd[k]._mask))
            vld_sig = np.logical_and(vld_sig, np.invert(sig_dir[k]._mask))
    
            vld_man_levs = man_levs[k, vld_man]
            vld_man_spd = man_spd[k, vld_man]
            vld_man_dir = man_dir[k, vld_man]
            vld_man_temp = man_temp[k, vld_man]
    
            s_idxs = np.argsort(vld_man_levs)
            s_man_levs = vld_man_levs[s_idxs]
            s_man_spd = vld_man_spd[s_idxs]
            s_man_dir = vld_man_dir[s_idxs]
            s_man_temp = vld_man_temp[s_idxs]
    
            vld_sig_levs = sig_levs[k, vld_sig]
            vld_sig_spd = sig_spd[k, vld_sig]
            vld_sig_dir = sig_dir[k, vld_sig]
            vld_sig_temp = sig_temp[k, vld_sig]
    
            all_levs = np.append(s_man_levs, vld_sig_levs)
            all_temp = np.append(s_man_temp, vld_sig_temp)
            all_spd = np.append(s_man_spd, vld_sig_spd)
            all_dir = np.append(s_man_dir, vld_sig_dir)
    
            if len(all_levs) == 0:
                continue
    
            s_idxs = np.argsort(all_levs)
            all_levs = all_levs[s_idxs]
            all_temp = all_temp[s_idxs]
            all_spd = all_spd[s_idxs]
            all_dir = all_dir[s_idxs]
    
            all_levs = all_levs[::-1]
            all_temp = all_temp[::-1]
            all_spd = all_spd[::-1]
            all_dir = all_dir[::-1]
    
            raob = np.stack([all_levs._data, all_temp._data, all_dir._data, all_spd._data], axis=1)
    
            raob_dct[LatLonTuple(lat=lat, lon=lon)] = [raob, id]
    
        rtgrp.close()
    
        return raob_dct
    
    
    def get_raob_dict(filename):
        header = None
        data = []
        raob_dict = {}
    
        with open(filename) as file:
            for line in file:
                if header is None:
                    header = line
                    continue
                toks = line.split()
    
                lat = float(toks[0])
                lon = float(toks[1])
                pres = float(toks[2])
                temp = float(toks[3])
                dir = float(toks[4])
                spd = float(toks[5])
    
                if pres < 0 or pres > 1050:
                    continue
                if lat < -90 or lat > 90:
                    continue
                if lon < -180 or lon > 180:
                    continue
                if spd < 0 or spd > 150:
                    continue
                if dir < 0 or dir > 360:
                    continue
    
                tup = (lat, lon, pres, temp, dir, spd)
                data.append(tup)
    
            file.close()
    
        key = None
        for tup in data:
            loc = (tup[0], tup[1])
            rtup = (tup[2], tup[3], tup[4], tup[5])
    
            if loc != key:
                lst = []
                lst.append(rtup)
                key = loc
                raob_dict[key] = lst
            else:
                lst = raob_dict.get(loc)
                lst.append(rtup)
    
        return raob_dict_to_nd(raob_dict)
    
    
    def raob_dict_to_nd(raob_dict):
        loc_keys = raob_dict.keys()
        loc_to_nd = {}
    
        for key in loc_keys:
            obs_lst = raob_dict.get(key)
            loc_to_nd[key] = np.array(obs_lst)
    
        return loc_to_nd
    
    
    def amv_bulk_stats(amvs, qitype=amv_cqi_idx, qival=50):
        idx = amv_hdr_list.index('lat')
        lat_min = amvs[:, idx].min()
        lat_max = amvs[:, idx].max()
        print(lat_min, lat_max)
    
        idx = amv_hdr_list.index('lon')
        lon_min = amvs[:, idx].min()
        lon_max = amvs[:, idx].max()
        print(lon_min, lon_max)
    
        good = amvs[:, qitype] >= qival
        num_good = np.sum(good)
        good_amvs = amvs[good, :]
    
        print('------------------------------')
        print('num good amvs: ', num_good)
    
        sidx = amv_hdr_list.index('spd')
        spd_min = good_amvs[:, sidx].min()
        spd_max = good_amvs[:, sidx].max()
        print('spd min/max/mean: ', spd_min, spd_max, np.average(good_amvs[:, sidx]))
    
        pidx = amv_hdr_list.index('pres')
        p_min = good_amvs[:, pidx].min()
        p_max = good_amvs[:, pidx].max()
        print('pres min/max/mean: ', p_min, p_max, np.average(good_amvs[:, pidx]))
    
        low = good_amvs[:, pidx] >= 700
        mid = np.logical_and(good_amvs[:, pidx] < 700, good_amvs[:, pidx] > 400)
        hgh = good_amvs[:, pidx] <= 400
    
        n_low = np.sum(low)
        n_mid = np.sum(mid)
        n_hgh = np.sum(hgh)
    
        print('% low: ', 100.0*(n_low/num_good))
        print('% mid: ', 100.0*(n_mid/num_good))
        print('% hgh: ', 100.0*(n_hgh/num_good))
        print('---------------------------')
    
        print('Low Spd min/max/mean: ', good_amvs[low, sidx].min(), good_amvs[low, sidx].max(), good_amvs[low,sidx].mean())
        print('Low Press min/max/mean: ', good_amvs[low, pidx].min(), good_amvs[low, pidx].max(), good_amvs[low, pidx].mean())
        print('---------------------------')
    
        print('Mid Spd min/max/mean: ', good_amvs[mid, sidx].min(), good_amvs[mid, sidx].max(), good_amvs[mid, sidx].mean())
        print('Mid Press min/max/mean: ', good_amvs[mid, pidx].min(), good_amvs[mid, pidx].max(), good_amvs[mid, pidx].mean())
        print('---------------------------')
    
        print('Hgh Spd min/max/mean: ', good_amvs[hgh, sidx].min(), good_amvs[hgh, sidx].max(), good_amvs[hgh, sidx].mean())
        print('Hgh Press min/max/mean: ', good_amvs[hgh, pidx].min(), good_amvs[hgh, pidx].max(), good_amvs[hgh, pidx].mean())
    
        print('----------------------------')
    
    
    press_diff_threshold = 50.0  # mb
    dist_threshold = 150.0  # km
    km_per_deg = 111.0
    
    
    # Find raobs within threshold distance for each (lat, lon) point
    def get_raob_neighbors(lats, lons, raob_dct, dist_threshold=150):
        nbor_dct = {}
    
        num_pts = lats.shape[0]
        keys = list(raob_dct.keys())
    
        rlat_s = []
        rlon_s = []
        for ll_tup in keys:
            rlat = ll_tup[0]
            rlon = ll_tup[1]
            rlat_s.append(rlat)
            rlon_s.append(rlon)
        rlats = np.array(rlat_s)
        rlons = np.array(rlon_s)
    
        for j in range(num_pts):
            alat = lats[j]
            alon = lons[j]
    
            dist = haversine_np(rlons, rlats, alon, alat)
            ridxs = (np.nonzero(dist < dist_threshold))[0]
    
            if ridxs.shape[0] != 0:
                # sort by distance
                dist = dist[ridxs]
                sidxs = np.argsort(dist)
                dist = dist[sidxs]
                ridxs = ridxs[sidxs]
    
                nbor_keys = []
                for i in range(ridxs.shape[0]):
                    nbor_keys.append(keys[ridxs[i]])
                nbor_dct[j] = (nbor_keys, dist)
            else:
                nbor_dct[j] = None
    
        return nbor_dct
    
    
    # Find points (lat, lon) within threshold distance to each raob
    def get_neighbors_to_raob(lats, lons, raob_dct, dist_threshold=150):
        nbor_dct = {}
    
        keys = list(raob_dct.keys())
    
        for ll_tup in keys:
            rlat = ll_tup[0]
            rlon = ll_tup[1]
    
            dist = haversine_np(lons, lats, rlon, rlat)
    
            idxs = (np.nonzero(dist < dist_threshold))[0]
    
            if idxs.shape[0] != 0:
                nbor_keys = []
                for i in range(idxs.shape[0]):
                    nbor_keys.append(idxs[i])
                nbor_dct[ll_tup] = (nbor_keys, dist[idxs])
            else:
                nbor_dct[ll_tup] = None
    
        return nbor_dct
    
    
    def amv_raob_diff_2(amvs, raob_dct):
        aspd = amvs[:, amv_spd_idx]
        adir = amvs[:, amv_dir_idx]
        aprs = amvs[:, amv_prs_idx]
        alat = amvs[:, amv_lat_idx]
        alon = amvs[:, amv_lon_idx]
    
        spd_diffs, dir_diffs = amv_raob_diff(alat, alon, aspd, adir, aprs, raob_dct)
    
        return spd_diffs, dir_diffs
    
    
    def amv_raob_diff(amv_lat, amv_lon, amv_spd, amv_dir, amv_press, raob_dct):
        num_amvs = amv_lat.shape[0]
    
        amv_u = -amv_spd * np.sin((np.pi/180.0) * amv_dir)
        amv_v = -amv_spd * np.cos((np.pi/180.0) * amv_dir)
    
        raob_nbor_dct = get_raob_neighbors(amv_lat, amv_lon, raob_dct)
    
        dir_diffs = []
        spd_diffs = []
        u_diffs = []
        v_diffs = []
        prs_diffs = []
    
        for j in range(num_amvs):  # compute stats
            tup = raob_nbor_dct[j]
            if tup is not None:
                nbor_keys = tup[0]
                # nbor_dsts = tup[1] future use maybe
    
                for ll_tup in nbor_keys:  # using only closest (see break below)
                    raob = raob_dct.get(ll_tup)
    
                    pdiff = amv_press[j] - raob[:, raob_pres_idx]
                    lev_idx = np.argmin(np.abs(pdiff))
                    if np.abs(pdiff[lev_idx]) > press_diff_threshold:
                        continue
    
                    aspd = amv_spd[j]
                    adir = amv_dir[j]
                    rdir = raob[lev_idx, raob_dir_idx]
                    rspd = raob[lev_idx, raob_spd_idx]
    
                    ddiff = direction_difference(adir, rdir)
                    dir_diffs.append(ddiff)
    
                    spd_diffs.append(aspd - rspd)
    
                    au = amv_u[j]
                    av = amv_v[j]
    
                    ru = -rspd * np.sin((np.pi / 180.0) * rdir)
                    rv = -rspd * np.cos((np.pi / 180.0) * rdir)
    
                    udiff = au - ru
                    vdiff = av - rv
    
                    u_diffs.append(udiff)
                    v_diffs.append(vdiff)
                    prs_diffs.append(pdiff[lev_idx])
    
                    break  # Use first (closest) raob only
    
        num = len(u_diffs)
        dir_diffs = np.array(dir_diffs)
        spd_diffs = np.array(spd_diffs)
        prs_diffs = np.array(prs_diffs)
        u_diffs = np.array(u_diffs)
        v_diffs = np.array(v_diffs)
    
        # spd_bias = np.sqrt(np.sum(u_diffs)**2 + np.sum(v_diffs)**2) / num
        # spd_rms = np.sqrt((np.sqrt(np.sum(np.array(u_diffs) ** 2)/num))**2 + (np.sqrt(np.sum(np.array(v_diffs) ** 2)/num))**2)
        # dir_bias = np.average(dir_diffs)
        print('num amvs: ', num_amvs)
        print('num samples: ', num)
    
        spd_mean = np.mean(spd_diffs)
        spd_std = np.std(spd_diffs)
        print('spd MAD/bias/rms: ', np.average(np.abs(spd_diffs)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
        print('-----------------')
    
        dir_mean = np.mean(dir_diffs)
        print('dir bias: ', dir_mean)
        print('-----------------------')
    
        vd = np.sqrt(u_diffs**2 + v_diffs**2)
        vd_mean = np.mean(vd)
        vd_std = np.std(vd)
        print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
        print('--------------')
    
        pd_std = np.std(prs_diffs)
        pd_mean = np.mean(prs_diffs)
        print('press bias/rms ', pd_mean, np.sqrt(pd_mean**2 + pd_std**2))
        print('------------------------------------------')
    
        return spd_diffs, dir_diffs
    
    
    def run_best_fit(amvs, raob_dct, dist_threshold=200, min_num_levs=20):
        num_amvs = amvs.shape[0]
        bestfits = []
    
        raob_nbor_dct = get_raob_neighbors(amvs[:, amv_lat_idx], amvs[:, amv_lon_idx], raob_dct, dist_threshold=dist_threshold)
    
        for j in range(num_amvs):
            tup = raob_nbor_dct[j]
            if tup is not None:
                nbor_keys = tup[0]
                nbor_dsts = tup[1]
                s_i = np.argsort(nbor_dsts).tolist()
    
                for k in range(len(nbor_keys)):
                    # dist = nbor_dsts[s_i[k]]
                    ll_tup = nbor_keys[s_i[k]]
                    raob = raob_dct.get(ll_tup)
    
                    if raob.shape[0] < min_num_levs:
                        continue
    
                    rspd = raob[:, raob_spd_idx]
                    rdir = raob[:, raob_dir_idx]
                    rprs = raob[:, raob_pres_idx]
    
                    aspd = amvs[j, amv_spd_idx]
                    adir = amvs[j, amv_dir_idx]
                    aprs = amvs[j, amv_prs_idx]
                    alat = amvs[j, amv_lat_idx]
                    alon = amvs[j, amv_lon_idx]
    
                    bf = best_fit(aspd, adir, aprs, alat, alon, rspd, rdir, rprs)
                    bf = (j, bf[0], bf[1], bf[2], bf[3])
                    if bf[4] == 0:
                        bestfits.append(bf)
                        break
    
        return bestfits
    
    
    def run_amv_clavr_compare(amvs, clvr_filename):
        h5file = h5py.File(clvr_filename, 'r')
    
        # Keep this for reference
        # lons = h5file['longitude'][0:2500, :]
        # lats = h5file['latitude'][0:2500, :]
        # lons = np.where(lons < 0, lons + 360, lons)
        # lons = np.where(lons == -639, np.nan, lons)
        # lats = np.where(lats == -999.0, np.nan, lats)
        # ll_grid = LonLatGrid(lons, lats)
    
        t0 = int(round(time.time() * 1000))
        t1 = int(round(time.time() * 1000))
    
        a_lons = []
        a_lats = []
        for k in range(5000):
            lon = amvs[k, amv_lon_idx]
            lat = amvs[k, amv_lat_idx]
            a_lons.append(lon)
            a_lats.append(lat)
        a_lons = np.array(a_lons)
        a_lats = np.array(a_lats)
        idxs = earth_to_indexs(a_lons, a_lats, 5500)
        print('search time: ', (int(round(time.time() * 1000)) - t1))
        print('--------------')
    
    
    def run_best_fit_gfs_all(amvs, gfs_filename):
        num_centers = len(amvs)
    
        bfs_s = []
        for k in range(num_centers):
            bfs = run_best_fit_gfs(amvs[k], gfs_filename)
            bfs_s.append(bfs)
    
        return bfs_s
    
    
    def run_best_fit_gfs(amvs, gfs_filename, amv_lat_idx=amv_lat_idx, amv_lon_idx=amv_lon_idx, amv_prs_idx=amv_prs_idx,
                         amv_spd_idx=amv_spd_idx, amv_dir_idx=amv_dir_idx, method='nearest'):
        xr_dataset = xr.open_dataset(gfs_filename)
        num_amvs = amvs.shape[0]
        bestfits = []
    
        gfs_press = xr_dataset['pressure levels']
        gfs_press = gfs_press.data
        gfs_press = gfs_press[::-1]
    
        uv_wind = get_vert_profile_s(xr_dataset, ['u-wind', 'v-wind'], amvs[:, amv_lon_idx], amvs[:, amv_lat_idx], method=method)
        uv_wind = uv_wind.data
        wspd, wdir = spd_dir_from_uv(uv_wind[0,:,:], uv_wind[1,:,:])
        wspd = wspd.magnitude
        wdir = wdir.magnitude
        wspd = wspd[:,::-1]
        wdir = wdir[:,::-1]
    
        for j in range(num_amvs):
    
            aspd = amvs[j, amv_spd_idx]
            adir = amvs[j, amv_dir_idx]
            aprs = amvs[j, amv_prs_idx]
            alat = amvs[j, amv_lat_idx]
            alon = amvs[j, amv_lon_idx]
    
            gfs_spd = wspd[j]
            gfs_dir = wdir[j]
    
            bf = best_fit(aspd, adir, aprs, alat, alon, gfs_spd, gfs_dir, gfs_press)
            # bf = (j, bf[0], bf[1], bf[2], bf[3])
            # if bf[4] == 0:
            #     bestfits.append(bf)
            bestfits.append(bf)
    
        return bestfits
    
    
    def raob_gfs_diff(raob_dct, gfs_filename):
        r_press = 900.0
        n_raobs = len(raob_dct)
        keys = list(raob_dct.keys())
        rb_lats = []
        rb_lons = []
        rb_spd = []
        rb_dir = []
    
        for ll_tup in keys:
    
            prof = raob_dct[ll_tup]
            raob_pres_prof = raob_dct[ll_tup][:, raob_pres_idx]
            pdiff = r_press - raob_pres_prof
            lev_idx = np.argmin(np.abs(pdiff))
            rb_lats.append(ll_tup[0])
            rb_lons.append(ll_tup[1])
            rb_spd.append(prof[lev_idx, raob_spd_idx])
            rb_dir.append(prof[lev_idx, raob_dir_idx])
    
        rb_lons = np.array(rb_lons)
        rb_lats = np.array(rb_lats)
        rb_spd = np.array(rb_spd)
        rb_dir = np.array(rb_dir)
    
    
        xr_dataset = xr.open_dataset(gfs_filename)
    
        gfs_press = xr_dataset['pressure levels']
        g_lev_idx = np.argmin(np.abs(500 - gfs_press))
    
        uv_wind = get_profile_multi(xr_dataset, ['u-wind', 'v-wind'], rb_lons, rb_lats)
        uv_wind = uv_wind.data
        wspd, wdir = spd_dir_from_uv(uv_wind[0,:,:], uv_wind[1,:,:])
        g_wspd = wspd.magnitude
        g_wdir = wdir.magnitude
    
        dir_diffs = []
        spd_diffs = []
        u_diffs = []
        v_diffs = []
    
        for j in range(n_raobs):
    
            rspd = rb_spd[j]
            rdir = rb_dir[j]
    
            gspd = g_wspd[j, g_lev_idx]
            gdir = g_wdir[j, g_lev_idx]
            print(rspd, gspd, ':', rdir, gdir)
    
            ddiff = direction_difference(gdir, rdir)
            dir_diffs.append(ddiff)
    
            spd_diffs.append(gspd - rspd)
    
            gu = -gspd * np.sin((np.pi / 180.0) * gdir)
            gv = -gspd * np.cos((np.pi / 180.0) * gdir)
    
            ru = -rspd * np.sin((np.pi / 180.0) * rdir)
            rv = -rspd * np.cos((np.pi / 180.0) * rdir)
    
            udiff = gu - ru
            vdiff = gv - rv
    
            u_diffs.append(udiff)
            v_diffs.append(vdiff)
    
    
        num = len(u_diffs)
        dir_diffs = np.array(dir_diffs)
        spd_diffs = np.array(spd_diffs)
        u_diffs = np.array(u_diffs)
        v_diffs = np.array(v_diffs)
    
        # spd_bias = np.sqrt(np.sum(u_diffs)**2 + np.sum(v_diffs)**2) / num
        # spd_rms = np.sqrt((np.sqrt(np.sum(np.array(u_diffs) ** 2)/num))**2 + (np.sqrt(np.sum(np.array(v_diffs) ** 2)/num))**2)
        # dir_bias = np.average(dir_diffs)
        print('num rabos: ', n_raobs)
        print('num samples: ', num)
    
        spd_mean = np.mean(spd_diffs)
        spd_std = np.std(spd_diffs)
        print('spd MAD/bias/rms: ', np.average(np.abs(spd_diffs)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
        print('-----------------')
    
        dir_mean = np.mean(dir_diffs)
        print('dir bias: ', dir_mean)
        print('-----------------------')
    
        vd = np.sqrt(u_diffs**2 + v_diffs**2)
        vd_mean = np.mean(vd)
        vd_std = np.std(vd)
        print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
        print('--------------')
    
        return None
    
    
    def amv_gfs_diff(amvs, gfs_filename):
        xr_dataset = xr.open_dataset(gfs_filename)
        num_amvs = amvs.shape[0]
    
        gfs_press = xr_dataset['pressure levels']
    
        uv_wind = get_profile_multi(xr_dataset, ['u-wind', 'v-wind'], amvs[:, amv_lon_idx], amvs[:, amv_lat_idx])
        uv_wind = uv_wind.data
        wspd, wdir = spd_dir_from_uv(uv_wind[0,:,:], uv_wind[1,:,:])
        wspd = wspd.magnitude
        wdir = wdir.magnitude
    
        amv_u = -amvs[:, amv_spd_idx] * np.sin((np.pi / 180.0) * amvs[:, amv_dir_idx])
        amv_v = -amvs[:, amv_spd_idx] * np.cos((np.pi / 180.0) * amvs[:, amv_dir_idx])
    
        dir_diffs = []
        spd_diffs = []
        u_diffs = []
        v_diffs = []
        prs_diffs = []
    
        for j in range(num_amvs):
    
            aspd = amvs[j, amv_spd_idx]
            adir = amvs[j, amv_dir_idx]
            aprs = amvs[j, amv_prs_idx]
    
            pdiff = aprs - gfs_press
            lev_idx = np.argmin(np.abs(pdiff))
            if np.abs(pdiff[lev_idx]) > press_diff_threshold:
                continue
    
            rspd = wspd[j, lev_idx]
            rdir = wdir[j, lev_idx]
    
            ddiff = direction_difference(adir, rdir)
            dir_diffs.append(ddiff)
    
            spd_diffs.append(aspd - rspd)
    
            au = amv_u[j]
            av = amv_v[j]
    
            ru = -rspd * np.sin((np.pi / 180.0) * rdir)
            rv = -rspd * np.cos((np.pi / 180.0) * rdir)
    
            udiff = au - ru
            vdiff = av - rv
    
            u_diffs.append(udiff)
            v_diffs.append(vdiff)
            prs_diffs.append(pdiff[lev_idx])
    
        num = len(u_diffs)
        dir_diffs = np.array(dir_diffs)
        spd_diffs = np.array(spd_diffs)
        prs_diffs = np.array(prs_diffs)
        u_diffs = np.array(u_diffs)
        v_diffs = np.array(v_diffs)
    
        # spd_bias = np.sqrt(np.sum(u_diffs)**2 + np.sum(v_diffs)**2) / num
        # spd_rms = np.sqrt((np.sqrt(np.sum(np.array(u_diffs) ** 2)/num))**2 + (np.sqrt(np.sum(np.array(v_diffs) ** 2)/num))**2)
        # dir_bias = np.average(dir_diffs)
        print('num amvs: ', num_amvs)
        print('num samples: ', num)
    
        spd_mean = np.mean(spd_diffs)
        spd_std = np.std(spd_diffs)
        print('spd MAD/bias/rms: ', np.average(np.abs(spd_diffs)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
        print('-----------------')
    
        dir_mean = np.mean(dir_diffs)
        print('dir bias: ', dir_mean)
        print('-----------------------')
    
        vd = np.sqrt(u_diffs**2 + v_diffs**2)
        vd_mean = np.mean(vd)
        vd_std = np.std(vd)
        print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
        print('--------------')
    
        pd_std = np.std(prs_diffs)
        pd_mean = np.mean(prs_diffs)
        print('press bias/rms ', pd_mean, np.sqrt(pd_mean**2 + pd_std**2))
        print('------------------------------------------')
    
        return None
    
    
    # ddiff = np.abs(adir - rdir)
    # if ddiff > 180:
    #     ddiff = 360 - ddiff
    #     if rdir < adir:
    #         ddiff = -1 * ddiff
    # else:
    #     if rdir > adir:
    #         ddiff = -1 * ddiff
    def direction_difference(dir_a, dir_b):
        diff = np.abs(dir_a - dir_b)
    
        d180 = diff > 180
    
        diff = np.where(d180, 360 - diff, diff)
    
        diff = np.where(d180, np.where(dir_b < dir_a, -1 * diff, diff), np.where(dir_b > dir_a, -1 * diff, diff))
    
        return diff
    
    
    def get_amv_winds_match(dist_threshold=50, qitype=amv_cqi_idx, qival=50, lat_range=[-61, 61]):
        #----- match all to EUM
        amvs_eum = get_amv_nd('/Users/tomrink/data/amv_intercompare/EUM321_csv.csv')
        amvs_brz = get_amv_nd('/Users/tomrink/data/amv_intercompare/BRZCPTECfin_121_csv.csv')
        amvs_jma = get_amv_nd('/Users/tomrink/data/amv_intercompare/JMA421_csv.csv')
        amvs_kma = get_amv_nd('/Users/tomrink/data/amv_intercompare/KMA521NI_csv.csv')
        amvs_noa = get_amv_nd('/Users/tomrink/data/amv_intercompare/NOA621_csv.csv')
        amvs_nwc = get_amv_nd('/Users/tomrink/data/amv_intercompare/NWC721_csv.csv')
    
        amvs_eum = filter_amvs(amvs_eum, qitype=qitype, qival=qival, lat_range=lat_range)
        amvs_brz = filter_amvs(amvs_brz, qitype=qitype, qival=qival, lat_range=lat_range)
        amvs_jma = filter_amvs(amvs_jma, qitype=qitype, qival=qival, lat_range=lat_range)
        amvs_kma = filter_amvs(amvs_kma, qitype=qitype, qival=qival, lat_range=lat_range)
        amvs_nwc = filter_amvs(amvs_nwc, qitype=qitype, qival=qival, lat_range=lat_range)
        amvs_noa = filter_amvs(amvs_noa, qitype=qitype, qival=qival, lat_range=lat_range)
    
        lat_idx = amv_hdr_list.index('lat')
        lon_idx = amv_hdr_list.index('lon')
    
        num_eum = amvs_eum.shape[0]
        print('num EUM amvs: ', num_eum)
    
        eum_idxs = []
        brz_idxs = []
        jma_idxs = []
        kma_idxs = []
        noa_idxs = []
        nwc_idxs = []
    
        cnt = 0
        for i in range(num_eum):
            dist = haversine_np(amvs_nwc[:, lon_idx], amvs_nwc[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
            i_nwc = np.argmin(dist)
            dist = dist[i_nwc]
    
            if dist < dist_threshold:
                dist = haversine_np(amvs_brz[:, lon_idx], amvs_brz[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
                i_brz = np.argmin(dist)
                dist = dist[i_brz]
    
                if dist < dist_threshold:
                    dist = haversine_np(amvs_jma[:, lon_idx], amvs_jma[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
                    i_jma = np.argmin(dist)
                    dist = dist[i_jma]
    
                    if dist < dist_threshold:
                        dist = haversine_np(amvs_noa[:, lon_idx], amvs_noa[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
                        i_noa = np.argmin(dist)
                        dist = dist[i_noa]
    
                        if dist < dist_threshold:
                            dist = haversine_np(amvs_kma[:, lon_idx], amvs_kma[:, lat_idx], amvs_eum[i, lon_idx], amvs_eum[i, lat_idx])
                            i_kma = np.argmin(dist)
                            dist = dist[i_kma]
    
                            if dist < dist_threshold:
                                eum_idxs.append(i)
                                brz_idxs.append(i_brz)
                                kma_idxs.append(i_kma)
                                jma_idxs.append(i_jma)
                                noa_idxs.append(i_noa)
                                nwc_idxs.append(i_nwc)
    
                                cnt += 1
    
        print('total num matches: ', cnt)
    
        eum_idxs = np.array(eum_idxs)
        brz_idxs = np.array(brz_idxs)
        kma_idxs = np.array(kma_idxs)
        jma_idxs = np.array(jma_idxs)
        noa_idxs = np.array(noa_idxs)
        nwc_idxs = np.array(nwc_idxs)
    
        amvs_dict = {}
        amvs_dict['EUM'] = amvs_eum[eum_idxs, :]
        amvs_dict['BRZ'] = amvs_brz[brz_idxs, :]
        amvs_dict['KMA'] = amvs_kma[kma_idxs, :]
        amvs_dict['JMA'] = amvs_jma[jma_idxs, :]
        amvs_dict['NOA'] = amvs_noa[noa_idxs, :]
        amvs_dict['NWC'] = amvs_nwc[nwc_idxs, :]
    
        return amvs_dict
    
    
    def run_best_fit_all(amvs_dict, raobs_pathname, dist_threshold=200, min_num_levs=20):
        raob_dct = get_raob_dict(raobs_pathname)
        print('got raobs')
    
        bfs_dct = {}
        for key in amv_centers_list:
            amvs = amvs_dict.get(key)
            bfs = run_best_fit(amvs, raob_dct, dist_threshold=dist_threshold, min_num_levs=min_num_levs)
            print('done: ', key)
            bfs_dct[key] = bfs
    
        return bfs_dct
    
    
    def best_fit(amv_spd, amv_dir, amv_prs, amv_lat, amv_lon, fcst_spd, fcst_dir, fcst_prs, verbose=False):
    
        """Finds the background model best fit pressure associated with the AMV.
           The model best-fit pressure is the height (in pressure units) where the
           vector difference between the observed AMV and model background is a
           minimum.  This calculation may only work approximately 1/3 of the time.
    
           Reference:
           Salonen et al (2012), "Characterising AMV height assignment error by
           comparing best-fit pressure statistics from the Met Office and ECMWF
           System."  Proceedings of the 11th International Winds Workshop,
           Auckland, New Zealand, 20-24 February 2012.
    
           Input contained in amv_data and fcst_data:
           amv_spd -  AMV speed m/s
           amv_dir -  AMV direction deg
           amv_prs -  AMV pressure hPa
           fcst_spd - (level) forecast speed m/s
           fcst_dir - (level) forecast direction (deg)
           fcst_prs - (level) forecast pressure (hPa)
    
           Output contained in bf_data:
           SatwindBestFitU - AMV best fit U component m, unconstrained value is undef
           SatwindBestFitV - AMV best fit V component m, unconstrained value is undef
           bfit_prs - AMV best fit pressure m/s, unconstrained value is undef
           flag - 0 found, 1 not contrained, 2 vec diff minimum not met, 3 failed to find suitable fcst pressure match
    
           History:
           10/2012 - Steve Wanzong - Created in Fortran
           10/2013 - Sharon Nebuda - rewritten for python
           10/2021 - Tom Rink - adapted from previous version to leverage more of NumPy
        """
    
        undef = -9999.0
        flag = 3
        bf_data = (undef, undef, undef, flag)
    
        fcst_num_levels = fcst_prs.shape[0]
    
        PressDiff = 150.0       # pressure above and below AMV to look for fit
        TopPress = 50.0         # highest level to allow search
        BotPress = fcst_prs[0]  # lowest level to allow search
    
        if amv_prs < TopPress:
            if verbose:
                print('AMV location lat,lon,prs ({0},{1},{2}) is higher than pressure {3}'.format(amv_lat, amv_lon, amv_prs, TopPress))
            return bf_data
    
    # Calculate the pressure +/- 150 hPa from the AMV pressure.
        PressMax = min((amv_prs + PressDiff), BotPress)
        PressMin = max((amv_prs - PressDiff), TopPress)
    
    # 1d array of indicies to consider for best fit location
        kk = np.where((fcst_prs < PressMax) & (fcst_prs > PressMin))
        if len(kk[0]) == 0:
            if verbose:
                print('AMV location lat,lon,prs ({0},{1},{2}) failed to find fcst prs around AMV'.format(amv_lat, amv_lon, amv_prs))
            return bf_data
    
    # Diagnostic field: Find the forecast/model minimum speed and maximum speed within PressDiff of the AMV.
        if verbose:
            SatwindMinSpeed = min(fcst_spd[kk])
            SatwindMaxSpeed = max(fcst_spd[kk])
    
    # Compute U anv V for both AMVs and forecast
        dr = 0.0174533  # pi/180
        amv_uwind = -amv_spd * np.sin(dr*amv_dir)
        amv_vwind = -amv_spd * np.cos(dr*amv_dir)
        fcst_uwind = -fcst_spd * np.sin(dr*fcst_dir)
        fcst_vwind = -fcst_spd * np.cos(dr*fcst_dir)
    
    # Calculate the vector difference between the AMV and forecast/model background at all levels.
        VecDiff = np.sqrt((amv_uwind - fcst_uwind) ** 2 + (amv_vwind - fcst_vwind) ** 2)
    
    # Find the model level of best-fit pressure, from the minimum vector difference.
        MinVecDiff = min(VecDiff[kk])
        imin = np.where(VecDiff == MinVecDiff)
        if imin[0].size == 0:
            if verbose:
                print('AMV location lat,lon,prs ({0},{1},{2}) failed to find the min vector difference in layers around AMV'.format(amv_lat, amv_lon, amv_prs))
            return bf_data
    
        imin = imin[0][0]  # index of min vec diff in absolute coordinates
    
    # Use a parabolic fit to find the best-fit pressure.
    # p2 - Minimized model pressure at level imin (hPa)
    # v2 - Minimized vector difference at level imin (m/s)
    # p1 - 1 pressure level lower in atmosphere than p2
    # p3 - 1 pressure level above in atmosphere than p2
    # v1 - Vector difference 1 pressure level lower than p2
    # v3 - Vector difference 1 pressure level above than p2
    
        p2 = fcst_prs[imin]
        v2 = VecDiff[imin]
        SatwindBestFitPress = p2
        SatwindBestFitU = fcst_uwind[imin]
        SatwindBestFitV = fcst_vwind[imin]
    
    # assumes fcst data level 0 at surface and (fcst_num_levels-1) at model top
    # if bottom model level
        if imin == 0 or imin == fcst_num_levels - 1:
            if verbose:
                print('AMV location lat,lon,alt ({0},{1},{2}) min vector difference at top or bottom'.format(amv_lat, amv_lon, amv_prs))
        else:
            p3 = fcst_prs[imin+1]
            p1 = fcst_prs[imin-1]
            v3 = VecDiff[imin+1]
            v1 = VecDiff[imin-1]
    
            if p3 >= TopPress and v1 != v2 and v2 != v3:  # check not collinear, above allowed region
                SatwindBestFitPress = p2 - (0.5 *
                ((((p2 - p1) * (p2 - p1) * (v2 - v3)) - ((p2 - p3) * (p2 - p3) * (v2 - v1))) /
                (((p2 - p1) * (v2 - v3)) - ((p2 - p3) * (v2 - v1)))))
                if (SatwindBestFitPress < p3) or (SatwindBestFitPress > p1):
                    if (verbose):
                        print('Best Fit not found between two pressure layers')
                        print('SatwindBestFitPress {0} p1 {1} p2 {2} p3 {3} imin {4}'.format(SatwindBestFitPress, p1, p2, p3, imin))
                    SatwindBestFitPress = p2
            else:
                if verbose:
                    print('Top of parabolic fit window above TopPress, or co-linear vector difference profile')
                    print('SatwindBestFitPress {0} p1 {1} p2 {2} p3 {3} imin {4}'.format(SatwindBestFitPress, p1, p2, p3, imin))
                    print('Vector difference profile p1 {0} p2 {1} p3 {2}'.format(v1, v2, v3))
    
            # Find best fit U and V by linear interpolation.
            if p2 != SatwindBestFitPress:
                if p2 < SatwindBestFitPress:
                    LevBelow = imin - 1
                    LevAbove = imin
                    Prop = (SatwindBestFitPress - p1) / (p2 - p1)
                else:
                    LevBelow = imin
                    LevAbove = imin + 1
                    Prop = (SatwindBestFitPress - p2) / (p3 - p2)
    
                SatwindBestFitU = fcst_uwind[LevBelow] * (1.0 - Prop) + fcst_uwind[LevAbove] * Prop
                SatwindBestFitV = fcst_vwind[LevBelow] * (1.0 - Prop) + fcst_vwind[LevAbove] * Prop
    
        # best fit success and good constraints -----
        SatwindGoodConstraint = 1
        flag = 0
    
        # Check for overall good agreement
        vdiff_best = np.sqrt((fcst_uwind - SatwindBestFitU) ** 2 + (fcst_vwind - SatwindBestFitV) ** 2)
        min_vdiff_best = min(vdiff_best)
        if min_vdiff_best > 4.0:
            SatwindGoodConstraint = 0
            flag = 2
    
        # Check for a substantial secondary local minimum or if local minimum is too broad
        if SatwindGoodConstraint == 1:
            mm = np.where(fcst_prs > (SatwindBestFitPress + 100))[0]
            nn = np.where(fcst_prs < (SatwindBestFitPress - 100))[0]
            if (np.sum(vdiff_best[mm] < (min_vdiff_best + 2.0)) + np.sum(vdiff_best[nn] < (min_vdiff_best + 2.0))) > 0:
                SatwindGoodConstraint = 0
                flag = 1
    
        if SatwindGoodConstraint == 1:
            bfit_prs = SatwindBestFitPress
            bfit_u = SatwindBestFitU
            bfit_v = SatwindBestFitV
        else:
            bfit_prs = undef
            bfit_u = undef
            bfit_v = undef
    
        if verbose:
            print('*** AMV best-fit *********************************')
            print('AMV -> p/minspd/maxspd: {0} {1} {2}'.format(amv_prs,SatwindMinSpeed,SatwindMaxSpeed))
            # print('Bestfit -> p1,p2,p3,v1,v2,v3: {0} {1} {2} {3} {4} {5}'.format(p1,p2,p3,v1,v2,v3))
            print('Bestfit -> pbest,bfu,bfv,amvu,amvv,bgu,bgv: {0} {1} {2} {3} {4} {5} {6}'.format(
            SatwindBestFitPress,SatwindBestFitU,SatwindBestFitV,amv_uwind,amv_vwind,fcst_uwind[imin],fcst_vwind[imin]))
            print('Good Constraint: {0}'.format(SatwindGoodConstraint))
            print('Minimum Vector Difference: {0}'.format(min_vdiff_best))
            print('Vector Difference Profile: ')
            print(vdiff_best)
            print('Pressure Profile: ')
            print(fcst_prs)
    
            if (abs(SatwindBestFitU - amv_uwind) > 4.0) or (abs(SatwindBestFitV - amv_vwind) > 4.0):
                print('U Diff: {0}'.format(abs(SatwindBestFitU - amv_uwind)))
                print('V Diff: {0}'.format(abs(SatwindBestFitV - amv_vwind)))
    
        bf_data = (bfit_u, bfit_v, bfit_prs, flag)
    
        return bf_data
    
    
    def best_fit_altitude(amv_spd, amv_dir, amv_alt, amv_lat, amv_lon, fcst_spd, fcst_dir, fcst_alt,
                          amv_uwind=None, amv_vwind=None,
                          fcst_uwind=None, fcst_vwind=None,
                          alt_top=25000.0, alt_bot=0.0, bf_half_width=1000.0, constraint_half_width=1000.0, verbose=False):
        fcst_num_levels = fcst_alt.shape[0]
    
        flag = 3
        bf_tup = (np.nan, np.nan, np.nan, flag)
        if amv_alt > alt_top:
            if verbose:
                print('AMV location lat,lon,lat ({0},{1},{2}) is higher than max altitude {3}'.format(amv_lat, amv_lon, amv_alt, alt_top))
            return bf_tup
    
        # Calculate the height +/- alt_diff from the AMV height
        alt_max = min((amv_alt + bf_half_width), alt_top)
        alt_min = max((amv_lat - bf_half_width), alt_bot)
    
        # 1d array of indices to consider for best fit height
        kk = np.where((fcst_alt > alt_min) & (fcst_alt < alt_max))
        if kk[0].size == 0:
            if verbose:
                print('AMV location lat,lon,alt ({0},{1},{2}) failed to find fcst alt around AMV'.format(amv_lat,amv_lon,amv_alt))
            return bf_tup
    
        # Compute U anv V for both AMVs and forecast
        if amv_uwind is None:
            amv_uwind = -amv_spd * np.sin((np.pi/180.0)*amv_dir)
            amv_vwind = -amv_spd * np.cos((np.pi/180.0)*amv_dir)
        if fcst_uwind is None:
            fcst_uwind = -fcst_spd * np.sin((np.pi/180.0)*fcst_dir)
            fcst_vwind = -fcst_spd * np.cos((np.pi/180.0)*fcst_dir)
        else:
            fcst_spd = np.sqrt(fcst_uwind**2 + fcst_vwind**2)
    
        # Diagnostic field: Find the model minimum speed and maximum speed within PressDiff of the AMV.
        if verbose:
            sat_wind_min_spd = min(fcst_spd[kk])
            sat_wind_max_spd = max(fcst_spd[kk])
    
        # Calculate the vector difference between the AMV and model background at all levels.
        vec_diff = np.sqrt((amv_uwind - fcst_uwind) ** 2 + (amv_vwind - fcst_vwind) ** 2)
    
        # Find the model level of best-fit pressure, from the minimum vector difference.
        min_vec_diff = min(vec_diff[kk])
        imin = np.where(vec_diff == min_vec_diff)
        if imin[0].size == 0:
            if verbose:
                print('AMV location lat,lon,alt ({0},{1},{2}) failed to find min vector difference in layers around AMV'.format(amv_lat,amv_lon,amv_alt))
            return bf_tup
    
        imin = imin[0][0]
    
        a2 = fcst_alt[imin]
        v2 = vec_diff[imin]
    
        sat_wind_best_fit_alt = a2
        sat_wind_best_fit_u = fcst_uwind[imin]
        sat_wind_best_fit_v = fcst_vwind[imin]
    
        if imin == 0 or imin == (fcst_num_levels - 1):
            if verbose:
                print('AMV location lat,lon,alt ({0},{1},{2}) min vector difference at top/bottom'.format(amv_lat,amv_lon,amv_alt))
        else:
            a1 = fcst_alt[imin-1]
            v1 = vec_diff[imin-1]
            a3 = fcst_alt[imin+1]
            v3 = vec_diff[imin+1]
    
            # check not co-linear and below alt_max
            if a3 < alt_max and v1 != v2 and v2 != v3:
                sat_wind_best_fit_alt = minimize_quadratic(a1, a2, a3, v1, v2, v3)
                if sat_wind_best_fit_alt < a1 or sat_wind_best_fit_alt > a3:
                    sat_wind_best_fit_alt = a2
    
            # Find best fit U and V by linear interpolation:
            if a2 != sat_wind_best_fit_alt:
                if a2 < sat_wind_best_fit_alt:
                    lev_below = imin
                    lev_above = imin + 1
                    prop = (sat_wind_best_fit_alt - a2) / (a3 - a2)
                else:
                    lev_below = imin - 1
                    lev_above = imin
                    prop = (sat_wind_best_fit_alt - a1) / (a2 - a1)
    
                sat_wind_best_fit_u = fcst_uwind[lev_below] * (1.0 - prop) + fcst_uwind[lev_above] * prop
                sat_wind_best_fit_v = fcst_vwind[lev_below] * (1.0 - prop) + fcst_vwind[lev_above] * prop
    
        # check contraints
        good_constraint = 1
        flag = 0
        # Check for overall good agreement: want vdiff less than 4 m/s
        vdiff = np.sqrt((fcst_uwind - sat_wind_best_fit_u) ** 2 + (fcst_vwind - sat_wind_best_fit_v) ** 2)
        min_vdiff = min(vdiff)
        if min_vdiff > 4.0:
            good_constraint = 0
            flag = 2
    
        # Check for a substantial secondary local minimum or if local minimum is too broad
        if good_constraint == 1:
            mm = np.where(fcst_alt > (sat_wind_best_fit_alt + constraint_half_width))[0]
            nn = np.where(fcst_alt < (sat_wind_best_fit_alt - constraint_half_width))[0]
            if (np.sum(vdiff[mm] < (min_vdiff + 2.0)) + np.sum(vdiff[nn] < (min_vdiff + 2.0))) > 0:
                good_constraint = 0
                flag = 1
    
        if good_constraint == 1:
            bf_tup = (sat_wind_best_fit_u, sat_wind_best_fit_v, sat_wind_best_fit_alt, flag)
        else:
            bf_tup = (np.nan, np.nan, np.nan, flag)
    
        if verbose:
            print('*** AMV best-fit ***')
            print('AMV -> p/minspd/maxspd: {0} {1} {2}'.format(amv_alt,sat_wind_min_spd,sat_wind_max_spd))
            # print('Bestfit -> p1,p2,p3,v1,v2,v3: {0} {1} {2} {3} {4} {5}'.format(p1,p2,p3,v1,v2,v3))
            print('Bestfit -> pbest,bfu,bfv,amvu,amvv,bgu,bgv: {0} {1} {2} {3} {4} {5} {6}'.format(
            sat_wind_best_fit_alt,sat_wind_best_fit_u,sat_wind_best_fit_v,amv_uwind,amv_vwind,fcst_uwind[imin],fcst_vwind[imin]))
            print('Good Constraint: {0}'.format(good_constraint))
            print('Minimum Vector Difference: {0}'.format(vec_diff[imin]))
            print('Vector Difference Profile: ')
            print(vdiff)
            print('Altitude Profile: ')
            print(fcst_alt)
    
            if (abs(sat_wind_best_fit_u - amv_uwind) > 4.0) or (abs(sat_wind_best_fit_v - amv_vwind) > 4.0):
                print('U Diff: {0}'.format(abs(sat_wind_best_fit_u - amv_uwind)))
                print('V Diff: {0}'.format(abs(sat_wind_best_fit_v - amv_vwind)))
    
        return bf_tup
    
    
    def get_press_bin_ranges(lop, hip, bin_size=100):
        bin_ranges = []
        delp = hip - lop
        nbins = int(delp/bin_size)
    
        for i in range(nbins):
            rng = [lop + i*bin_size, lop + i*bin_size + bin_size]
            bin_ranges.append(rng)
    
        return bin_ranges
    
    
    def analyze_best_fit_all(amvs_dct, bfs_dct, bin_size=200):
        pres_mad = []
        pres_bias = []
        spd_mad = []
        spd_bias = []
        dir_mad = []
        dir_bias = []
    
        for key in amv_centers_list:
            amvs = amvs_dct.get(key)
            bfs = bfs_dct.get(key)
            x_values, bin_pres, num_pres, bin_spd, num_spd, bin_dir, num_dir = analyze_best_fit_single(amvs, bfs, bin_size=bin_size)
    
            pres_mad_c = []
            pres_bias_c = []
            spd_mad_c = []
            spd_bias_c = []
            dir_mad_c = []
            dir_bias_c = []
    
            for i in range(len(x_values)):
                pres_mad_c.append(np.average(np.abs(bin_pres[i])))
                pres_bias_c.append(np.average(bin_pres[i]))
                spd_mad_c.append(np.average(np.abs(bin_spd[i])))
                spd_bias_c.append(np.average(bin_spd[i]))
                dir_mad_c.append(np.average(np.abs(bin_dir[i])))
                dir_bias_c.append(np.average(bin_dir[i]))
    
            pres_mad.append(pres_mad_c)
            pres_bias.append(pres_bias_c)
            spd_mad.append(spd_mad_c)
            spd_bias.append(spd_bias_c)
            dir_mad.append(dir_mad_c)
            dir_bias.append(dir_bias_c)
    
        return x_values, pres_mad, pres_bias, spd_mad, spd_bias, dir_mad, dir_bias
    
    
    def analyze_best_fit_single(amvs, bfs, bin_size=200):
        bfs = np.array(bfs)
    
        keep_idxs = bfs[:, 0]
        keep_idxs = keep_idxs.astype(np.int32)
    
        amv_p = amvs[keep_idxs, amv_pres_idx]
        bf_p = bfs[:, 3]
        diff = amv_p - bf_p
        mad = np.average(np.abs(diff))
        bias = np.average(diff)
        print('********************************************************')
        print('num of best fits: ', bf_p.shape[0])
        print('press, MAD: ', mad)
        print('press, bias: ', bias)
        pd_std = np.std(diff)
        pd_mean = np.mean(diff)
        print('press bias/rms ', pd_mean, np.sqrt(pd_mean**2 + pd_std**2))
        print('------------------------------------------')
    
        bin_ranges = get_press_bin_ranges(50, 1050, bin_size=bin_size)
        bin_pres = bin_data_by(diff, amv_p, bin_ranges)
    
        amv_spd = amvs[keep_idxs, amv_spd_idx]
        amv_dir = amvs[keep_idxs, amv_dir_idx]
        bf_spd, bf_dir = spd_dir_from_uv(bfs[:, 1], bfs[:, 2])
    
        diff = amv_spd * units('m/s') - bf_spd
        spd_mad = np.average(np.abs(diff))
        spd_bias = np.average(diff)
        print('spd, MAD: ', spd_mad)
        print('spd, bias: ', spd_bias)
        spd_mean = np.mean(diff)
        spd_std = np.std(diff)
        print('spd MAD/bias/rms: ', np.average(np.abs(diff)), spd_mean, np.sqrt(spd_mean**2 + spd_std**2))
        print('-----------------')
        bin_spd = bin_data_by(diff, amv_p, bin_ranges)
    
        amv_dir = amv_dir * units.deg
        dir_diff = direction_difference(amv_dir, bf_dir)
        print('dir, MAD: ', np.average(np.abs(dir_diff)))
        print('dir bias: ', np.average(dir_diff))
        print('-------------------------------------')
        bin_dir = bin_data_by(dir_diff, amv_p, bin_ranges)
    
        amv_u, amv_v = uv_from_spd_dir(amvs[keep_idxs, amv_spd_idx], amvs[keep_idxs, amv_dir_idx])
        u_diffs = amv_u - (bfs[:, 1] * units('m/s'))
        v_diffs = amv_v - (bfs[:, 2] * units('m/s'))
    
        vd = np.sqrt(u_diffs**2 + v_diffs**2)
        vd_mean = np.mean(vd)
        vd_std = np.std(vd)
        print('VD bias/rms: ', vd_mean, np.sqrt(vd_mean**2 + vd_std**2))
        print('------------------------------------------')
    
        x_values = []
        num_pres = []
        num_spd = []
        num_dir = []
        for i in range(len(bin_ranges)):
            x_values.append(np.average(bin_ranges[i]))
            num_pres.append(bin_pres[i].shape[0])
            num_spd.append(bin_spd[i].shape[0])
            num_dir.append(bin_dir[i].shape[0])
    
        return x_values, bin_pres, num_pres, bin_spd, num_spd, bin_dir, num_dir
    
    
    def bin_data_by(a, b, bin_ranges):
        nbins = len(bin_ranges)
        binned_data = []
    
        for i in range(nbins):
            rng = bin_ranges[i]
            idxs = (b >= rng[0]) & (b < rng[1])
            binned_data.append(a[idxs])
    
        return binned_data
    
    
    # earth relative
    def uv_from_spd_dir(speed, direction, speed_unit=None, dir_unit=None):
        if speed_unit is None:
            speed_unit = units('m/s')
        if dir_unit is None:
            dir_unit = units.deg
    
        u, v = mpcalc.wind_components(speed * speed_unit, direction * dir_unit)
    
        return u, v
    
    
    def spd_dir_from_uv(u, v, speed_unit=None):
        if speed_unit is None:
            speed_unit = units('m/s')
    
        dir = mpcalc.wind_direction(u * speed_unit, v * speed_unit)
        spd = mpcalc.wind_speed(u * speed_unit, v * speed_unit)
    
        return spd, dir
    
    
    # earth relative to projection
    def uv_proj_from_uv_earth(u_earth, v_earth, lons, lats, proj, units=None):
        if units is None:
            units = units('m/s')
    
        u_proj, v_proj = proj.transform_vectors(ccrs.PlateCarree(), lons, lats, u_earth, v_earth)
    
        return u_proj, v_proj
    
    
    # wind direction conversion
    def azimuth_to_met(direction):
        direction += 180
        direction = np.where(direction < 360, direction, direction - 360.0)
        return direction