Skip to content
Snippets Groups Projects
gfs_reader.py 15.67 KiB
import datetime
from datetime import timezone
import glob
import os
import numpy as np
import xarray as xr
# from util.util import value_to_index, homedir
from metpy.units import units

# gfs_directory = '/ships22/cloud/Ancil_Data/clavrx_ancil_data/dynamic/gfs/'
homedir = os.path.expanduser('~') + '/'
gfs_directory = homedir+'data/gfs/'
gfs_date_format = '%y%m%d'

# force incoming longitude to (0, 360) to match GFS
lon360 = True
# GFS half degree resolution
NX = 720
NY = 361
lat_coords = np.linspace(-90, 90, NY)
lon_coords = np.linspace(0, 359.5, NX)

plevs = np.array([10.0, 20.0, 30.0, 50.0, 70.0, 100.0, 150.0, 200.0, 250.0, 300.0,
                  350.0, 400.0, 450.0, 500.0, 550.0, 600.0, 650.0, 700.0, 750.0, 800.0,
                  850.0, 900.0, 925.0, 950.0, 975.0, 1000.0])
NZ = plevs.shape[0]


class MyGenericException(Exception):
    def __init__(self, message):
        self.message = message


# Return index of nda closest to value. nda must be 1d
def value_to_index(nda, value):
    diff = np.abs(nda - value)
    idx = np.argmin(diff)
    return idx


def get_timestamp(filename):
    toks = filename.split('.')
    tstr = toks[1].split('_')[0]
    dto = datetime.datetime.strptime(tstr, gfs_date_format + '%H').replace(tzinfo=timezone.utc)
    dto = dto + datetime.timedelta(hours=12)
    return dto.timestamp()


def get_time_tuple_utc(timestamp):
    dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
    return dt_obj, dt_obj.timetuple()


def get_bounding_gfs_files(timestamp):
    dt_obj, time_tup = get_time_tuple_utc(timestamp)
    dt_obj = dt_obj + datetime.timedelta(hours=12)
    date_str = dt_obj.strftime(gfs_date_format)
    dt_obj = datetime.datetime.strptime(date_str, gfs_date_format).replace(tzinfo=timezone.utc)

    dt_obj_r = dt_obj + datetime.timedelta(days=1)
    date_str_r = dt_obj_r.strftime(gfs_date_format)

    dt_obj_l = dt_obj - datetime.timedelta(days=1)
    date_str_l = dt_obj_l.strftime(gfs_date_format)

    flist_l = glob.glob(gfs_directory+'gfs.'+date_str_l+'??_F012.h5')
    flist = glob.glob(gfs_directory+'gfs.'+date_str+'??_F012.h5')
    flist_r = glob.glob(gfs_directory+'gfs.'+date_str_r+'??_F012.h5')
    filelist = flist_l + flist + flist_r
    if len(filelist) == 0:
        return None, None, None, None

    ftimes = []
    for pname in filelist:  # TODO: make better with regular expressions (someday)
        fname = os.path.split(pname)[1]
        ftimes.append(get_timestamp(fname))

    tarr = np.array(ftimes)
    sidxs = tarr.argsort()

    farr = np.array(filelist)
    farr = farr[sidxs]
    ftimes = tarr[sidxs]
    idxs = np.arange(len(filelist))

    above = ftimes >= timestamp
    if not above.any():
        return None, None, None, None
    tR = ftimes[above].min()

    below = ftimes <= timestamp
    if not below.any():
        return None, None, None, None
    tL = ftimes[below].max()

    iL = idxs[below].max()
    iR = iL + 1

    fList = farr.tolist()

    return fList[iL], ftimes[iL], fList[iR], ftimes[iR]


def get_horz_layer(xr_dataset, fld_name, press, lon_range=None, lat_range=None):
    p_idx = value_to_index(plevs, press)

    x_lo, x_hi = 0, NX
    y_lo, y_hi = 0, NY

    if lon_range is not None:
        lon_lo = lon_range[0]
        lon_hi = lon_range[1]
        lat_lo = lat_range[0]
        lat_hi = lat_range[1]
        if lon360:
            if lon_lo < 0:
                lon_lo += 360
            if lon_hi < 0:
                lon_hi += 360

        x_lo = value_to_index(lon_coords, lon_lo)
        x_hi = value_to_index(lon_coords, lon_hi)
        y_lo = value_to_index(lat_coords, lat_lo)
        y_hi = value_to_index(lat_coords, lat_hi)

    nda = xr_dataset[fld_name].values
    sub_nda = nda[y_lo:y_hi, x_lo:x_hi, p_idx]
    xra = xr.DataArray(sub_nda, dims=['latitude', 'longitude'],
                       coords={"latitude": lat_coords[y_lo:y_hi], "longitude": lon_coords[x_lo:x_hi]},
                       attrs={"description": fld_name, "units": 'm/s'})

    return xra


