import datetime
from datetime import timezone
import bisect
import glob
import numpy as np
import subprocess
import xarray as xr
from util.goesr_l1b import PugFile, PugL1bTools
from util.setup import home_dir
from pysolar.solar import get_altitude
from scipy.interpolate import interp2d
#import cartopy
#from cartopy import *
from metpy import *
#import cartopy.crs as ccrs
import random
import re


goes16_directory = '/arcdata/goes/grb/goes16'  # /year/date/abi/L1b/RadC
goes_date_format = '%Y%j%H'
dt_str_pat = '_s\d{11}'
dir_fmt = '%Y_%m_%d_%j'

gfs_directory = '/apollo/cloud/Ancil_Data/clavrx_ancil_data/dynamic/gfs/'
gfs_name_pattern = r'gfs\.(\d{8})_F\d{3}\.h5'
gfs_date_format = '%y%m%d'

h4_to_h5_path = home_dir + '/h4toh5convert'

data_dir = '/scratch/long/rink'
converted_file_dir = data_dir + '/gfs_h5'
CACHE_GFS = True

goes_cache_dir = data_dir + '/goes16'
CACHE_GOES = True
COPY_GOES = True

fmt = '%Y%j%H'

matchup_dict_glbl = None

geoloc_2km = None
geoloc_1km = None
geoloc_hkm = None

TRIPLET = False
CONV3D = False

H08_lat_range = [-78, 78]
H08_lon_range = [62, 218]


amv_hdr_list = ['lat', 'lon', 'tbox', 'sbox', 'spd', 'dir', 'pres', 'lowl', 'mspd', 'mdir',
                'alb', 'corr', 'tmet', 'perr', 'qi', 'cqi', 'qif']

raob_hdr_list = ['pres', 'temp', 'dir', 'spd']

raob_spd_idx = raob_hdr_list.index('spd')
raob_dir_idx = raob_hdr_list.index('dir')
amv_lon_idx = amv_hdr_list.index('lon')
amv_lat_idx = amv_hdr_list.index('lat')
amv_pres_idx = amv_hdr_list.index('pres')
raob_pres_idx = raob_hdr_list.index('pres')
amv_spd_idx = amv_hdr_list.index('spd')
amv_dir_idx = amv_hdr_list.index('dir')
amv_prs_idx = amv_hdr_list.index('pres')
amv_qi_idx = amv_hdr_list.index('qi')
amv_cqi_idx = amv_hdr_list.index('cqi')
amv_qif_idx = amv_hdr_list.index('qif')


def get_matchups(filename):
    header = None
    data = []
    with open(filename) as file:
        for line in file:
            if header is None:
                header = line
                continue
            toks = line.split()

            dto = datetime.datetime.strptime(toks[1], fmt).replace(tzinfo=timezone.utc)
            timestmp = dto.timestamp()

            lat = float(toks[2])/100
            lon = float(toks[3])/100
            lbfp = float(toks[4])/100
            amvpre = float(toks[5])/100
            spd = float(toks[6])/100
            dir = float(toks[7])/100
            shear = float(toks[8])/100
            satzen = float(toks[9])/100
            qif = float(toks[11])/100
            phase = float(toks[13])/100
            type = float(toks[14])/100

            if qif < 50:
                continue

            tup = (timestmp, lat, lon, lbfp, amvpre, spd, dir, shear, satzen, qif, phase, type)
            data.append(tup)

        file.close()

    return data


def merge(matchups, time_to_matchups):
    if matchups is None:
        return time_to_matchups

    if time_to_matchups is None:
        tup = matchups[0]
        time = tup[0]
        tup_s = []
        time_to_matchups = {time : tup_s}

    time_s = list(time_to_matchups.keys())
    tup_s = list(time_to_matchups.values())

    for tup in matchups:
        time = tup[0]
        try:
            idx = time_s.index(time)
            tup_s[idx].append(tup)
        except ValueError:
            idx = bisect.bisect(time_s, time)
            time_s.insert(idx, time)
            lst = []
            lst.append(tup)
            tup_s.insert(idx, lst)

    time_to_matchups = {}
    for idx in range(len(time_s)):
        # time_to_matchups[time_s[idx]] = tup_s[idx]
        time_to_matchups[idx] = tup_s[idx]

    return time_to_matchups


def matchup_dict_to_nd(matchup_dict, shuffle=True, reduce=None):
    time_keys = matchup_dict.keys()
    time_to_nd = {}

    for key in time_keys:
        obs_lst = matchup_dict.get(key)
        nda = np.array(obs_lst)
        if shuffle:
            np.random.shuffle(nda)
        if reduce is not None:
            nda = nda[::reduce, :]
        time_to_nd[key] = nda

    return time_to_nd


def get_num_samples(matchup_dict, keys=None):
    cnt = 0

    if keys is None:
        time_keys = matchup_dict.keys()
    else:
        time_keys = keys

    for key in time_keys:
        cnt += matchup_dict[key].shape[0]
    return cnt


