import datetime
from datetime import timezone
import glob
import os
import numpy as np
import xarray as xr
from util.util import value_to_index

# gfs_directory = '/apollo/cloud/Ancil_Data/clavrx_ancil_data/dynamic/gfs/'
gfs_directory = '/Users/tomrink/data/gfs/'
gfs_date_format = '%y%m%d'

# force incoming longitude to (0, 360) to match GFS
lon360 = True
# GFS half degree resolution
NX = 720
NY = 361
lat_coords = np.linspace(-90, 90, NY)
lon_coords = np.linspace(0, 359.5, NX)

plevs = np.array([10.0, 20.0, 30.0, 50.0, 70.0, 100.0, 150.0, 200.0, 250.0, 300.0,
                  350.0, 400.0, 450.0, 500.0, 550.0, 600.0, 650.0, 700.0, 750.0, 800.0,
                  850.0, 900.0, 925.0, 950.0, 975.0, 1000.0])
NZ = plevs.shape[0]


class MyGenericException(Exception):
    def __init__(self, message):
        self.message = message


def get_timestamp(filename):
    toks = filename.split('.')
    tstr = toks[1].split('_')[0]
    dto = datetime.datetime.strptime(tstr, gfs_date_format + '%H').replace(tzinfo=timezone.utc)
    dto = dto + datetime.timedelta(hours=12)
    return dto.timestamp()


def get_time_tuple_utc(timestamp):
    dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
    return dt_obj, dt_obj.timetuple()


def get_bounding_gfs_files(timestamp):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)
    dt_obj = dt_obj + datetime.timedelta(hours=12)
    date_str = dt_obj.strftime(gfs_date_format)
    dt_obj = datetime.datetime.strptime(date_str, gfs_date_format).replace(tzinfo=timezone.utc)
    dt_obj_r = dt_obj + datetime.timedelta(days=1)
    date_str_r = dt_obj_r.strftime(gfs_date_format)
    dt_obj_l = dt_obj - datetime.timedelta(days=1)
    date_str_l = dt_obj_l.strftime(gfs_date_format)

    flist_l = glob.glob(gfs_directory+'gfs.'+date_str_l+'??_F012.h5')
    flist = glob.glob(gfs_directory+'gfs.'+date_str+'??_F012.h5')
    flist_r = glob.glob(gfs_directory+'gfs.'+date_str_r+'??_F012.h5')
    filelist = flist_l + flist + flist_r
    if len(filelist) == 0:
        return None, None, None, None

    ftimes = []
    for pname in filelist:  # TODO: make better with regular expressions (someday)
        fname = os.path.split(pname)[1]
        ftimes.append(get_timestamp(fname))

    tarr = np.array(ftimes)
    sidxs = tarr.argsort()

    farr = np.array(filelist)
    farr = farr[sidxs]
    ftimes = tarr[sidxs]
    idxs = np.arange(len(filelist))

    above = ftimes >= timestamp
    if not above.any():
        return None, None, None, None
    tR = ftimes[above].min()

    below = ftimes <= timestamp
    if not below.any():
        return None, None, None, None
    tL = ftimes[below].max()

    iL = idxs[below].max()
    iR = iL + 1

    fList = farr.tolist()

    return fList[iL], ftimes[iL], fList[iR], ftimes[iR]


def get_horz_layer(xr_dataset, fld_name, press, lon_range=None, lat_range=None):
    p_idx = value_to_index(plevs, press)
    fld = xr_dataset[fld_name]

    x_lo = 0
    x_hi = NX
    y_lo = 0
    y_hi = NY

    if lon_range is not None:
        lon_lo = lon_range[0]
        lon_hi = lon_range[1]
        lat_lo = lat_range[0]
        lat_hi = lat_range[1]
        if lon360:
            if lon_lo < 0:
                lon_lo += 360
            if lon_hi < 0:
                lon_hi += 360

        x_lo = value_to_index(lon_coords, lon_lo)
        x_hi = value_to_index(lon_coords, lon_hi)
        y_lo = value_to_index(lat_coords, lat_lo)
        y_hi = value_to_index(lat_coords, lat_hi)

    sub_fld = fld[y_lo:y_hi, x_lo:x_hi, p_idx]
    sub_fld = sub_fld.expand_dims('channel')

    sub_fld = sub_fld.assign_coords(channel=[fld_name], fakeDim2=lon_coords[x_lo:x_hi], fakeDim1=lat_coords[y_lo:y_hi])

    return sub_fld