def get_horz_layer_s(xr_dataset, fld_names, press, lon_range=None, lat_range=None):
    p_idx = value_to_index(plevs, press)

    x_lo = 0
    x_hi = NX
    y_lo = 0
    y_hi = NY

    if lon_range is not None:
        lon_lo = lon_range[0]
        lon_hi = lon_range[1]
        lat_lo = lat_range[0]
        lat_hi = lat_range[1]
        if lon360:
            if lon_lo < 0:
                lon_lo += 360
            if lon_hi < 0:
                lon_hi += 360

        x_lo = value_to_index(lon_coords, lon_lo)
        x_hi = value_to_index(lon_coords, lon_hi)
        y_lo = value_to_index(lat_coords, lat_lo)
        y_hi = value_to_index(lat_coords, lat_hi)

    sub_fld_s = []
    for fld_name in fld_names:
        fld = xr_dataset[fld_name]
        sub_fld = fld[y_lo:y_hi, x_lo:x_hi, p_idx]
        sub_fld_s.append(sub_fld)

    sub_fld = xr.concat(sub_fld_s, 'channel')
    sub_fld = sub_fld.assign_coords(channel=fld_names, fakeDim2=lon_coords[x_lo:x_hi], fakeDim1=lat_coords[y_lo:y_hi])

    return sub_fld


def get_time_interpolated_layer(xr_dataset_s, time_s, time, fld_name, press, lon_range=None, lat_range=None, method='linear'):
    layer_s = []
    for ds in xr_dataset_s:
        lyr = get_horz_layer(ds, fld_name, press, lon_range=lon_range, lat_range=lat_range)
        layer_s.append(lyr)

    lyr = xr.concat(layer_s, 'time')
    lyr = lyr.assign_coords(time=time_s)

    intrp_lyr = lyr.interp(time=time, method=method)

    return intrp_lyr


def get_time_interpolated_layer_s(xr_dataset_s, time_s, time, fld_name_s, press, lon_range=None, lat_range=None, method='linear'):
    layer_s = []
    for ds in xr_dataset_s:
        lyr = get_horz_layer_s(ds, fld_name_s, press, lon_range=lon_range, lat_range=lat_range)
        layer_s.append(lyr)

    lyr = xr.concat(layer_s, 'time')
    lyr = lyr.assign_coords(time=time_s)

    intrp_lyr = lyr.interp(time=time, method='linear')

    return intrp_lyr


def get_vert_profile(xr_dataset, fld_name, lons, lats, method='linear'):
    if lon360:  # convert -180/+180 to 0,360
        lons = np.where(lons < 0, lons + 360, lons)

    fld = xr_dataset[fld_name]
    fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)

    dim2 = xr.DataArray(lons, dims='k')
    dim1 = xr.DataArray(lats, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method=method)

    return intrp_fld


def get_vert_profile_s(xr_dataset, fld_name_s, lons, lats, method='linear'):
    if lon360:  # convert -180,+180 to 0,360
        lons = np.where(lons < 0, lons + 360, lons)

    fld_s = []
    for fld_name in fld_name_s:
        fld = xr_dataset[fld_name]
        fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)
        fld_s.append(fld)

    fld = xr.concat(fld_s, 'fld_dim')

    dim2 = xr.DataArray(lons, dims='k')
    dim1 = xr.DataArray(lats, dims='k')

    intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, fakeDim0=plevs, method=method)

    return intrp_fld


def get_point(xr_dataset, fld_name, lons, lats, pres_s=None, method='nearest'):
    if lon360:  # convert -180/+180 to 0,360
        lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360

    lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
    lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)

    fld = xr_dataset[fld_name]

    if pres_s is not None:
        fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)
    else:
        fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords)

    dim1 = xr.DataArray(lats, dims='k')
    dim2 = xr.DataArray(lons, dims='k')
    if pres_s is not None:
        dim0 = xr.DataArray(pres_s, dims='k')
        intrp_fld = fld.interp(fakeDim0=dim0, fakeDim1=dim1, fakeDim2=dim2, method=method)
    else:
        intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, method=method)

    return intrp_fld