# CONUS
lon_space = np.linspace(-140, -40, 401)
lat_space = np.linspace(15, 55, 161)


def spatial_filter(matchup_dict, shuffle=True, reduce=None):
    keys = matchup_dict.keys()
    filterd_dict = {}
    grd_x_hi = lon_space.shape[0] - 1
    grd_y_hi = lat_space.shape[0] - 1

    for key in keys:

        grd_bins = np.full((lat_space.shape[0], lon_space.shape[0]), False)

        obs = matchup_dict.get(key)
        lons = obs[:, 2]
        lats = obs[:, 1]

        lon_idxs = np.searchsorted(lon_space, lons)
        lat_idxs = np.searchsorted(lat_space, lats)

        keep_obs = []
        for i in range(len(lon_idxs)):
            lon_idx = lon_idxs[i]
            lat_idx = lat_idxs[i]
            if lon_idx < 0 or lon_idx > grd_x_hi:
                continue
            if lat_idx < 0 or lat_idx > grd_y_hi:
                continue
            if not grd_bins[lat_idx, lon_idx]:
                grd_bins[lat_idx, lon_idx] = True
                keep_obs.append(obs[i])

        if len(keep_obs) == 0:
            continue

        keep_obs = np.array(keep_obs)
        if shuffle:
            np.random.seed(seed=42)
            shuffle_idxs = np.random.permutation(keep_obs.shape[0])
            keep_obs = keep_obs[shuffle_idxs]

        if reduce is not None:
            keep_obs = keep_obs[::reduce, :]

        filterd_dict[key] = keep_obs

    return filterd_dict


def get_matchup_dict(filename, merge_with=None, filter_by_space=True, reduce=None, skip=None):
    matchup = get_matchups(filename)
    matchup_dict = merge(matchup, merge_with)

    if skip is not None:
        skp_dct = {}
        skp_keys = list(matchup_dict.keys())
        random.shuffle(skp_keys)
        skp_keys = skp_keys[::skip]
        for key in skp_keys:
            skp_dct[key] = matchup_dict[key]
        matchup_dict = skp_dct

    matchup_dict = matchup_dict_to_nd(matchup_dict)

    if filter_by_space:
        matchup_dict = spatial_filter(matchup_dict, reduce=reduce)

    return matchup_dict


NUM_SECTIONS = 4


def split_matchup(matchup_dict, perc=0.2):
    time_keys = matchup_dict.keys()
    num_datetimes = len(time_keys)
    keys_list = list(time_keys)

    num_test = int(num_datetimes * perc)

    skip = int(num_datetimes / num_test)

    tst_keys = keys_list[::skip]
    train_dict = {}
    test_dict = {}

    tst_set = set(tst_keys)
    trn_set = (set(keys_list)).difference(tst_set)
    trn_keys = list(trn_set)

    for key in tst_keys:
        test_dict[key] = matchup_dict.get(key)

    for key in trn_keys:
        train_dict[key] = matchup_dict.get(key)

    return train_dict, test_dict


def flatten(matchup_dict, shuffle=True):
    time_keys = matchup_dict.keys()

    dat_list = []
    cnt = 0
    for key in time_keys:
        dat = matchup_dict[key]
        num_dat = dat.shape[0]
        dat_list.append(dat)
        cnt += num_dat

    data = np.vstack(dat_list)

    if shuffle:
        np.random.seed(seed=42)
        shuffle_idxs = np.random.permutation(cnt)
        data = data[shuffle_idxs]

    return data


def shuffle_dict(in_dict):  # Outer most level only
    shuffled_dict = {}
    keys = in_dict.keys()
    num_times = len(keys)
    np.random.seed(seed=42)
    shuffle_idxs = np.random.permutation(num_times)

    key_list = list(keys)

    for idx in shuffle_idxs:
        key = key_list[idx]
        shuffled_dict[key] = in_dict[key]

    return shuffled_dict


def filter_by_type(matchup_dict, error=None, phase=None, ctype=None):
    time_keys = list(matchup_dict.keys())
    filtered_dict = {}

    for key in time_keys:
        amvs = matchup_dict[key]
        keep = None
        if error is not None:
            mae = np.abs(amvs[:, 4] - amvs[:, 3])
            keep = mae > error
            amvs = amvs[keep, :]
        if phase is not None:
            keep = amvs[:, 10] == phase
            amvs = amvs[keep, :]
        elif ctype is not None:
            keep = amvs[:, 11] == ctype
            amvs = amvs[keep, :]

        if keep is not None:
            if np.sum(keep) == 0:
                continue

        filtered_dict[key] = amvs

    return filtered_dict