def get_horz_layer_s(xr_dataset, fld_names, press, lon_range=None, lat_range=None):
    p_idx = value_to_index(plevs, press)

    x_lo = 0
    x_hi = NX
    y_lo = 0
    y_hi = NY

    if lon_range is not None:
        lon_lo = lon_range[0]
        lon_hi = lon_range[1]
        lat_lo = lat_range[0]
        lat_hi = lat_range[1]
        if lon360:
            if lon_lo < 0:
                lon_lo += 360
            if lon_hi < 0:
                lon_hi += 360

        x_lo = value_to_index(lon_coords, lon_lo)
        x_hi = value_to_index(lon_coords, lon_hi)
        y_lo = value_to_index(lat_coords, lat_lo)
        y_hi = value_to_index(lat_coords, lat_hi)

    sub_fld_s = []
    for fld_name in fld_names:
        fld = xr_dataset[fld_name]
        sub_fld = fld[y_lo:y_hi, x_lo:x_hi, p_idx]
        sub_fld_s.append(sub_fld)

    sub_fld = xr.concat(sub_fld_s, 'channel')
    sub_fld = sub_fld.assign_coords(channel=fld_names, fakeDim2=lon_coords[x_lo:x_hi], fakeDim1=lat_coords[y_lo:y_hi])

    return sub_fld


def get_time_interpolated_layer(xr_dataset_s, time_s, time, fld_name, press, lon_range=None, lat_range=None, method='linear'):
    layer_s = []
    for ds in xr_dataset_s:
        lyr = get_horz_layer(ds, fld_name, press, lon_range=lon_range, lat_range=lat_range)
        layer_s.append(lyr)

    lyr = xr.concat(layer_s, 'time')
    lyr = lyr.assign_coords(time=time_s)

    intrp_lyr = lyr.interp(time=time, method=method)

    return intrp_lyr


def get_time_interpolated_layer_s(xr_dataset_s, time_s, time, fld_name_s, press, lon_range=None, lat_range=None, method='linear'):
    layer_s = []
    for ds in xr_dataset_s:
        lyr = get_horz_layer_s(ds, fld_name_s, press, lon_range=lon_range, lat_range=lat_range)
        layer_s.append(lyr)

    lyr = xr.concat(layer_s, 'time')
    lyr = lyr.assign_coords(time=time_s)

    intrp_lyr = lyr.interp(time=time, method='linear')

    return intrp_lyr


def get_vert_profile(xr_dataset, fld_name, lons, lats, method='linear'):
    if lon360:  # convert -180/+180 to 0,360
        lons = np.where(lons < 0, lons + 360, lons)


    fld = xr_dataset[fld_name]
    fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)

    dim2 = xr.DataArray(lons, dims='k')
    dim1 = xr.DataArray(lats, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method=method)

    return intrp_fld


def get_vert_profile_s(xr_dataset, fld_name_s, lons, lats, method='linear'):
    if lon360:  # convert -180,+180 to 0,360
        lons = np.where(lons < 0, lons + 360, lons)


    fld_s = []
    for fld_name in fld_name_s:
        fld = xr_dataset[fld_name]
        fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)
        fld_s.append(fld)

    fld = xr.concat(fld_s, 'fld_dim')

    dim2 = xr.DataArray(lons, dims='k')
    dim1 = xr.DataArray(lats, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method=method)

    return intrp_fld


def get_point(xr_dataset, fld_name, lons, lats):
    if lon360:  # convert -180/+180 to 0,360
        lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360

    lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
    lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)

    fld = xr_dataset[fld_name]
    fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords)

    dim1 = xr.DataArray(lats, dims='k')
    dim2 = xr.DataArray(lons, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, method='linear')

    return intrp_fld


def get_time_interpolated_vert_profile(xr_dataset_s, time_s, fld_name, time, lons, lats, method='linear'):
    prof_s = []
    for ds in xr_dataset_s:
        vp = get_vert_profile(ds, fld_name, lons, lats)
        prof_s.append(vp)

    prof = xr.concat(prof_s, 'time')
    prof = prof.assign_coords(time=time_s)

    intrp_prof = prof.interp(time=time, method=method)
    intrp_prof = intrp_prof.values

    return intrp_prof


def get_time_interpolated_vert_profile_s(xr_dataset_s, time_s, fld_name_s, time, lons, lats, method='linear'):
    prof_s = []
    for ds in xr_dataset_s:
        vp = get_vert_profile_s(ds, fld_name_s, lons, lats)
        prof_s.append(vp)

    prof = xr.concat(prof_s, 'time')
    prof = prof.assign_coords(time=time_s)

    intrp_prof = prof.interp(time=time, method=method)
    intrp_prof = intrp_prof.values

    return intrp_prof