def get_point_s(xr_dataset, fld_name_s, lons, lats, pres_s=None, method='nearest'):
    if lon360:  # convert -180/+180 to 0,360
        lons = np.where(lons < 0, lons + 360, lons) # convert -180,180 to 0,360

    lat_coords = np.linspace(-90, 90, xr_dataset.fakeDim1.size)
    lon_coords = np.linspace(0, 359.5, xr_dataset.fakeDim2.size)

    fld_s = []
    for fld_name in fld_name_s:
        fld = xr_dataset[fld_name]
        if pres_s is not None:
            fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords, fakeDim0=plevs)
        else:
            fld = fld.assign_coords(fakeDim2=lon_coords, fakeDim1=lat_coords)
        fld_s.append(fld)
    fld = xr.concat(fld_s, 'fld_dim')

    dim1 = xr.DataArray(lats, dims='k')
    dim2 = xr.DataArray(lons, dims='k')
    if pres_s is not None:
        dim0 = xr.DataArray(pres_s, dims='k')
        intrp_fld = fld.interp(fakeDim0=dim0, fakeDim1=dim1, fakeDim2=dim2, method=method)
    else:
        intrp_fld = fld.interp(fakeDim1=dim1, fakeDim2=dim2, method=method)

    return intrp_fld


def get_time_interpolated_vert_profile(xr_dataset_s, time_s, fld_name, time, lons, lats, method='linear'):
    prof_s = []
    for ds in xr_dataset_s:
        vp = get_vert_profile(ds, fld_name, lons, lats, method=method)
        prof_s.append(vp)

    prof = xr.concat(prof_s, 'time')
    prof = prof.assign_coords(time=time_s)

    intrp_prof = prof.interp(time=time, method=method)
    intrp_prof = intrp_prof.values

    return intrp_prof


def get_time_interpolated_vert_profile_s(xr_dataset_s, time_s, fld_name_s, time, lons, lats, method='linear'):
    prof_s = []
    for ds in xr_dataset_s:
        vp = get_vert_profile_s(ds, fld_name_s, lons, lats, method=method)
        prof_s.append(vp)

    prof = xr.concat(prof_s, 'time')
    prof = prof.assign_coords(time=time_s)

    intrp_prof = prof.interp(time=time, method=method)
    intrp_prof = intrp_prof.values

    return intrp_prof


def get_time_interpolated_point(ds_0, ds_1, time0, time1, fld_name, time, lons, lats, method='linear'):
    vals_0 = get_point(ds_0, fld_name, lons, lats)
    vals_1 = get_point(ds_1, fld_name, lons, lats)

    vals = xr.concat([vals_0, vals_1], 'time')
    vals = vals.assign_coords(time=[time0, time1])

    intrp_vals = vals.interp(time=time, method=method)

    intrp_vals = intrp_vals.values

    return intrp_vals


def get_time_interpolated_point_s(xr_dataset_s, time_s, fld_name_s, time, lons, lats, method='linear'):
    pt_s = []
    for ds in xr_dataset_s:
        pt = get_point_s(ds, fld_name_s, lons, lats, method=method)
        pt_s.append(pt)

    pt = xr.concat(pt_s, 'time')
    pt = pt.assign_coords(time=time_s)

    intrp_pt = pt.interp(time=time, method=method)
    intrp_pt = intrp_pt.values

    return intrp_pt


def get_time_interpolated_voxel(xr_dataset_s, time_s, time, fld_name, lon, lat, press, x_width=3, y_width=3, z_width=3, method='linear'):
    vox_s = []
    for ds in xr_dataset_s:
        vox = get_voxel(ds, fld_name, lon, lat, press, x_width=x_width, y_width=y_width, z_width=z_width)
        vox_s.append(vox)

    vox = xr.concat(vox_s, 'time')
    vox = vox.assign_coords(time=time_s)

    intrp_vox = vox.interp(time=time, method=method)

    return intrp_vox


def get_voxel(xr_dataset, fld_name, lon, lat, press, x_width=3, y_width=3, z_width=3):
    if lon360:
        if lon < 0:
            lon += 360

    fld = xr_dataset[fld_name]

    p_c = value_to_index(plevs, press)
    x_c = value_to_index(lon_coords, lon)
    y_c = value_to_index(lat_coords, lat)
    y_h = int(y_width / 2)
    x_h = int(x_width / 2)
    p_h = int(z_width / 2)

    y_start = y_c - y_h
    x_start = x_c - x_h
    z_start = p_c - p_h
    if y_start < 0 or x_start < 0 or z_start < 0:
        return None

    y_stop = y_c + y_h + 1
    x_stop = x_c + x_h + 1
    z_stop = p_c + p_h + 1
    if y_stop > NY-1 or x_stop > NX-1 or z_stop > NZ-1:
        return None

    sub_fld = fld[y_start:y_stop, x_start:x_stop, z_start:z_stop]

    sub_fld = sub_fld.expand_dims('channel')
    sub_fld = sub_fld.assign_coords(channel=[fld_name], fakeDim2=lon_coords[x_start:x_stop],
                                    fakeDim1=lat_coords[y_start:y_stop], fakeDim0=plevs[z_start:z_stop])

    return sub_fld