def analyze(matchup_dict):
    time_keys = list(matchup_dict.keys())

    hgt_diff = []
    hgt_diff_i = []
    hgt_diff_w = []
    hgt_diff_scw = []
    hgt_diff_mphz = []
    hgts = []
    hgts_i = []
    hgts_w = []
    hgts_scw = []
    hgts_mphz = []
    num_amvs = 0
    num_ice = 0
    num_water = 0
    num_sc_water = 0
    num_mphz = 0

    hgt_diff_tice = []
    hgts_tice = []
    num_tice = 0

    hgt_diff_cirrus = []
    hgts_cirrus = []
    num_cirrus = 0
    hgt_diff_ovrlap = []
    hgts_ovrlap = []
    num_ovrlap = 0

    for key in time_keys:
        amvs = matchup_dict[key]
        # num_amvs += amvs.shape[0]
        d = amvs[:, 4] - amvs[:, 3]

        # keep = np.abs(d) > 70
        keep = np.abs(d) >= 0
        d = d[keep]
        amvs = amvs[keep, :]
        num_amvs += amvs.shape[0]

        hgt_diff.append(d)
        hgts.append(amvs[:, 3])

        ice = amvs[:, 10] == 4
        hgt_diff_i.append(d[ice])
        hgts_i.append(amvs[ice, 3])
        num_ice += np.sum(ice)

        water = amvs[:, 10] == 1
        hgt_diff_w.append(d[water])
        hgts_w.append(amvs[water, 3])
        num_water += np.sum(water)

        water = amvs[:, 10] == 2
        hgt_diff_scw.append(d[water])
        hgts_scw.append(amvs[water, 3])
        num_sc_water += np.sum(water)

        mphz = amvs[:, 10] == 3
        hgt_diff_mphz.append(d[mphz])
        hgts_mphz.append(amvs[mphz, 3])
        num_mphz += np.sum(mphz)

        tice = amvs[:, 11] == 5
        hgt_diff_tice.append(d[tice])
        hgts_tice.append(amvs[tice, 3])
        num_tice += np.sum(tice)

        cirrus = amvs[:, 11] == 6
        hgt_diff_cirrus.append(d[cirrus])
        hgts_cirrus.append(amvs[cirrus, 3])
        num_cirrus += np.sum(cirrus)

        ovrlap = amvs[:, 11] == 7
        hgt_diff_ovrlap.append(d[ovrlap])
        hgts_ovrlap.append(amvs[ovrlap, 3])
        num_ovrlap += np.sum(ovrlap)

    hgt_diff = np.concatenate(hgt_diff)
    hgt_diff_i = np.concatenate(hgt_diff_i)
    hgt_diff_w = np.concatenate(hgt_diff_w)
    hgt_diff_scw = np.concatenate(hgt_diff_scw)
    hgt_diff_tice = np.concatenate(hgt_diff_tice)
    hgt_diff_cirrus = np.concatenate(hgt_diff_cirrus)
    hgt_diff_mphz = np.concatenate(hgt_diff_mphz)
    hgt_diff_ovrlap = np.concatenate(hgt_diff_ovrlap)

    hgts = np.concatenate(hgts)
    hgts_i = np.concatenate(hgts_i)
    hgts_w = np.concatenate(hgts_w)
    hgts_scw = np.concatenate(hgts_scw)
    hgts_tice = np.concatenate(hgts_tice)
    hgts_cirrus = np.concatenate(hgts_cirrus)
    hgts_ovrlap = np.concatenate(hgts_ovrlap)

    print('total: ', num_amvs)
    print('ice: ', num_ice, num_ice/num_amvs)
    print('water: ', num_water, num_water/num_amvs)
    print('sc water: ', num_sc_water, num_sc_water/num_amvs)
    print('mixed phase: ', num_mphz, num_mphz/num_amvs)
    print('tice: ', num_tice, num_tice/num_amvs)
    print('cirrus: ', num_cirrus, num_cirrus/num_amvs)
    print('num ovrlap: ', num_ovrlap, num_ovrlap/num_amvs)

    print('all MAE/BIAS/STD', np.average(np.abs(hgt_diff)), np.average(hgt_diff), np.std(hgt_diff))
    print('ice MAE/BIAS/STD', np.average(np.abs(hgt_diff_i)), np.average(hgt_diff_i), np.std(hgt_diff_i))
    print('water MAE/BIAS/STD', np.average(np.abs(hgt_diff_w)), np.average(hgt_diff_w), np.std(hgt_diff_w))
    print('sc water MAE/BIAS/STD', np.average(np.abs(hgt_diff_scw)), np.average(hgt_diff_scw), np.std(hgt_diff_scw))
    print('tice MAE/BIAS/STD', np.average(np.abs(hgt_diff_tice)), np.average(hgt_diff_tice), np.std(hgt_diff_tice))
    print('cirrus MAE/BIAS/STD', np.average(np.abs(hgt_diff_cirrus)), np.average(hgt_diff_cirrus), np.std(hgt_diff_cirrus))
    print('mphz MAE/BIAS/STD', np.average(np.abs(hgt_diff_mphz)), np.average(hgt_diff_mphz), np.std(hgt_diff_mphz))
    print('ovrlap MAE/BIAS/STD', np.average(np.abs(hgt_diff_ovrlap)), np.average(hgt_diff_ovrlap), np.std(hgt_diff_ovrlap))

    return hgts, hgt_diff, hgt_diff_i, hgt_diff_w, hgt_diff_scw, hgt_diff_tice, hgt_diff_cirrus, hgt_diff_mphz, hgt_diff_ovrlap


