Skip to content
Snippets Groups Projects
Select Git revision
  • e98c5a7ed35905a2cec70b29a36577c680cb28a4
  • master default protected
  • use_flight_altitude
  • distribute
4 results

amv_raob.py

Blame
  • user avatar
    tomrink authored
    fe421e35
    History
    amv_raob.py 34.98 KiB
    import datetime
    from datetime import timezone
    import bisect
    import glob
    import numpy as np
    import subprocess
    import xarray as xr
    from util.goesr_l1b import PugFile, PugL1bTools
    from util.setup import home_dir
    from pysolar.solar import get_altitude
    from scipy.interpolate import interp2d
    #import cartopy
    #from cartopy import *
    from metpy import *
    #import cartopy.crs as ccrs
    import random
    import re
    
    
    goes16_directory = '/arcdata/goes/grb/goes16'  # /year/date/abi/L1b/RadC
    goes_date_format = '%Y%j%H'
    dt_str_pat = '_s\d{11}'
    dir_fmt = '%Y_%m_%d_%j'
    
    gfs_directory = '/apollo/cloud/Ancil_Data/clavrx_ancil_data/dynamic/gfs/'
    gfs_date_format = '%y%m%d'
    
    h4_to_h5_path = home_dir + '/h4toh5convert'
    
    data_dir = '/scratch/long/rink'
    converted_file_dir = data_dir + '/gfs_h5'
    CACHE_GFS = True
    
    goes_cache_dir = data_dir + '/goes16'
    CACHE_GOES = True
    COPY_GOES = True
    
    fmt = '%Y%j%H'
    
    matchup_dict_glbl = None
    
    geoloc_2km = None
    geoloc_1km = None
    geoloc_hkm = None
    
    TRIPLET = False
    CONV3D = False
    
    H08_lat_range = [-78, 78]
    H08_lon_range = [62, 218]
    
    
    amv_hdr_list = ['lat', 'lon', 'tbox', 'sbox', 'spd', 'dir', 'pres', 'lowl', 'mspd', 'mdir',
                    'alb', 'corr', 'tmet', 'perr', 'qi', 'cqi', 'qif']
    
    raob_hdr_list = ['pres', 'temp', 'dir', 'spd']
    
    raob_spd_idx = raob_hdr_list.index('spd')
    raob_dir_idx = raob_hdr_list.index('dir')
    amv_lon_idx = amv_hdr_list.index('lon')
    amv_lat_idx = amv_hdr_list.index('lat')
    amv_pres_idx = amv_hdr_list.index('pres')
    raob_pres_idx = raob_hdr_list.index('pres')
    amv_spd_idx = amv_hdr_list.index('spd')
    amv_dir_idx = amv_hdr_list.index('dir')
    amv_prs_idx = amv_hdr_list.index('pres')
    amv_qi_idx = amv_hdr_list.index('qi')
    amv_cqi_idx = amv_hdr_list.index('cqi')
    amv_qif_idx = amv_hdr_list.index('qif')
    
    
    def get_matchups(filename):
        header = None
        data = []
        with open(filename) as file:
            for line in file:
                if header is None:
                    header = line
                    continue
                toks = line.split()
    
                dto = datetime.datetime.strptime(toks[1], fmt).replace(tzinfo=timezone.utc)
                timestmp = dto.timestamp()
    
                lat = float(toks[2])/100
                lon = float(toks[3])/100
                lbfp = float(toks[4])/100
                amvpre = float(toks[5])/100
                spd = float(toks[6])/100
                dir = float(toks[7])/100
                shear = float(toks[8])/100
                satzen = float(toks[9])/100
                qif = float(toks[11])/100
                phase = float(toks[13])/100
                type = float(toks[14])/100
    
                if qif < 50:
                    continue
    
                tup = (timestmp, lat, lon, lbfp, amvpre, spd, dir, shear, satzen, qif, phase, type)
                data.append(tup)
    
            file.close()
    
        return data
    
    
    def merge(matchups, time_to_matchups):
        if matchups is None:
            return time_to_matchups
    
        if time_to_matchups is None:
            tup = matchups[0]
            time = tup[0]
            tup_s = []
            time_to_matchups = {time : tup_s}
    
        time_s = list(time_to_matchups.keys())
        tup_s = list(time_to_matchups.values())
    
        for tup in matchups:
            time = tup[0]
            try:
                idx = time_s.index(time)
                tup_s[idx].append(tup)
            except ValueError:
                idx = bisect.bisect(time_s, time)
                time_s.insert(idx, time)
                lst = []
                lst.append(tup)
                tup_s.insert(idx, lst)
    
        time_to_matchups = {}
        for idx in range(len(time_s)):
            # time_to_matchups[time_s[idx]] = tup_s[idx]
            time_to_matchups[idx] = tup_s[idx]
    
        return time_to_matchups
    
    
    def matchup_dict_to_nd(matchup_dict, shuffle=True, reduce=None):
        time_keys = matchup_dict.keys()
        time_to_nd = {}
    
        for key in time_keys:
            obs_lst = matchup_dict.get(key)
            nda = np.array(obs_lst)
            if shuffle:
                np.random.shuffle(nda)
            if reduce is not None:
                nda = nda[::reduce, :]
            time_to_nd[key] = nda
    
        return time_to_nd
    
    
    def get_num_samples(matchup_dict, keys=None):
        cnt = 0
    
        if keys is None:
            time_keys = matchup_dict.keys()
        else:
            time_keys = keys
    
        for key in time_keys:
            cnt += matchup_dict[key].shape[0]
        return cnt
    
    
    # CONUS
    lon_space = np.linspace(-140, -40, 401)
    lat_space = np.linspace(15, 55, 161)
    
    
    def spatial_filter(matchup_dict, shuffle=True, reduce=None):
        keys = matchup_dict.keys()
        filterd_dict = {}
        grd_x_hi = lon_space.shape[0] - 1
        grd_y_hi = lat_space.shape[0] - 1
    
        for key in keys:
    
            grd_bins = np.full((lat_space.shape[0], lon_space.shape[0]), False)
    
            obs = matchup_dict.get(key)
            lons = obs[:, 2]
            lats = obs[:, 1]
    
            lon_idxs = np.searchsorted(lon_space, lons)
            lat_idxs = np.searchsorted(lat_space, lats)
    
            keep_obs = []
            for i in range(len(lon_idxs)):
                lon_idx = lon_idxs[i]
                lat_idx = lat_idxs[i]
                if lon_idx < 0 or lon_idx > grd_x_hi:
                    continue
                if lat_idx < 0 or lat_idx > grd_y_hi:
                    continue
                if not grd_bins[lat_idx, lon_idx]:
                    grd_bins[lat_idx, lon_idx] = True
                    keep_obs.append(obs[i])
    
            if len(keep_obs) == 0:
                continue
    
            keep_obs = np.array(keep_obs)
            if shuffle:
                np.random.seed(seed=42)
                shuffle_idxs = np.random.permutation(keep_obs.shape[0])
                keep_obs = keep_obs[shuffle_idxs]
    
            if reduce is not None:
                keep_obs = keep_obs[::reduce, :]
    
            filterd_dict[key] = keep_obs
    
        return filterd_dict
    
    
    def get_matchup_dict(filename, merge_with=None, filter_by_space=True, reduce=None, skip=None):
        matchup = get_matchups(filename)
        matchup_dict = merge(matchup, merge_with)
    
        if skip is not None:
            skp_dct = {}
            skp_keys = list(matchup_dict.keys())
            random.shuffle(skp_keys)
            skp_keys = skp_keys[::skip]
            for key in skp_keys:
                skp_dct[key] = matchup_dict[key]
            matchup_dict = skp_dct
    
        matchup_dict = matchup_dict_to_nd(matchup_dict)
    
        if filter_by_space:
            matchup_dict = spatial_filter(matchup_dict, reduce=reduce)
    
        return matchup_dict
    
    
    NUM_SECTIONS = 4
    
    
    def split_matchup(matchup_dict, perc=0.2):
        time_keys = matchup_dict.keys()
        num_datetimes = len(time_keys)
        keys_list = list(time_keys)
    
        num_test = int(num_datetimes * perc)
    
        skip = int(num_datetimes / num_test)
    
        tst_keys = keys_list[::skip]
        train_dict = {}
        test_dict = {}
    
        tst_set = set(tst_keys)
        trn_set = (set(keys_list)).difference(tst_set)
        trn_keys = list(trn_set)
    
        for key in tst_keys:
            test_dict[key] = matchup_dict.get(key)
    
        for key in trn_keys:
            train_dict[key] = matchup_dict.get(key)
    
        return train_dict, test_dict
    
    
    def flatten(matchup_dict, shuffle=True):
        time_keys = matchup_dict.keys()
    
        dat_list = []
        cnt = 0
        for key in time_keys:
            dat = matchup_dict[key]
            num_dat = dat.shape[0]
            dat_list.append(dat)
            cnt += num_dat
    
        data = np.vstack(dat_list)
    
        if shuffle:
            np.random.seed(seed=42)
            shuffle_idxs = np.random.permutation(cnt)
            data = data[shuffle_idxs]
    
        return data
    
    
    def shuffle_dict(in_dict):  # Outer most level only
        shuffled_dict = {}
        keys = in_dict.keys()
        num_times = len(keys)
        np.random.seed(seed=42)
        shuffle_idxs = np.random.permutation(num_times)
    
        key_list = list(keys)
    
        for idx in shuffle_idxs:
            key = key_list[idx]
            shuffled_dict[key] = in_dict[key]
    
        return shuffled_dict
    
    
    def filter_by_type(matchup_dict, error=None, phase=None, ctype=None):
        time_keys = list(matchup_dict.keys())
        filtered_dict = {}
    
        for key in time_keys:
            amvs = matchup_dict[key]
            keep = None
            if error is not None:
                mae = np.abs(amvs[:, 4] - amvs[:, 3])
                keep = mae > error
                amvs = amvs[keep, :]
            if phase is not None:
                keep = amvs[:, 10] == phase
                amvs = amvs[keep, :]
            elif ctype is not None:
                keep = amvs[:, 11] == ctype
                amvs = amvs[keep, :]
    
            if keep is not None:
                if np.sum(keep) == 0:
                    continue
    
            filtered_dict[key] = amvs
    
        return filtered_dict
    
    
    def analyze(matchup_dict):
        time_keys = list(matchup_dict.keys())
    
        hgt_diff = []
        hgt_diff_i = []
        hgt_diff_w = []
        hgt_diff_scw = []
        hgt_diff_mphz = []
        hgts = []
        hgts_i = []
        hgts_w = []
        hgts_scw = []
        hgts_mphz = []
        num_amvs = 0
        num_ice = 0
        num_water = 0
        num_sc_water = 0
        num_mphz = 0
    
        hgt_diff_tice = []
        hgts_tice = []
        num_tice = 0
    
        hgt_diff_cirrus = []
        hgts_cirrus = []
        num_cirrus = 0
        hgt_diff_ovrlap = []
        hgts_ovrlap = []
        num_ovrlap = 0
    
        for key in time_keys:
            amvs = matchup_dict[key]
            # num_amvs += amvs.shape[0]
            d = amvs[:, 4] - amvs[:, 3]
    
            # keep = np.abs(d) > 70
            keep = np.abs(d) >= 0
            d = d[keep]
            amvs = amvs[keep, :]
            num_amvs += amvs.shape[0]
    
            hgt_diff.append(d)
            hgts.append(amvs[:, 3])
    
            ice = amvs[:, 10] == 4
            hgt_diff_i.append(d[ice])
            hgts_i.append(amvs[ice, 3])
            num_ice += np.sum(ice)
    
            water = amvs[:, 10] == 1
            hgt_diff_w.append(d[water])
            hgts_w.append(amvs[water, 3])
            num_water += np.sum(water)
    
            water = amvs[:, 10] == 2
            hgt_diff_scw.append(d[water])
            hgts_scw.append(amvs[water, 3])
            num_sc_water += np.sum(water)
    
            mphz = amvs[:, 10] == 3
            hgt_diff_mphz.append(d[mphz])
            hgts_mphz.append(amvs[mphz, 3])
            num_mphz += np.sum(mphz)
    
            tice = amvs[:, 11] == 5
            hgt_diff_tice.append(d[tice])
            hgts_tice.append(amvs[tice, 3])
            num_tice += np.sum(tice)
    
            cirrus = amvs[:, 11] == 6
            hgt_diff_cirrus.append(d[cirrus])
            hgts_cirrus.append(amvs[cirrus, 3])
            num_cirrus += np.sum(cirrus)
    
            ovrlap = amvs[:, 11] == 7
            hgt_diff_ovrlap.append(d[ovrlap])
            hgts_ovrlap.append(amvs[ovrlap, 3])
            num_ovrlap += np.sum(ovrlap)
    
        hgt_diff = np.concatenate(hgt_diff)
        hgt_diff_i = np.concatenate(hgt_diff_i)
        hgt_diff_w = np.concatenate(hgt_diff_w)
        hgt_diff_scw = np.concatenate(hgt_diff_scw)
        hgt_diff_tice = np.concatenate(hgt_diff_tice)
        hgt_diff_cirrus = np.concatenate(hgt_diff_cirrus)
        hgt_diff_mphz = np.concatenate(hgt_diff_mphz)
        hgt_diff_ovrlap = np.concatenate(hgt_diff_ovrlap)
    
        hgts = np.concatenate(hgts)
        hgts_i = np.concatenate(hgts_i)
        hgts_w = np.concatenate(hgts_w)
        hgts_scw = np.concatenate(hgts_scw)
        hgts_tice = np.concatenate(hgts_tice)
        hgts_cirrus = np.concatenate(hgts_cirrus)
        hgts_ovrlap = np.concatenate(hgts_ovrlap)
    
        print('total: ', num_amvs)
        print('ice: ', num_ice, num_ice/num_amvs)
        print('water: ', num_water, num_water/num_amvs)
        print('sc water: ', num_sc_water, num_sc_water/num_amvs)
        print('mixed phase: ', num_mphz, num_mphz/num_amvs)
        print('tice: ', num_tice, num_tice/num_amvs)
        print('cirrus: ', num_cirrus, num_cirrus/num_amvs)
        print('num ovrlap: ', num_ovrlap, num_ovrlap/num_amvs)
    
        print('all MAE/BIAS/STD', np.average(np.abs(hgt_diff)), np.average(hgt_diff), np.std(hgt_diff))
        print('ice MAE/BIAS/STD', np.average(np.abs(hgt_diff_i)), np.average(hgt_diff_i), np.std(hgt_diff_i))
        print('water MAE/BIAS/STD', np.average(np.abs(hgt_diff_w)), np.average(hgt_diff_w), np.std(hgt_diff_w))
        print('sc water MAE/BIAS/STD', np.average(np.abs(hgt_diff_scw)), np.average(hgt_diff_scw), np.std(hgt_diff_scw))
        print('tice MAE/BIAS/STD', np.average(np.abs(hgt_diff_tice)), np.average(hgt_diff_tice), np.std(hgt_diff_tice))
        print('cirrus MAE/BIAS/STD', np.average(np.abs(hgt_diff_cirrus)), np.average(hgt_diff_cirrus), np.std(hgt_diff_cirrus))
        print('mphz MAE/BIAS/STD', np.average(np.abs(hgt_diff_mphz)), np.average(hgt_diff_mphz), np.std(hgt_diff_mphz))
        print('ovrlap MAE/BIAS/STD', np.average(np.abs(hgt_diff_ovrlap)), np.average(hgt_diff_ovrlap), np.std(hgt_diff_ovrlap))
    
        return hgts, hgt_diff, hgt_diff_i, hgt_diff_w, hgt_diff_scw, hgt_diff_tice, hgt_diff_cirrus, hgt_diff_mphz, hgt_diff_ovrlap
    
    
    def get_time_tuple_utc(timestamp):
        dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
        return dt_obj, dt_obj.timetuple()
    
    
    def get_bounding_goes16_files(timestamp, ch_str, file_time_span=5):
        dt_obj, time_tup = get_time_tuple_utc(timestamp)
    
        dt_obj_0 = dt_obj - datetime.timedelta(minutes=20)
        dt_obj_1 = dt_obj + datetime.timedelta(minutes=20)
    
        yr_dir_0 = str(dt_obj_0.timetuple().tm_year)
        date_dir_0 = dt_obj_0.strftime(dir_fmt)
        date_str_0 = dt_obj_0.strftime(goes_date_format)
    
        yr_dir_1 = str(dt_obj_1.timetuple().tm_year)
        date_dir_1 = dt_obj_1.strftime(dir_fmt)
        date_str_1 = dt_obj_1.strftime(goes_date_format)
    
        files_path_0 = goes16_directory + '/' + yr_dir_0 + '/' + date_dir_0 + '/abi' + '/L1b' + '/RadC'
        files_path_1 = goes16_directory + '/' + yr_dir_1 + '/' + date_dir_1 + '/abi' + '/L1b' + '/RadC'
    
        flist_0 = glob.glob(files_path_0 + '/OR_ABI-L1b-RadC-??C' + ch_str + '_G16_s' + date_str_0 + '*.nc')
        flist_1 = []
        if date_str_0 != date_str_1:
            flist_1 = glob.glob(files_path_1 + '/OR_ABI-L1b-RadC-??C' + ch_str + '_G16_s' + date_str_1 + '*.nc')
    
        flist = flist_0 + flist_1
        if len(flist) == 0:
            return None, None, None, None, None, None
    
        ftimes = []
        for pname in flist:
            fname = os.path.split(pname)[1]
            tstr = re.search(dt_str_pat, fname).group()[2:]
            dto = datetime.datetime.strptime(tstr, goes_date_format+'%M').replace(tzinfo=timezone.utc)
            ftimes.append(dto.timestamp())
    
        tarr = np.array(ftimes)
        sidxs = tarr.argsort()
    
        farr = np.array(flist)
        farr = farr[sidxs]
        ftimes = tarr[sidxs]
    
        ftimes_n = ftimes + file_time_span * 60
        iC = -1
        for k, t in enumerate(ftimes):
            if t <= timestamp < ftimes_n[k]:
                iC = k
                break
    
        if iC >= 0:
            iL = iC - 1
            iR = iC + 1
        else:
            return None, None, None, None, None, None
    
        fList = farr.tolist()
    
        if iL < 0 or iR >= len(ftimes):
            return None, None, fList[iC], ftimes[iC], None, None
        else:
            return fList[iL], ftimes[iL], fList[iC], ftimes[iC], fList[iR], ftimes[iR]
    
    
    def get_bounding_gfs_files(timestamp):
        dt_obj, time_tup = get_time_tuple_utc(timestamp)
        dt_obj = dt_obj - datetime.timedelta(hours=12)
        date_str = dt_obj.strftime(gfs_date_format)
        dt_obj0 = datetime.datetime.strptime(date_str, gfs_date_format).replace(tzinfo=timezone.utc)
        dt_obj1 = dt_obj0 + datetime.timedelta(days=1)
        date_str_1 = dt_obj1.strftime(gfs_date_format)
    
        flist_0 = glob.glob(gfs_directory+'gfs.'+date_str+'??_F012.hdf')
        flist_1 = glob.glob(gfs_directory+'gfs.'+date_str_1+'00_F012.hdf')
        if len(flist_0) == 0:
            return None, None, None, None
        filelist = flist_0
        if len(flist_1) > 0:
            filelist = flist_0 + flist_1
    
        ftimes = []
        for pname in filelist:  # TODO: make better with regular expressions (someday)
            fname = os.path.split(pname)[1]
            toks = fname.split('.')
            tstr = toks[1].split('_')[0]
            dto = datetime.datetime.strptime(tstr, gfs_date_format+'%H').replace(tzinfo=timezone.utc)
            dto = dto + datetime.timedelta(hours=12)
            ftimes.append(dto.timestamp())
    
        tarr = np.array(ftimes)
        sidxs = tarr.argsort()
    
        farr = np.array(filelist)
        farr = farr[sidxs]
        ftimes = tarr[sidxs]
        idxs = np.arange(len(filelist))
    
        above = ftimes >= timestamp
        if not above.any():
            return None, None, None, None
    
        below = ftimes <= timestamp
        if not below.any():
            return None, None, None, None
        iL = idxs[below].max()
        iR = iL + 1
    
        fList = farr.tolist()
    
        if timestamp == ftimes[iL]:
            return fList[iL], ftimes[iL], None, None
        else:
            return fList[iL], ftimes[iL], fList[iR], ftimes[iR]
    
    
    def get_profile(xr_dataset, fld_name, lons, lats, lon360=True, do_norm=False):
        if lon360:
            lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360
    
        plevs = xr_dataset['pressure levels']
        lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
        lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)
    
        fld = xr_dataset[fld_name]
        fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)
    
        dim2 = xr.DataArray(lons, dims='k')
        dim1 = xr.DataArray(lats, dims='k')
    
        intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method='linear')
        intrp_fld = intrp_fld.values
        if do_norm:
            intrp_fld = normalize(intrp_fld, fld_name)
    
        return intrp_fld
    
    
    def get_profile_multi(xr_dataset, fld_name_s, lons, lats, lon360=True):
        if lon360:
            lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360
    
        plevs = xr_dataset['pressure levels']
        lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
        lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)
    
        fld_s = []
        for fld_name in fld_name_s:
            fld = xr_dataset[fld_name]
            fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)
            fld_s.append(fld)
    
        fld = xr.concat(fld_s, 'fld_dim')
    
        dim2 = xr.DataArray(lons, dims='k')
        dim1 = xr.DataArray(lats, dims='k')
    
        intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method='linear')
    
        return intrp_fld
    
    
    def get_scalar(xr_dataset, fld_name, lons, lats, lon360=True):
        if lon360:
            lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360
    
        lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
        lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)
    
        fld = xr_dataset[fld_name]
        fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords)
    
        dim1 = xr.DataArray(lats, dims='k')
        dim2 = xr.DataArray(lons, dims='k')
    
        intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, method='linear')
    
        return intrp_fld
    
    
    def convert_file(fpath_hdf4):
        fname_h4 = os.path.split(fpath_hdf4)[1]
        fname_h5 = os.path.splitext(fname_h4)[0] + '.h5'
        fpath_h5 = converted_file_dir + '/' + fname_h5
    
        # check if we have it already
        if os.path.exists(fpath_h5):
            return fpath_h5
    
        subprocess.call([h4_to_h5_path, fpath_hdf4, fpath_h5])
    
        return fpath_h5
    
    
    def get_interpolated_profile(ds_0, ds_1, time0, time1, fld_name, time, lons, lats, do_norm=False):
    
        prof_0 = get_profile(ds_0, fld_name, lons, lats)
        prof_1 = get_profile(ds_1, fld_name, lons, lats)
    
        prof = xr.concat([prof_0, prof_1], 'time')
        prof = prof.assign_coords(time=[time0, time1])
    
        intrp_prof = prof.interp(time=time, method='linear')
    
        intrp_prof = intrp_prof.values
    
        if do_norm:
            intrp_prof = normalize(intrp_prof, fld_name)
    
        return intrp_prof
    
    
    def get_interpolated_profile_multi(ds_0, ds_1, time0, time1, fld_name_s, time, lons, lats, do_norm_s):
    
        prof_0 = get_profile_multi(ds_0, fld_name_s, lons, lats)
        prof_1 = get_profile_multi(ds_1, fld_name_s, lons, lats)
    
        prof = xr.concat([prof_0, prof_1], 'time')
        prof = prof.assign_coords(time=[time0, time1])
    
        intrp_prof = prof.interp(time=time, method='linear')
    
        intrp_prof = intrp_prof.values
    
        intrp_prof = normalize_multi(intrp_prof, fld_name_s, do_norm_s)
    
        return intrp_prof
    
    
    def get_interpolated_scalar(ds_0, ds_1, time0, time1, fld_name, time, lons, lats, do_norm=False):
    
        vals_0 = get_scalar(ds_0, fld_name, lons, lats)
        vals_1 = get_scalar(ds_1, fld_name, lons, lats)
    
        vals = xr.concat([vals_0, vals_1], 'time')
        vals = vals.assign_coords(time=[time0, time1])
    
        intrp_vals = vals.interp(time=time, method='linear')
    
        intrp_vals = intrp_vals.values
    
        if do_norm:
            intrp_vals = normalize(intrp_vals, fld_name)
        return intrp_vals
    
    
    path_to_reader_dict = {}
    
    MAX_FILES_OPEN = 50
    
    
    def get_reader(path):
        if len(path_to_reader_dict) > MAX_FILES_OPEN:
            path_to_reader_dict.clear()
    
        rdr = path_to_reader_dict.get(path)
        if rdr is None:
            rdr = PugL1bTools(path)
            path_to_reader_dict[path] = rdr
        return rdr
    
    
    def clear_readers():
        path_to_reader_dict.clear()
    
    
    def get_images(lons, lats, timestamp, channel_list, half_width, step, do_norm=False, daynight='ANY'):
        global geoloc_2km, geoloc_1km, geoloc_hkm
    
        dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
        data = []
        data_l = []
        data_r = []
        idxs = []
    
        l_s = []
        c_s = []
    
        for ch_idx, ch in enumerate(channel_list):
            img_width = int(2 * half_width[ch_idx] / step[ch_idx])
            dum = np.arange(img_width * img_width, dtype=np.float32)
            dum = dum.reshape((img_width, img_width))
    
            fL, tL, fC, tC, fR, tR = get_bounding_goes16_files(timestamp, ch)
            if (fL is None) or (fR is None) or (fC is None):
                return None, None, None, None
    
            # copy from archive to local drive
            local_path = fC
            if COPY_GOES:
                local_path = goes_cache_dir + '/' + os.path.split(fC)[1]
                if not os.path.exists(local_path):
                    subprocess.call(['cp', fC, goes_cache_dir])
                    subprocess.call(['chmod', 'u+w', local_path])
    
            # Navigation, convert to BT/REFL
            try:
                pug_l1b_c = get_reader(local_path)
            except Exception as exc:
                print(exc)
                return None, None, None, None
    
            bt_or_refl = None
            if pug_l1b_c.bt_or_refl == 'bt':
                bt_or_refl = pug_l1b_c.bt
            elif pug_l1b_c.bt_or_refl == 'refl':
                bt_or_refl = pug_l1b_c.refl
    
            if TRIPLET or CONV3D:
                # copy from archive to local drive
                local_path_l = fL
                if COPY_GOES:
                    local_path_l = goes_cache_dir + '/' + os.path.split(fL)[1]
                    if not os.path.exists(local_path_l):
                        subprocess.call(['cp', fL, goes_cache_dir])
                        subprocess.call(['chmod', 'u+w', local_path_l])
    
                # Navigation, convert to BT/REFL
                try:
                    pug_l1b_l = get_reader(local_path_l)
                except Exception as exc:
                    print(exc)
                    return None, None, None, None
    
                if pug_l1b_l.bt_or_refl == 'bt':
                    bt_or_refl_l = pug_l1b_l.bt
                elif pug_l1b_l.bt_or_refl == 'refl':
                    bt_or_refl_l = pug_l1b_l.refl
    
                # copy from archive to local drive
                local_path_r = fR
                if COPY_GOES:
                    local_path_r = goes_cache_dir + '/' + os.path.split(fR)[1]
                    if not os.path.exists(local_path_r):
                        subprocess.call(['cp', fR, goes_cache_dir])
                    subprocess.call(['chmod', 'u+w', local_path_r])
    
                # Navigation, convert to BT/REFL
                try:
                    pug_l1b_r = get_reader(local_path_r)
                except Exception as exc:
                    print(exc)
                    return None, None, None, None
    
                if pug_l1b_r.bt_or_refl == 'bt':
                    bt_or_refl_r = pug_l1b_r.bt
                elif pug_l1b_r.bt_or_refl == 'refl':
                    bt_or_refl_r = pug_l1b_r.refl
    
            if daynight != 'ANY':
                if step[ch_idx] == 1:
                    if geoloc_2km is None:
                        geoloc_2km = pug_l1b_c.geo
                    glons = geoloc_2km.lon
                    glats = geoloc_2km.lat
                elif step[ch_idx] == 4:
                    if geoloc_hkm is None:
                        geoloc_hkm = pug_l1b_c.geo
                    glons = geoloc_hkm.lon
                    glats = geoloc_hkm.lat
                elif step[ch_idx] == 2:
                    if geoloc_1km is None:
                        geoloc_1km = pug_l1b_c.geo
                    glons = geoloc_1km.lon
                    glats = geoloc_1km.lat
    
                alt_fnc = get_solar_alt_func(pug_l1b_c, glons, glats, dt_obj)
    
            images = []
            images_l = []
            images_r = []
            for idx, lon in enumerate(lons):
                lat = lats[idx]
                if ch_idx == 0:
                    l, c = pug_l1b_c.lat_lon_to_l_c(lat, lon)
                    l_s.append(l)
                    c_s.append(c)
                else:
                    l = l_s[idx]
                    c = c_s[idx]
                if daynight == 'NIGHT':
                    salt = alt_fnc(l, c)
                    if salt[0] > 3.0:
                        images.append(dum)
                        continue
                elif daynight == 'DAY':
                    salt = alt_fnc(l, c)
                    if salt[0] < 3.0:
                        images.append(dum)
                        continue
    
                l_a = l - half_width[ch_idx]
                l_b = l + half_width[ch_idx]
                c_a = c - half_width[ch_idx]
                c_b = c + half_width[ch_idx]
                if (l_a >= 0 and l_b < pug_l1b_c.shape[0]) and (c_a >= 0 and c_b < pug_l1b_c.shape[1]):
                    img = bt_or_refl[l_a:l_b:step[ch_idx], c_a:c_b:step[ch_idx]]
                    if do_norm:
                        img = prepare_image(img, ch)
                    images.append(img)
                    if ch_idx == 0:
                        idxs.append(idx)
                    if TRIPLET or CONV3D:
                        img = bt_or_refl_l[l_a:l_b:step[ch_idx], c_a:c_b:step[ch_idx]]
                        if do_norm:
                            img = prepare_image(img, ch)
                        images_l.append(img)
    
                        img = bt_or_refl_r[l_a:l_b:step[ch_idx], c_a:c_b:step[ch_idx]]
                        if do_norm:
                            img = prepare_image(img, ch)
                        images_r.append(img)
                else:
                    images.append(dum)
                    if TRIPLET or CONV3D:
                        images_l.append(dum)
                        images_r.append(dum)
    
            images = np.array(images)
            data.append(images)
            if TRIPLET or CONV3D:
                images_l = np.array(images_l)
                images_r = np.array(images_r)
                data_l.append(images_l)
                data_r.append(images_r)
    
            # remove file from local cache
            if COPY_GOES:
                if not CACHE_GOES:
                    subprocess.call(['rm', local_path])
                    if TRIPLET or CONV3D:
                        subprocess.call(['rm', local_path_l])
                        subprocess.call(['rm', local_path_r])
    
        return np.array(data), np.array(data_l), np.array(data_r), np.array(idxs)
    
    
    def copy_abi(matchup_dict, channel_list):
        for key in matchup_dict.keys():
            obs = matchup_dict.get(key)
            timestamp = obs[0][0]
            for ch_idx, ch in enumerate(channel_list):
                f0, t0, f1, t1 = get_bounding_goes16_files(timestamp, ch)
                if (f0 is None) or (f1 is None):
                    continue
    
                # copy from archive to local drive
                local_path = goes_cache_dir + os.path.split(f0)[1]
                if not os.path.exists(local_path):
                    subprocess.call(['cp', f0, goes_cache_dir])
                    subprocess.call(['chmod', 'u+w', local_path])
    
    
    def prepare_image(image, ch):
        from deeplearning.cloudheight import abi_valid_range, abi_mean, abi_std
        vld_rng = abi_valid_range.get(ch)
    
        valid = np.logical_and(image >= vld_rng[0], image <= vld_rng[1])
        mean = abi_mean.get(ch)
        std = abi_std.get(ch)
    
        image -= mean
        image /= std
    
        not_valid = np.invert(valid)
        image[not_valid] = 0
    
        return image
    
    
    def normalize(data, param):
        from deeplearning.cloudheight import mean_std_dict, valid_range_dict
    
        if mean_std_dict.get(param) is None:
            return data
    
        mean, std = mean_std_dict.get(param)
        data -= mean
        data /= std
    
        vld_rng = valid_range_dict.get(param)
        if vld_rng is None:
            return data
    
        shape = data.shape
        data = data.flatten()
        valid = np.logical_and(data >= vld_rng[0], data <= vld_rng[1])
        not_valid = np.invert(valid)
        data[not_valid] = 0
        data = np.reshape(data, shape)
    
        return data
    
    
    def normalize_multi(data, param_s, do_norm_s):
        from deeplearning.cloudheight import mean_std_dict, valid_range_dict
    
        for idx, param in enumerate(param_s):
            if not do_norm_s[idx]:
                continue
    
            if mean_std_dict.get(param) is None:
                continue
    
            mean, std = mean_std_dict.get(param)
            data[idx] -= mean
            data[idx] /= std
    
            vld_rng = valid_range_dict.get(param)
            if vld_rng is None:
                continue
    
            tmp = data[idx]
    
            shape = tmp.shape
            tmp = tmp.flatten()
            valid = np.logical_and(tmp >= vld_rng[0], tmp <= vld_rng[1])
            not_valid = np.invert(valid)
            tmp[not_valid] = 0
            tmp = np.reshape(tmp, shape)
    
            data[idx] = tmp
    
        return data
    
    
    def get_abi_mean_std(matchup_dict, ch_str, valid_range, step=100):
    
        keys = matchup_dict.keys()
        image_s = []
    
        for key in keys:
            obs = matchup_dict.get(key)
            if obs is None:
                continue
            timestamp = obs[0][0]
            dt_obj, ttup = get_time_tuple_utc(timestamp)
            print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
            f0, t0, f1, t1, f2, t2 = get_bounding_goes16_files(timestamp, ch_str)
            if (f0 is None) or (f1 is None) or (f2 is None):
                continue
    
            # copy from archive to local drive
            local_path = goes_cache_dir + '/' + os.path.split(f0)[1]
            if not os.path.exists(local_path):
                subprocess.call(['cp', f0, goes_cache_dir])
                subprocess.call(['chmod', 'u+w', local_path])
    
            # Navigation, convert to BT/REFL
            try:
                pug = PugFile(local_path, 'Rad')
                pug_l1b = PugL1bTools(local_path)
            except Exception as exc:
                print(exc)
                continue
    
            if ch_str == '02':
                corner_lons = []
                corner_lats = []
                elems = pug.shape[1]
                lines = pug.shape[0]
                lons = pug.geo.lon
                lats = pug.geo.lat
                lon_a = lons[0,0]
                lat_a = lats[0,0]
                corner_lons.append(lon_a)
                corner_lats.append(lat_a)
                lon_b = lons[0,elems-1]
                lat_b = lats[0,elems-1]
                corner_lons.append(lon_b)
                corner_lats.append(lat_b)
                lon_c = lons[lines-1, 0]
                lat_c = lats[lines-1, 0]
                corner_lons.append(lon_c)
                corner_lats.append(lat_c)
                lon_d = lons[lines-1, elems-1]
                lat_d = lats[lines-1, elems-1]
                corner_lons.append(lon_d)
                corner_lats.append(lat_d)
    
                for i in range(4):
                    lon = corner_lons[i]
                    lat = corner_lats[i]
                    alt = get_altitude(lat, lon, dt_obj)
                    if alt < 10:
                        print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
                        continue
    
            bt_or_refl = None
            if pug.bt_or_refl == 'bt':
                bt_or_refl = pug_l1b.bt
            elif pug.bt_or_refl == 'refl':
                bt_or_refl = pug_l1b.refl
    
            bt_or_refl = bt_or_refl[::step, ::step]
            valid = np.logical_and(bt_or_refl >= valid_range[0], bt_or_refl <= valid_range[1])
            bt_or_refl = bt_or_refl[valid]
            image_s.append(bt_or_refl)
    
            # remove file from local cache
            if not CACHE_GOES:
                subprocess.call(['rm', local_path])
    
        if len(image_s) == 0:
            return None, None
    
        values = np.concatenate(image_s)
        mean = np.mean(values)
        values -= mean
        std = np.std(values)
    
        return mean, std
    
    
    NUM_VERT_LEVS = 26
    
    
    def get_gfs_mean_std(matchup_dict, fld_name, valid_range, step=2):
    
        keys = matchup_dict.keys()
    
        fld_at_levels = [[] for i in range(NUM_VERT_LEVS)]
    
        for key in keys:
            obs = matchup_dict.get(key)
            if obs is None:
                continue
            timestamp = obs[0][0]
            _, ttup = get_time_tuple_utc(timestamp)
            print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
            gfs_0, time_0, gfs_1, time_1 = get_bounding_gfs_files(timestamp)
            if gfs_0 is None:
                continue
            gfs_0 = convert_file(gfs_0)
    
            try:
                xr_dataset = xr.open_dataset(gfs_0)
            except Exception as exc:
                print(exc)
                return None
    
            fld = xr_dataset[fld_name]
            fld_3d = fld.values
    
            for i in range(NUM_VERT_LEVS):
                fld_i = fld_3d[::step, ::step, i]
                valid = np.logical_and(fld_i >= valid_range[0], fld_i <= valid_range[1])
                fld_i = fld_i[valid]
                fld_at_levels[i].append(fld_i)
    
        mean = []
        std = []
        for i in range(NUM_VERT_LEVS):
            fld_at_levels[i] = (np.concatenate(fld_at_levels[i])).flatten()
            m = np.mean(fld_at_levels[i])
            fld_at_levels[i] -= m
            mean.append(m)
            std.append(np.std(fld_at_levels[i]))
    
        mean = np.array(mean)
        std = np.array(std)
    
        mean = np.reshape(mean, (NUM_VERT_LEVS, 1))
        std = np.reshape(std, (NUM_VERT_LEVS, 1))
    
        return mean, std
    
    
    def get_gfs_mean_std_2d(matchup_dict, fld_name, valid_range, step=2):
    
        keys = matchup_dict.keys()
    
        fld_s = []
    
        for key in keys:
            obs = matchup_dict.get(key)
            if obs is None:
                continue
            timestamp = obs[0][0]
            _, ttup = get_time_tuple_utc(timestamp)
            print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
            gfs_0, time_0, gfs_1, time_1 = get_bounding_gfs_files(timestamp)
            if gfs_0 is None:
                continue
            gfs_0 = convert_file(gfs_0)
    
            try:
                xr_dataset = xr.open_dataset(gfs_0)
            except Exception as exc:
                print(exc)
                return None
    
            fld = xr_dataset[fld_name]
            fld_2d = fld.values
    
            fld_i = fld_2d[::step, ::step]
            valid = np.logical_and(fld_i >= valid_range[0], fld_i <= valid_range[1])
            fld_i = fld_i[valid]
            fld_s.append(fld_i)
    
        fld_s = (np.concatenate(fld_s)).flatten()
    
        mean = np.mean(fld_s)
        fld_s -= mean
        std = np.std(fld_s)
    
        return mean, std
    
    
    PROC_BATCH_SIZE = 10
    BATCH_SIZE = 256
    
    
    def run_mean_std(filename, matchup_dict=None):
        if matchup_dict is None:
            matchup = get_matchups(filename)
            matchup_dict = merge(matchup, None)
            matchup_dict = matchup_dict_to_nd(matchup_dict)
        train_dict, valid_dict = split_matchup(matchup_dict, perc=0.1)
    
        channels = ['14', '08', '02']
        valid_range = [[150, 350], [150, 350], [0, 100]]
        step = [100, 100, 400]
    
        f = open('/home/rink/stats.txt', 'w')
    
        for ch_idx, ch in enumerate(channels):
            mean, std = get_abi_mean_std(valid_dict, channels[ch_idx], valid_range[ch_idx], step[ch_idx])
            print(ch, ': ', mean, std)
            f.write('channel: %s, %f, %f\n' % (ch, mean, std))
    
        mean, std = get_gfs_mean_std(valid_dict, 'temperature', [150, 350])
        f.write('gfs temperature: \n')
        for i in range(NUM_VERT_LEVS):
            f.write('%f, %f\n' % (mean[i], std[i]))
    
        f.close()
    
        return matchup_dict
    
    
    def get_solar_alt_func(pug, glons, glats, dt_obj):
        elems = pug.shape[1]
        lines = pug.shape[0]
    
        gx = np.arange(elems)[::100]
        gy = np.arange(lines)[::100]
    
        glons = glons[gy]
        glons = glons[:, gx]
        glats = glats[gy]
        glats = glats[:, gx]
    
        glons = glons.flatten()
        glats = glats.flatten()
    
        slr_alt = []
        for i in range(len(glons)):
            alt = get_altitude(glats[i], glons[i], dt_obj)
            slr_alt.append(alt)
        nd_alt = np.array(slr_alt)
        alt_fnc = interp2d(gy, gx, nd_alt, kind='cubic')
    
        return alt_fnc