def get_time_interpolated_voxel_s(xr_dataset_s, time_s, time, fld_name_s, lon, lat, press, x_width=3, y_width=3, z_width=3, method='linear'):
    vox_s = []
    for ds in xr_dataset_s:
        vox = get_voxel_s(ds, fld_name_s, lon, lat, press, x_width=x_width, y_width=y_width, z_width=z_width)
        vox_s.append(vox)

    vox = xr.concat(vox_s, 'time')
    vox = vox.assign_coords(time=time_s)

    intrp_vox = vox.interp(time=time, method=method)

    return intrp_vox


def get_voxel_s(xr_dataset, fld_name_s, lon, lat, press, x_width=3, y_width=3, z_width=3):
    if lon360:
        if lon < 0:
            lon += 360

    p_c = value_to_index(plevs, press)
    x_c = value_to_index(lon_coords, lon)
    y_c = value_to_index(lat_coords, lat)
    y_h = int(y_width / 2)
    x_h = int(x_width / 2)
    p_h = int(z_width / 2)

    y_start = y_c - y_h
    x_start = x_c - x_h
    z_start = p_c - p_h
    if y_start < 0 or x_start < 0 or z_start < 0:
        return None

    y_stop = y_c + y_h + 1
    x_stop = x_c + x_h + 1
    z_stop = p_c + p_h + 1
    if y_stop > NY-1 or x_stop > NX-1 or z_stop > NZ-1:
        return None

    sub_fld_s = []
    for name in fld_name_s:
        fld = xr_dataset[name]
        sub_fld = fld[y_start:y_stop, x_start:x_stop, z_start:z_stop]
        sub_fld_s.append(sub_fld)
    sub_fld = xr.concat(sub_fld_s, 'channel')

    sub_fld = sub_fld.assign_coords(channel=fld_name_s, fakeDim2=lon_coords[x_start:x_stop],
                                    fakeDim1=lat_coords[y_start:y_stop], fakeDim0=plevs[z_start:z_stop])

    return sub_fld


def get_volume(xr_dataset, fld_name, unit_str, press_range=None, lon_range=None, lat_range=None):

    x_lo, x_hi = 0, NX
    y_lo, y_hi = 0, NY
    z_lo, z_hi = 0, NZ

    if lon_range is not None:
        lon_lo = lon_range[0]
        lon_hi = lon_range[1]

        if lon360:
            if lon_lo < 0:
                lon_lo += 360
            if lon_hi < 0:
                lon_hi += 360

        x_lo = value_to_index(lon_coords, lon_lo)
        x_hi = value_to_index(lon_coords, lon_hi)

    if lat_range is not None:
        lat_lo = lat_range[0]
        lat_hi = lat_range[1]

        y_lo = value_to_index(lat_coords, lat_lo)
        y_hi = value_to_index(lat_coords, lat_hi)

    if press_range is not None:
        z_lo = value_to_index(plevs, press_range[0])
        z_hi = value_to_index(plevs, press_range[1])

    nda = xr_dataset[fld_name].values
    sub_nda = nda[y_lo:y_hi, x_lo:x_hi, z_lo:z_hi]
    xra = xr.DataArray(sub_nda, dims=['Latitude', 'Longitude', 'Pressure'],
                       coords={"Latitude": lat_coords[y_lo:y_hi], "Longitude": lon_coords[x_lo:x_hi], "Pressure": plevs[z_lo:z_hi]},
                       attrs={"description": fld_name, "units": unit_str})

    return xra


def volume_np_to_xr(nda, dims, press_range=None, lon_range=None, lat_range=None):

    x_lo, x_hi = 0, NX
    y_lo, y_hi = 0, NY
    z_lo, z_hi = 0, NZ

    if lon_range is not None:
        lon_lo = lon_range[0]
        lon_hi = lon_range[1]

        if lon360:
            if lon_lo < 0:
                lon_lo += 360
            if lon_hi < 0:
                lon_hi += 360

        x_lo = value_to_index(lon_coords, lon_lo)
        x_hi = value_to_index(lon_coords, lon_hi)

    if lat_range is not None:
        lat_lo = lat_range[0]
        lat_hi = lat_range[1]

        y_lo = value_to_index(lat_coords, lat_lo)
        y_hi = value_to_index(lat_coords, lat_hi)

    if press_range is not None:
        z_lo = value_to_index(plevs, press_range[0])
        z_hi = value_to_index(plevs, press_range[1])

    xra = xr.DataArray(nda, dims=dims,
                       coords={"Latitude": lat_coords[y_lo:y_hi], "Longitude": lon_coords[x_lo:x_hi], "Pressure": plevs[z_lo:z_hi]})

    return xra