def get_time_tuple_utc(timestamp):
    dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
    return dt_obj, dt_obj.timetuple()


def get_bounding_goes16_files(timestamp, ch_str, file_time_span=5):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)

    dt_obj_0 = dt_obj - datetime.timedelta(minutes=20)
    dt_obj_1 = dt_obj + datetime.timedelta(minutes=20)

    yr_dir_0 = str(dt_obj_0.timetuple().tm_year)
    date_dir_0 = dt_obj_0.strftime(dir_fmt)
    date_str_0 = dt_obj_0.strftime(goes_date_format)

    yr_dir_1 = str(dt_obj_1.timetuple().tm_year)
    date_dir_1 = dt_obj_1.strftime(dir_fmt)
    date_str_1 = dt_obj_1.strftime(goes_date_format)

    files_path_0 = goes16_directory + '/' + yr_dir_0 + '/' + date_dir_0 + '/abi' + '/L1b' + '/RadC'
    files_path_1 = goes16_directory + '/' + yr_dir_1 + '/' + date_dir_1 + '/abi' + '/L1b' + '/RadC'

    flist_0 = glob.glob(files_path_0 + '/OR_ABI-L1b-RadC-??C' + ch_str + '_G16_s' + date_str_0 + '*.nc')
    flist_1 = []
    if date_str_0 != date_str_1:
        flist_1 = glob.glob(files_path_1 + '/OR_ABI-L1b-RadC-??C' + ch_str + '_G16_s' + date_str_1 + '*.nc')

    flist = flist_0 + flist_1
    if len(flist) == 0:
        return None, None, None, None, None, None

    ftimes = []
    for pname in flist:
        fname = os.path.split(pname)[1]
        tstr = re.search(dt_str_pat, fname).group()[2:]
        dto = datetime.datetime.strptime(tstr, goes_date_format+'%M').replace(tzinfo=timezone.utc)
        ftimes.append(dto.timestamp())

    tarr = np.array(ftimes)
    sidxs = tarr.argsort()

    farr = np.array(flist)
    farr = farr[sidxs]
    ftimes = tarr[sidxs]

    ftimes_n = ftimes + file_time_span * 60
    iC = -1
    for k, t in enumerate(ftimes):
        if t <= timestamp < ftimes_n[k]:
            iC = k
            break

    if iC >= 0:
        iL = iC - 1
        iR = iC + 1
    else:
        return None, None, None, None, None, None

    fList = farr.tolist()

    if iL < 0 or iR >= len(ftimes):
        return None, None, fList[iC], ftimes[iC], None, None
    else:
        return fList[iL], ftimes[iL], fList[iC], ftimes[iC], fList[iR], ftimes[iR]


def get_bounding_gfs_files(timestamp):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)
    dt_obj = dt_obj - datetime.timedelta(hours=12)
    date_str = dt_obj.strftime(gfs_date_format)
    dt_obj0 = datetime.datetime.strptime(date_str, gfs_date_format).replace(tzinfo=timezone.utc)
    dt_obj1 = dt_obj0 + datetime.timedelta(days=1)
    date_str_1 = dt_obj1.strftime(gfs_date_format)

    flist_0 = glob.glob(gfs_directory+'gfs.'+date_str+'??_F012.hdf')
    flist_1 = glob.glob(gfs_directory+'gfs.'+date_str_1+'00_F012.hdf')
    if len(flist_0) == 0:
        return None, None, None, None
    filelist = flist_0
    if len(flist_1) > 0:
        filelist = flist_0 + flist_1

    ftimes = []
    for pname in filelist:
        match = re.search(gfs_name_pattern, pname)
        tstr = match[1]
        dto = datetime.datetime.strptime(tstr, gfs_date_format+'%H').replace(tzinfo=timezone.utc)
        dto = dto + datetime.timedelta(hours=12)
        ftimes.append(dto.timestamp())

    tarr = np.array(ftimes)
    sidxs = tarr.argsort()

    farr = np.array(filelist)
    farr = farr[sidxs]
    ftimes = tarr[sidxs]
    idxs = np.arange(len(filelist))

    above = ftimes >= timestamp
    if not above.any():
        return None, None, None, None

    below = ftimes <= timestamp
    if not below.any():
        return None, None, None, None
    iL = idxs[below].max()
    iR = iL + 1

    fList = farr.tolist()

    if timestamp == ftimes[iL]:
        return fList[iL], ftimes[iL], None, None
    else:
        return fList[iL], ftimes[iL], fList[iR], ftimes[iR]


def get_profile(xr_dataset, fld_name, lons, lats, lon360=True, do_norm=False):
    if lon360:
        lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360

    plevs = xr_dataset['pressure levels']
    lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
    lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)

    fld = xr_dataset[fld_name]
    fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)

    dim2 = xr.DataArray(lons, dims='k')
    dim1 = xr.DataArray(lats, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method='linear')
    intrp_fld = intrp_fld.values
    if do_norm:
        intrp_fld = normalize(intrp_fld, fld_name)

    return intrp_fld