def get_time_interpolated_point(ds_0, ds_1, time0, time1, fld_name, time, lons, lats):
    vals_0 = get_point(ds_0, fld_name, lons, lats)
    vals_1 = get_point(ds_1, fld_name, lons, lats)

    vals = xr.concat([vals_0, vals_1], 'time')
    vals = vals.assign_coords(time=[time0, time1])

    intrp_vals = vals.interp(time=time, method='linear')

    intrp_vals = intrp_vals.values

    return intrp_vals


def get_time_interpolated_voxel(xr_dataset_s, time_s, time, fld_name, lon, lat, press, x_width=5, y_width=5, z_width=3, method='linear'):
    vox_s = []
    for ds in xr_dataset_s:
        vox = get_voxel(ds, fld_name, lon, lat, press, x_width=x_width, y_width=y_width, z_width=z_width)
        vox_s.append(vox)

    vox = xr.concat(vox_s, 'time')
    vox = vox.assign_coords(time=time_s)

    intrp_vox = vox.interp(time=time, method=method)

    return intrp_vox


def get_voxel(xr_dataset, fld_name, lon, lat, press, x_width=5, y_width=5, z_width=3):
    if lon360:
        if lon < 0:
            lon += 360

    fld = xr_dataset[fld_name]

    p_c = value_to_index(plevs, press)
    x_c = value_to_index(lon_coords, lon)
    y_c = value_to_index(lat_coords, lat)
    y_h = int(y_width / 2)
    x_h = int(x_width / 2)
    p_h = int(z_width / 2)

    y_start = y_c - y_h
    x_start = x_c - x_h
    z_start = p_c - p_h
    if y_start < 0 or x_start < 0 or z_start < 0:
        return None

    y_stop = y_c + y_h + 1
    x_stop = x_c + x_h + 1
    z_stop = p_c + p_h + 1
    if y_stop > NY-1 or x_stop > NX-1 or z_stop > NZ-1:
        return None

    sub_fld = fld[y_start:y_stop, x_start:x_stop, z_start:z_stop]

    sub_fld = sub_fld.expand_dims('channel')
    sub_fld = sub_fld.assign_coords(channel=[fld_name], fakeDim2=lon_coords[x_start:x_stop],
                                    fakeDim1=lat_coords[y_start:y_stop], fakeDim0=plevs[z_start:z_stop])

    return sub_fld


def get_time_interpolated_voxel_s(xr_dataset_s, time_s, time, fld_name_s, lon, lat, press, x_width=5, y_width=5, z_width=3, method='linear'):
    vox_s = []
    for ds in xr_dataset_s:
        vox = get_voxel_s(ds, fld_name_s, lon, lat, press, x_width=x_width, y_width=y_width, z_width=z_width)
        vox_s.append(vox)

    vox = xr.concat(vox_s, 'time')
    vox = vox.assign_coords(time=time_s)

    intrp_vox = vox.interp(time=time, method=method)

    return intrp_vox


def get_voxel_s(xr_dataset, fld_name_s, lon, lat, press, x_width=5, y_width=5, z_width=3):
    if lon360:
        if lon < 0:
            lon += 360

    p_c = value_to_index(plevs, press)
    x_c = value_to_index(lon_coords, lon)
    y_c = value_to_index(lat_coords, lat)
    y_h = int(y_width / 2)
    x_h = int(x_width / 2)
    p_h = int(z_width / 2)

    y_start = y_c - y_h
    x_start = x_c - x_h
    z_start = p_c - p_h
    if y_start < 0 or x_start < 0 or z_start < 0:
        return None

    y_stop = y_c + y_h + 1
    x_stop = x_c + x_h + 1
    z_stop = p_c + p_h + 1
    if y_stop > NY-1 or x_stop > NX-1 or z_stop > NZ-1:
        return None

    sub_fld_s = []
    for name in fld_name_s:
        fld = xr_dataset[name]
        sub_fld = fld[y_start:y_stop, x_start:x_stop, z_start:z_stop]
        sub_fld_s.append(sub_fld)
    sub_fld = xr.concat(sub_fld_s, 'channel')

    sub_fld = sub_fld.assign_coords(channel=fld_name_s, fakeDim2=lon_coords[x_start:x_stop],
                                    fakeDim1=lat_coords[y_start:y_stop], fakeDim0=plevs[z_start:z_stop])

    return sub_fld