def get_profile_multi(xr_dataset, fld_name_s, lons, lats, lon360=True):
    if lon360:
        lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360

    plevs = xr_dataset['pressure levels']
    lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
    lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)

    fld_s = []
    for fld_name in fld_name_s:
        fld = xr_dataset[fld_name]
        fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)
        fld_s.append(fld)

    fld = xr.concat(fld_s, 'fld_dim')

    dim2 = xr.DataArray(lons, dims='k')
    dim1 = xr.DataArray(lats, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method='linear')

    return intrp_fld


def get_scalar(xr_dataset, fld_name, lons, lats, lon360=True):
    if lon360:
        lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360

    lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
    lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)

    fld = xr_dataset[fld_name]
    fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords)

    dim1 = xr.DataArray(lats, dims='k')
    dim2 = xr.DataArray(lons, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, method='linear')

    return intrp_fld


def convert_file(fpath_hdf4):
    fname_h4 = os.path.split(fpath_hdf4)[1]
    fname_h5 = os.path.splitext(fname_h4)[0] + '.h5'
    fpath_h5 = converted_file_dir + '/' + fname_h5

    # check if we have it already
    if os.path.exists(fpath_h5):
        return fpath_h5

    subprocess.call([h4_to_h5_path, fpath_hdf4, fpath_h5])

    return fpath_h5


def get_interpolated_profile(ds_0, ds_1, time0, time1, fld_name, time, lons, lats, do_norm=False):

    prof_0 = get_profile(ds_0, fld_name, lons, lats)
    prof_1 = get_profile(ds_1, fld_name, lons, lats)

    prof = xr.concat([prof_0, prof_1], 'time')
    prof = prof.assign_coords(time=[time0, time1])

    intrp_prof = prof.interp(time=time, method='linear')

    intrp_prof = intrp_prof.values

    if do_norm:
        intrp_prof = normalize(intrp_prof, fld_name)

    return intrp_prof


def get_interpolated_profile_multi(ds_0, ds_1, time0, time1, fld_name_s, time, lons, lats, do_norm_s):

    prof_0 = get_profile_multi(ds_0, fld_name_s, lons, lats)
    prof_1 = get_profile_multi(ds_1, fld_name_s, lons, lats)

    prof = xr.concat([prof_0, prof_1], 'time')
    prof = prof.assign_coords(time=[time0, time1])

    intrp_prof = prof.interp(time=time, method='linear')

    intrp_prof = intrp_prof.values

    intrp_prof = normalize_multi(intrp_prof, fld_name_s, do_norm_s)

    return intrp_prof


def get_interpolated_scalar(ds_0, ds_1, time0, time1, fld_name, time, lons, lats, do_norm=False):

    vals_0 = get_scalar(ds_0, fld_name, lons, lats)
    vals_1 = get_scalar(ds_1, fld_name, lons, lats)

    vals = xr.concat([vals_0, vals_1], 'time')
    vals = vals.assign_coords(time=[time0, time1])

    intrp_vals = vals.interp(time=time, method='linear')

    intrp_vals = intrp_vals.values

    if do_norm:
        intrp_vals = normalize(intrp_vals, fld_name)
    return intrp_vals


path_to_reader_dict = {}

MAX_FILES_OPEN = 50


def get_reader(path):
    if len(path_to_reader_dict) > MAX_FILES_OPEN:
        path_to_reader_dict.clear()

    rdr = path_to_reader_dict.get(path)
    if rdr is None:
        rdr = PugL1bTools(path)
        path_to_reader_dict[path] = rdr
    return rdr


def clear_readers():
    path_to_reader_dict.clear()


def get_images(lons, lats, timestamp, channel_list, half_width, step, do_norm=False, daynight='ANY'):
    global geoloc_2km, geoloc_1km, geoloc_hkm

    dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
    data = []
    data_l = []
    data_r = []
    idxs = []

    l_s = []
    c_s = []

    for ch_idx, ch in enumerate(channel_list):
        img_width = int(2 * half_width[ch_idx] / step[ch_idx])
        dum = np.arange(img_width * img_width, dtype=np.float32)
        dum = dum.reshape((img_width, img_width))

        fL, tL, fC, tC, fR, tR = get_bounding_goes16_files(timestamp, ch)
        if (fL is None) or (fR is None) or (fC is None):
            return None, None, None, None

        # copy from archive to local drive
        local_path = fC
        if COPY_GOES:
            local_path = goes_cache_dir + '/' + os.path.split(fC)[1]
            if not os.path.exists(local_path):
                subprocess.call(['cp', fC, goes_cache_dir])
                subprocess.call(['chmod', 'u+w', local_path])

        # Navigation, convert to BT/REFL
        try:
            pug_l1b_c = get_reader(local_path)
        except Exception as exc:
            print(exc)
            return None, None, None, None

        bt_or_refl = None
        if pug_l1b_c.bt_or_refl == 'bt':
            bt_or_refl = pug_l1b_c.bt
        elif pug_l1b_c.bt_or_refl == 'refl':
            bt_or_refl = pug_l1b_c.refl

        if TRIPLET or CONV3D:
            # copy from archive to local drive
            local_path_l = fL
            if COPY_GOES:
                local_path_l = goes_cache_dir + '/' + os.path.split(fL)[1]
                if not os.path.exists(local_path_l):
                    subprocess.call(['cp', fL, goes_cache_dir])
                    subprocess.call(['chmod', 'u+w', local_path_l])

            # Navigation, convert to BT/REFL
            try:
                pug_l1b_l = get_reader(local_path_l)
            except Exception as exc:
                print(exc)
                return None, None, None, None

            if pug_l1b_l.bt_or_refl == 'bt':
                bt_or_refl_l = pug_l1b_l.bt
            elif pug_l1b_l.bt_or_refl == 'refl':
                bt_or_refl_l = pug_l1b_l.refl

            # copy from archive to local drive
            local_path_r = fR
            if COPY_GOES:
                local_path_r = goes_cache_dir + '/' + os.path.split(fR)[1]
                if not os.path.exists(local_path_r):
                    subprocess.call(['cp', fR, goes_cache_dir])
                subprocess.call(['chmod', 'u+w', local_path_r])

            # Navigation, convert to BT/REFL
            try:
                pug_l1b_r = get_reader(local_path_r)
            except Exception as exc:
                print(exc)
                return None, None, None, None

            if pug_l1b_r.bt_or_refl == 'bt':
                bt_or_refl_r = pug_l1b_r.bt
            elif pug_l1b_r.bt_or_refl == 'refl':
                bt_or_refl_r = pug_l1b_r.refl

        if daynight != 'ANY':
            if step[ch_idx] == 1:
                if geoloc_2km is None:
                    geoloc_2km = pug_l1b_c.geo
                glons = geoloc_2km.lon
                glats = geoloc_2km.lat
            elif step[ch_idx] == 4:
                if geoloc_hkm is None:
                    geoloc_hkm = pug_l1b_c.geo
                glons = geoloc_hkm.lon
                glats = geoloc_hkm.lat
            elif step[ch_idx] == 2:
                if geoloc_1km is None:
                    geoloc_1km = pug_l1b_c.geo
                glons = geoloc_1km.lon
                glats = geoloc_1km.lat

            alt_fnc = get_solar_alt_func(pug_l1b_c, glons, glats, dt_obj)

        images = []
        images_l = []
        images_r = []
        for idx, lon in enumerate(lons):
            lat = lats[idx]
            if ch_idx == 0:
                l, c = pug_l1b_c.lat_lon_to_l_c(lat, lon)
                l_s.append(l)
                c_s.append(c)
            else:
                l = l_s[idx]
                c = c_s[idx]
            if daynight == 'NIGHT':
                salt = alt_fnc(l, c)
                if salt[0] > 3.0:
                    images.append(dum)
                    continue
            elif daynight == 'DAY':
                salt = alt_fnc(l, c)
                if salt[0] < 3.0:
                    images.append(dum)
                    continue

            l_a = l - half_width[ch_idx]
            l_b = l + half_width[ch_idx]
            c_a = c - half_width[ch_idx]
            c_b = c + half_width[ch_idx]
            if (l_a >= 0 and l_b < pug_l1b_c.shape[0]) and (c_a >= 0 and c_b < pug_l1b_c.shape[1]):
                img = bt_or_refl[l_a:l_b:step[ch_idx], c_a:c_b:step[ch_idx]]
                if do_norm:
                    img = prepare_image(img, ch)
                images.append(img)
                if ch_idx == 0:
                    idxs.append(idx)
                if TRIPLET or CONV3D:
                    img = bt_or_refl_l[l_a:l_b:step[ch_idx], c_a:c_b:step[ch_idx]]
                    if do_norm:
                        img = prepare_image(img, ch)
                    images_l.append(img)

                    img = bt_or_refl_r[l_a:l_b:step[ch_idx], c_a:c_b:step[ch_idx]]
                    if do_norm:
                        img = prepare_image(img, ch)
                    images_r.append(img)
            else:
                images.append(dum)
                if TRIPLET or CONV3D:
                    images_l.append(dum)
                    images_r.append(dum)

        images = np.array(images)
        data.append(images)
        if TRIPLET or CONV3D:
            images_l = np.array(images_l)
            images_r = np.array(images_r)
            data_l.append(images_l)
            data_r.append(images_r)

        # remove file from local cache
        if COPY_GOES:
            if not CACHE_GOES:
                subprocess.call(['rm', local_path])
                if TRIPLET or CONV3D:
                    subprocess.call(['rm', local_path_l])
                    subprocess.call(['rm', local_path_r])

    return np.array(data), np.array(data_l), np.array(data_r), np.array(idxs)


def copy_abi(matchup_dict, channel_list):
    for key in matchup_dict.keys():
        obs = matchup_dict.get(key)
        timestamp = obs[0][0]
        for ch_idx, ch in enumerate(channel_list):
            f0, t0, f1, t1 = get_bounding_goes16_files(timestamp, ch)
            if (f0 is None) or (f1 is None):
                continue

            # copy from archive to local drive
            local_path = goes_cache_dir + os.path.split(f0)[1]
            if not os.path.exists(local_path):
                subprocess.call(['cp', f0, goes_cache_dir])
                subprocess.call(['chmod', 'u+w', local_path])


def prepare_image(image, ch):
    from deeplearning.cloudheight import abi_valid_range, abi_mean, abi_std
    vld_rng = abi_valid_range.get(ch)

    valid = np.logical_and(image >= vld_rng[0], image <= vld_rng[1])
    mean = abi_mean.get(ch)
    std = abi_std.get(ch)

    image -= mean
    image /= std

    not_valid = np.invert(valid)
    image[not_valid] = 0

    return image


def normalize(data, param):
    from deeplearning.cloudheight import mean_std_dict, valid_range_dict

    if mean_std_dict.get(param) is None:
        return data

    mean, std = mean_std_dict.get(param)
    data -= mean
    data /= std

    vld_rng = valid_range_dict.get(param)
    if vld_rng is None:
        return data

    shape = data.shape
    data = data.flatten()
    valid = np.logical_and(data >= vld_rng[0], data <= vld_rng[1])
    not_valid = np.invert(valid)
    data[not_valid] = 0
    data = np.reshape(data, shape)

    return data


def normalize_multi(data, param_s, do_norm_s):
    from deeplearning.cloudheight import mean_std_dict, valid_range_dict

    for idx, param in enumerate(param_s):
        if not do_norm_s[idx]:
            continue

        if mean_std_dict.get(param) is None:
            continue

        mean, std = mean_std_dict.get(param)
        data[idx] -= mean
        data[idx] /= std

        vld_rng = valid_range_dict.get(param)
        if vld_rng is None:
            continue

        tmp = data[idx]

        shape = tmp.shape
        tmp = tmp.flatten()
        valid = np.logical_and(tmp >= vld_rng[0], tmp <= vld_rng[1])
        not_valid = np.invert(valid)
        tmp[not_valid] = 0
        tmp = np.reshape(tmp, shape)

        data[idx] = tmp

    return data


def get_abi_mean_std(matchup_dict, ch_str, valid_range, step=100):

    keys = matchup_dict.keys()
    image_s = []

    for key in keys:
        obs = matchup_dict.get(key)
        if obs is None:
            continue
        timestamp = obs[0][0]
        dt_obj, ttup = get_time_tuple_utc(timestamp)
        print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
        f0, t0, f1, t1, f2, t2 = get_bounding_goes16_files(timestamp, ch_str)
        if (f0 is None) or (f1 is None) or (f2 is None):
            continue

        # copy from archive to local drive
        local_path = goes_cache_dir + '/' + os.path.split(f0)[1]
        if not os.path.exists(local_path):
            subprocess.call(['cp', f0, goes_cache_dir])
            subprocess.call(['chmod', 'u+w', local_path])

        # Navigation, convert to BT/REFL
        try:
            pug = PugFile(local_path, 'Rad')
            pug_l1b = PugL1bTools(local_path)
        except Exception as exc:
            print(exc)
            continue

        if ch_str == '02':
            corner_lons = []
            corner_lats = []
            elems = pug.shape[1]
            lines = pug.shape[0]
            lons = pug.geo.lon
            lats = pug.geo.lat
            lon_a = lons[0,0]
            lat_a = lats[0,0]
            corner_lons.append(lon_a)
            corner_lats.append(lat_a)
            lon_b = lons[0,elems-1]
            lat_b = lats[0,elems-1]
            corner_lons.append(lon_b)
            corner_lats.append(lat_b)
            lon_c = lons[lines-1, 0]
            lat_c = lats[lines-1, 0]
            corner_lons.append(lon_c)
            corner_lats.append(lat_c)
            lon_d = lons[lines-1, elems-1]
            lat_d = lats[lines-1, elems-1]
            corner_lons.append(lon_d)
            corner_lats.append(lat_d)

            for i in range(4):
                lon = corner_lons[i]
                lat = corner_lats[i]
                alt = get_altitude(lat, lon, dt_obj)
                if alt < 10:
                    print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
                    continue

        bt_or_refl = None
        if pug.bt_or_refl == 'bt':
            bt_or_refl = pug_l1b.bt
        elif pug.bt_or_refl == 'refl':
            bt_or_refl = pug_l1b.refl

        bt_or_refl = bt_or_refl[::step, ::step]
        valid = np.logical_and(bt_or_refl >= valid_range[0], bt_or_refl <= valid_range[1])
        bt_or_refl = bt_or_refl[valid]
        image_s.append(bt_or_refl)

        # remove file from local cache
        if not CACHE_GOES:
            subprocess.call(['rm', local_path])

    if len(image_s) == 0:
        return None, None

    values = np.concatenate(image_s)
    mean = np.mean(values)
    values -= mean
    std = np.std(values)

    return mean, std


NUM_VERT_LEVS = 26


def get_gfs_mean_std(matchup_dict, fld_name, valid_range, step=2):

    keys = matchup_dict.keys()

    fld_at_levels = [[] for i in range(NUM_VERT_LEVS)]

    for key in keys:
        obs = matchup_dict.get(key)
        if obs is None:
            continue
        timestamp = obs[0][0]
        _, ttup = get_time_tuple_utc(timestamp)
        print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
        gfs_0, time_0, gfs_1, time_1 = get_bounding_gfs_files(timestamp)
        if gfs_0 is None:
            continue
        gfs_0 = convert_file(gfs_0)

        try:
            xr_dataset = xr.open_dataset(gfs_0)
        except Exception as exc:
            print(exc)
            return None

        fld = xr_dataset[fld_name]
        fld_3d = fld.values

        for i in range(NUM_VERT_LEVS):
            fld_i = fld_3d[::step, ::step, i]
            valid = np.logical_and(fld_i >= valid_range[0], fld_i <= valid_range[1])
            fld_i = fld_i[valid]
            fld_at_levels[i].append(fld_i)

    mean = []
    std = []
    for i in range(NUM_VERT_LEVS):
        fld_at_levels[i] = (np.concatenate(fld_at_levels[i])).flatten()
        m = np.mean(fld_at_levels[i])
        fld_at_levels[i] -= m
        mean.append(m)
        std.append(np.std(fld_at_levels[i]))

    mean = np.array(mean)
    std = np.array(std)

    mean = np.reshape(mean, (NUM_VERT_LEVS, 1))
    std = np.reshape(std, (NUM_VERT_LEVS, 1))

    return mean, std


def get_gfs_mean_std_2d(matchup_dict, fld_name, valid_range, step=2):

    keys = matchup_dict.keys()

    fld_s = []

    for key in keys:
        obs = matchup_dict.get(key)
        if obs is None:
            continue
        timestamp = obs[0][0]
        _, ttup = get_time_tuple_utc(timestamp)
        print(ttup.tm_year, ttup.tm_mon, ttup.tm_mday)
        gfs_0, time_0, gfs_1, time_1 = get_bounding_gfs_files(timestamp)
        if gfs_0 is None:
            continue
        gfs_0 = convert_file(gfs_0)

        try:
            xr_dataset = xr.open_dataset(gfs_0)
        except Exception as exc:
            print(exc)
            return None

        fld = xr_dataset[fld_name]
        fld_2d = fld.values

        fld_i = fld_2d[::step, ::step]
        valid = np.logical_and(fld_i >= valid_range[0], fld_i <= valid_range[1])
        fld_i = fld_i[valid]
        fld_s.append(fld_i)

    fld_s = (np.concatenate(fld_s)).flatten()

    mean = np.mean(fld_s)
    fld_s -= mean
    std = np.std(fld_s)

    return mean, std


PROC_BATCH_SIZE = 10
BATCH_SIZE = 256


def run_mean_std(filename, matchup_dict=None):
    if matchup_dict is None:
        matchup = get_matchups(filename)
        matchup_dict = merge(matchup, None)
        matchup_dict = matchup_dict_to_nd(matchup_dict)
    train_dict, valid_dict = split_matchup(matchup_dict, perc=0.1)

    channels = ['14', '08', '02']
    valid_range = [[150, 350], [150, 350], [0, 100]]
    step = [100, 100, 400]

    f = open('/home/rink/stats.txt', 'w')

    for ch_idx, ch in enumerate(channels):
        mean, std = get_abi_mean_std(valid_dict, channels[ch_idx], valid_range[ch_idx], step[ch_idx])
        print(ch, ': ', mean, std)
        f.write('channel: %s, %f, %f\n' % (ch, mean, std))

    mean, std = get_gfs_mean_std(valid_dict, 'temperature', [150, 350])
    f.write('gfs temperature: \n')
    for i in range(NUM_VERT_LEVS):
        f.write('%f, %f\n' % (mean[i], std[i]))

    f.close()

    return matchup_dict


def get_solar_alt_func(pug, glons, glats, dt_obj):
    elems = pug.shape[1]
    lines = pug.shape[0]

    gx = np.arange(elems)[::100]
    gy = np.arange(lines)[::100]

    glons = glons[gy]
    glons = glons[:, gx]
    glats = glats[gy]
    glats = glats[:, gx]

    glons = glons.flatten()
    glats = glats.flatten()

    slr_alt = []
    for i in range(len(glons)):
        alt = get_altitude(glats[i], glons[i], dt_obj)
        slr_alt.append(alt)
    nd_alt = np.array(slr_alt)
    alt_fnc = interp2d(gy, gx, nd_alt, kind='cubic')

    return alt_fnc