Skip to content
Snippets Groups Projects
util.py 55.87 KiB
import metpy
import numpy as np
import xarray as xr
import datetime
from datetime import timezone
from metpy.units import units
from metpy.calc import thickness_hydrostatic
from collections import namedtuple
import os
import h5py
import pickle
from netCDF4 import Dataset
from util.setup import ancillary_path
from scipy.interpolate import RectBivariateSpline, interp2d
from scipy.ndimage import gaussian_filter
from scipy.signal import medfilt2d

LatLonTuple = namedtuple('LatLonTuple', ['lat', 'lon'])

homedir = os.path.expanduser('~') + '/'

# --- CLAVRx Radiometric parameters and metadata ------------------------------------------------
l1b_ds_list = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
               'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
               'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']

l1b_ds_types = {ds: 'f4' for ds in l1b_ds_list}
l1b_ds_fill = {l1b_ds_list[i]: -32767 for i in range(10)}
l1b_ds_fill.update({l1b_ds_list[i+10]: -32768 for i in range(5)})
l1b_ds_range = {ds: 'actual_range' for ds in l1b_ds_list}

# --- CLAVRx L2 parameters and metadata
ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'sensor_zenith_angle', 'supercooled_prob_acha',
           'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha', 'solar_zenith_angle',
           'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
           'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
           'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction', 'cloud_type', 'cloud_phase', 'cloud_mask']

ds_types = {ds_list[i]: 'f4' for i in range(23)}
ds_types.update({ds_list[i+23]: 'i1' for i in range(3)})
ds_fill = {ds_list[i]: -32768 for i in range(23)}
ds_fill.update({ds_list[i+23]: -128 for i in range(3)})
ds_range = {ds_list[i]: 'actual_range' for i in range(23)}
ds_range.update({ds_list[i]: None for i in range(3)})

ds_types.update(l1b_ds_types)
ds_fill.update(l1b_ds_fill)
ds_range.update(l1b_ds_range)

ds_types.update({'temp_3_9um_nom': 'f4'})
ds_types.update({'cloud_fraction': 'f4'})
ds_fill.update({'temp_3_9um_nom': -32767})
ds_fill.update({'cloud_fraction': -32768})
ds_range.update({'temp_3_9um_nom': 'actual_range'})
ds_range.update({'cloud_fraction': 'actual_range'})


class MyGenericException(Exception):
    def __init__(self, message):
        self.message = message


def make_tf_callable_generator(the_generator):
    class MyCallable:
        def __init__(self, gen):
            self.gen = gen

        def __call__(self):
            return self.gen

    the_callable = MyCallable(the_generator)
    return the_callable


def get_fill_attrs(name):
    if name in ds_fill:
        v = ds_fill[name]
        if v is None:
            return None, '_FillValue'
        else:
            return v, None
    else:
        return None, '_FillValue'


class GenericException(Exception):
    def __init__(self, message):
        self.message = message


class EarlyStop:
    def __init__(self, window_length=3, patience=5):
        self.patience = patience
        self.min = np.finfo(np.single).max
        self.cnt = 0
        self.cnt_wait = 0
        self.window = np.zeros(window_length, dtype=np.single)
        self.window.fill(np.nan)

    def check_stop(self, value):
        self.window[:-1] = self.window[1:]
        self.window[-1] = value

        if np.any(np.isnan(self.window)):
            return False

        ave = np.mean(self.window)
        if ave < self.min:
            self.min = ave
            self.cnt_wait = 0
            return False
        else:
            self.cnt_wait += 1

        if self.cnt_wait > self.patience:
            return True
        else:
            return False


def get_time_tuple_utc(timestamp):
    dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
    return dt_obj, dt_obj.timetuple()


def get_datetime_obj(dt_str, format_code='%Y-%m-%d_%H:%M'):
    dto = datetime.datetime.strptime(dt_str, format_code).replace(tzinfo=timezone.utc)
    return dto


def get_timestamp(dt_str, format_code='%Y-%m-%d_%H:%M'):
    dto = datetime.datetime.strptime(dt_str, format_code).replace(tzinfo=timezone.utc)
    ts = dto.timestamp()
    return ts


def add_time_range_to_filename(pathname, tstart, tend=None):
    filename = os.path.split(pathname)[1]
    w_o_ext, ext = os.path.splitext(filename)

    dt_obj, _ = get_time_tuple_utc(tstart)
    str_start = dt_obj.strftime('%Y%m%d%H')
    filename = w_o_ext+'_'+str_start

    if tend is not None:
        dt_obj, _ = get_time_tuple_utc(tend)
        str_end = dt_obj.strftime('%Y%m%d%H')
        filename = filename+'_'+str_end
    filename = filename+ext

    path = os.path.split(pathname)[0]
    path = path+'/'+filename
    return path


def haversine_np(lon1, lat1, lon2, lat2, earth_radius=6367.0):
    """
    Calculate the great circle distance between two points
    on the earth (specified in decimal degrees)

    (lon1, lat1) must be broadcastable with (lon2, lat2).
    """

    lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])

    dlon = lon2 - lon1
    dlat = lat2 - lat1

    a = np.sin(dlat/2.0)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2.0)**2

    c = 2.0 * np.arcsin(np.sqrt(a))
    km = earth_radius * c
    return km


def bin_data_by(a, b, bin_ranges):
    nbins = len(bin_ranges)
    binned_data = []

    for i in range(nbins):
        rng = bin_ranges[i]
        idxs = (b >= rng[0]) & (b < rng[1])
        binned_data.append(a[idxs])

    return binned_data


def bin_data_by_edges(a, b, edges):
    nbins = len(edges) - 1
    binned_data = []

    for i in range(nbins):
        idxs = (b >= edges[i]) & (b < edges[i+1])
        binned_data.append(a[idxs])

    return binned_data


def get_bin_ranges(lop, hip, bin_size=100):
    bin_ranges = []
    delp = hip - lop
    nbins = int(delp/bin_size)

    for i in range(nbins):
        rng = [lop + i*bin_size, lop + i*bin_size + bin_size]
        bin_ranges.append(rng)

    return bin_ranges


# t must be monotonic increasing
def get_breaks(t, threshold):
    t_0 = t[0:t.shape[0]-1]
    t_1 = t[1:t.shape[0]]
    d = t_1 - t_0
    idxs = np.nonzero(d > threshold)
    return idxs


# return indexes of ts where value is within ts[i] - threshold < value < ts[i] + threshold
# eventually, if necessary, fully vectorize (numpy) this is possible
# threshold units: seconds
def get_indexes_within_threshold(ts, value, threshold):
    idx_s = []
    t_s = []
    for k, v in enumerate(ts):
        if (ts[k] - threshold) <= value <= (ts[k] + threshold):
            idx_s.append(k)
            t_s.append(v)
    return idx_s, t_s


def pressure_to_altitude(pres, temp, prof_pres, prof_temp, sfc_pres=None, sfc_temp=None, sfc_elev=0):
    if not np.all(np.diff(prof_pres) > 0):
        raise GenericException("target pressure profile must be monotonic increasing")

    if pres < prof_pres[0]:
        raise GenericException("target pressure less than top of pressure profile")

    if temp is None:
        temp = np.interp(pres, prof_pres, prof_temp)

    i_top = np.argmax(np.extract(prof_pres <= pres, prof_pres)) + 1

    pres_s = prof_pres.tolist()
    temp_s = prof_temp.tolist()

    pres_s = [pres] + pres_s[i_top:]
    temp_s = [temp] + temp_s[i_top:]

    if sfc_pres is not None:
        if pres > sfc_pres:  # incoming pressure below surface
            return -999.0

        prof_pres = np.array(pres_s)
        prof_temp = np.array(temp_s)
        i_bot = prof_pres.shape[0] - 1

        if sfc_pres > prof_pres[i_bot]:  # surface below profile bottom
            pres_s = pres_s + [sfc_pres]
            temp_s = temp_s + [sfc_temp]
        else:
            idx = np.argmax(np.extract(prof_pres < sfc_pres, prof_pres))
            if sfc_temp is None:
                sfc_temp = np.interp(sfc_pres, prof_pres, prof_temp)
            pres_s = prof_pres.tolist()
            temp_s = prof_temp.tolist()
            pres_s = pres_s[0:idx+1] + [sfc_pres]
            temp_s = temp_s[0:idx+1] + [sfc_temp]

    prof_pres = np.array(pres_s)
    prof_temp = np.array(temp_s)

    prof_pres = prof_pres[::-1]
    prof_temp = prof_temp[::-1]

    prof_pres = prof_pres * units.hectopascal
    prof_temp = prof_temp * units.kelvin
    sfc_elev = sfc_elev * units.meter

    z = thickness_hydrostatic(prof_pres, prof_temp) + sfc_elev

    return z.magnitude


# http://fourier.eng.hmc.edu/e176/lectures/NM/node25.html
def minimize_quadratic(xa, xb, xc, ya, yb, yc):
    x_m = xb + 0.5*(((ya-yb)*(xc-xb)*(xc-xb) - (yc-yb)*(xb-xa)*(xb-xa)) / ((ya-yb)*(xc-xb) + (yc-yb)*(xb-xa)))
    return x_m


# Return index of nda closest to value. nda must be 1d
def value_to_index(nda, value):
    diff = np.abs(nda - value)
    idx = np.argmin(diff)
    return idx


def find_bin_index(nda, value_s):
    idxs = np.arange(nda.shape[0])
    iL_s = np.zeros(value_s.shape[0])
    iL_s[:,] = -1

    for k, v in enumerate(value_s):
        above = v >= nda
        if not above.any():
            continue

        below = v < nda
        if not below.any():
            continue

        iL = idxs[above].max()
        iL_s[k] = iL

    return iL_s.astype(np.int32)


# array solzen must be degrees, missing values must NaN. For small roughly 50x50km regions only
def is_day(solzen, test_angle=80.0):
    solzen = solzen.flatten()
    solzen = solzen[np.invert(np.isnan(solzen))]
    if len(solzen) == 0 or np.sum(solzen <= test_angle) < len(solzen):
        return False
    else:
        return True


# array solzen must be degrees, missing values must NaN. For small roughly 50x50km regions only
def is_night(solzen, test_angle=100.0):
    solzen = solzen.flatten()
    solzen = solzen[np.invert(np.isnan(solzen))]
    if len(solzen) == 0 or np.sum(solzen >= test_angle) < len(solzen):
        return False
    else:
        return True


def check_oblique(satzen, test_angle=70.0):
    satzen = satzen.flatten()
    satzen = satzen[np.invert(np.isnan(satzen))]
    if len(satzen) == 0 or np.sum(satzen <= test_angle) < len(satzen):
        return False
    else:
        return True


def get_median(tile_2d):
    tile = tile_2d.flatten()
    return np.nanmedian(tile)


def get_grid_values(h5f, grid_name, j_c, i_c, half_width, num_j=None, num_i=None, scale_factor_name='scale_factor', add_offset_name='add_offset',
                    fill_value_name='_FillValue', valid_range_name='valid_range', actual_range_name='actual_range', fill_value=None):
    hfds = h5f[grid_name]
    attrs = hfds.attrs
    if attrs is None:
        raise GenericException('No attributes object for: '+grid_name)

    ylen, xlen = hfds.shape

    if half_width is not None:
        j_l = j_c-half_width
        i_l = i_c-half_width
        if j_l < 0 or i_l < 0:
            return None

        j_r = j_c+half_width+1
        i_r = i_c+half_width+1
        if j_r >= ylen or i_r >= xlen:
            return None
    else:
        j_l = j_c
        j_r = j_c + num_j + 1
        i_l = i_c
        i_r = i_c + num_i + 1

    grd_vals = hfds[j_l:j_r, i_l:i_r]

    if fill_value_name is not None:
        attr = attrs.get(fill_value_name)
        if attr is not None:
            if np.isscalar(attr):
                fill_value = attr
            else:
                fill_value = attr[0]
            grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)
    elif fill_value is not None:
        grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)

    if valid_range_name is not None:
        attr = attrs.get(valid_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    if scale_factor_name is not None:
        attr = attrs.get(scale_factor_name)
        if attr is not None:
            if np.isscalar(attr):
                scale_factor = attr
            else:
                scale_factor = attr[0]
            grd_vals = grd_vals * scale_factor

    if add_offset_name is not None:
        attr = attrs.get(add_offset_name)
        if attr is not None:
            if np.isscalar(attr):
                add_offset = attr
            else:
                add_offset = attr[0]
            grd_vals = grd_vals + add_offset

    if actual_range_name is not None:
        attr = attrs.get(actual_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    return grd_vals


def get_grid_values_all(h5f, grid_name, scale_factor_name='scale_factor', add_offset_name='add_offset',
                        fill_value_name='_FillValue', valid_range_name='valid_range', actual_range_name='actual_range', fill_value=None, stride=None):
    hfds = h5f[grid_name]
    attrs = hfds.attrs

    if attrs is None:
        raise GenericException('No attributes object for: '+grid_name)

    if stride is None:
        grd_vals = hfds[:,]
    else:
        grd_vals = hfds[::stride, ::stride]

    if fill_value_name is not None:
        attr = attrs.get(fill_value_name)
        if attr is not None:
            if np.isscalar(attr):
                fill_value = attr
            else:
                fill_value = attr[0]
            grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)
    elif fill_value is not None:
        grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)

    if valid_range_name is not None:
        attr = attrs.get(valid_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    if scale_factor_name is not None:
        attr = attrs.get(scale_factor_name)
        if attr is not None:
            if np.isscalar(attr):
                scale_factor = attr
            else:
                scale_factor = attr[0]
            grd_vals = grd_vals * scale_factor

    if add_offset_name is not None:
        attr = attrs.get(add_offset_name)
        if attr is not None:
            if np.isscalar(attr):
                add_offset = attr
            else:
                add_offset = attr[0]
            grd_vals = grd_vals + add_offset

    if actual_range_name is not None:
        attr = attrs.get(actual_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    return grd_vals


# dt_str_0: start datetime string in format YYYY-MM-DD_HH:MM
# dt_str_1: stop datetime string, if not None num_steps is computed
# format_code: default '%Y-%m-%d_%H:%M'
# num_steps with increment of days, hours, minutes or seconds
# dt_str_1 and num_steps cannot both be None
# return num_steps+1 lists of datetime strings and timestamps (edges of a numpy histogram)
def make_times(dt_str_0, dt_str_1=None, format_code='%Y-%m-%d_%H:%M', num_steps=None, days=None, hours=None, minutes=None, seconds=None):
    if days is not None:
        inc = 86400*days
    elif hours is not None:
        inc = 3600*hours
    elif minutes is not None:
        inc = 60*minutes
    else:
        inc = seconds

    dt_obj_s = []
    ts_s = []
    dto_0 = datetime.datetime.strptime(dt_str_0, format_code).replace(tzinfo=timezone.utc)
    ts_0 = dto_0.timestamp()

    if dt_str_1 is not None:
        dto_1 = datetime.datetime.strptime(dt_str_1, format_code).replace(tzinfo=timezone.utc)
        ts_1 = dto_1.timestamp()
        num_steps = int((ts_1 - ts_0)/inc)

    dt_obj_s.append(dto_0)
    ts_s.append(ts_0)
    dto_last = dto_0
    for k in range(num_steps):
        dt_obj = dto_last + datetime.timedelta(seconds=inc)
        dt_obj_s.append(dt_obj)
        ts_s.append(dt_obj.timestamp())
        dto_last = dt_obj

    return dt_obj_s, ts_s


def normalize(data, param, mean_std_dict, copy=True):

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    if copy:
        data = data.copy()

    shape = data.shape
    data = data.flatten()

    mean, std, lo, hi = mean_std_dict.get(param)
    data -= mean
    data /= std

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def denormalize(data, param, mean_std_dict, copy=True):
    if copy:
        data = data.copy()

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    shape = data.shape
    data = data.flatten()

    mean, std, lo, hi = mean_std_dict.get(param)
    data *= std
    data += mean

    data = np.reshape(data, shape)

    return data


def scale(data, param, mean_std_dict, copy=True):
    if copy:
        data = data.copy()

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    shape = data.shape
    data = data.flatten()

    _, _, lo, hi = mean_std_dict.get(param)

    data -= lo
    data /= (hi - lo)

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def scale2(data, lo, hi, copy=True):
    if copy:
        data = data.copy()

    shape = data.shape
    data = data.flatten()

    data -= lo
    data /= (hi - lo)

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def descale(data, param, mean_std_dict, copy=True):
    if copy:
        data = data.copy()

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    shape = data.shape
    data = data.flatten()

    _, _, lo, hi = mean_std_dict.get(param)

    data *= (hi - lo)
    data += lo

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def add_noise(data, noise_scale=0.01, seed=None, copy=True):
    if copy:
        data = data.copy()

    shape = data.shape
    data = data.flatten()

    if seed is not None:
        np.random.seed(seed)
    rnd = np.random.normal(loc=0, scale=noise_scale, size=data.size)
    data += rnd

    data = np.reshape(data, shape)

    return data


geos_goes16_fd = None
geos_goes16_conus = None
geos_h08_fd = None

f = open(ancillary_path+'geos_crs_goes16_FD.pkl', 'rb')
geos_goes16_fd = pickle.load(f)
f.close()

f = open(ancillary_path+'geos_crs_goes16_CONUS.pkl', 'rb')
geos_goes16_conus = pickle.load(f)
f.close()

f = open(ancillary_path+'geos_crs_H08_FD.pkl', 'rb')
geos_h08_fd = pickle.load(f)
f.close()


def get_cartopy_crs(satellite, domain):
    if satellite == 'GOES16':
        if domain == 'FD':
            geos = geos_goes16_fd
            xlen = 5424
            xmin = -5433893.0
            xmax = 5433893.0
            ylen = 5424
            ymin = -5433893.0
            ymax = 5433893.0
        elif domain == 'CONUS':
            geos = geos_goes16_conus
            xlen = 2500
            xmin = -3626269.5
            xmax = 1381770.0
            ylen = 1500
            ymin = 1584175.9
            ymax = 4588198.0
    elif satellite == 'H08':
        geos = geos_h08_fd
        xlen = 5500
        xmin = -5498.99990119
        xmax = 5498.99990119
        ylen = 5500
        ymin = -5498.99990119
        ymax = 5498.99990119
    elif satellite == 'H09':
        geos = geos_h08_fd
        xlen = 5500
        xmin = -5498.99990119
        xmax = 5498.99990119
        ylen = 5500
        ymin = -5498.99990119
        ymax = 5498.99990119

    return geos, xlen, xmin, xmax, ylen, ymin, ymax


def concat_dict_s(t_dct_0, t_dct_1):
    keys_0 = list(t_dct_0.keys())
    nda_0 = np.array(keys_0)

    keys_1 = list(t_dct_1.keys())
    nda_1 = np.array(keys_1)

    comm_keys, comm0, comm1 = np.intersect1d(nda_0, nda_1, return_indices=True)

    comm_keys = comm_keys.tolist()

    for key in comm_keys:
        t_dct_1.pop(key)
    t_dct_0.update(t_dct_1)

    return t_dct_0


rho_water = 1000000.0  # g m^-3
rho_ice = 917000.0  # g m^-3

# real(kind=real4), parameter:: Rho_Water = 1.0    !g / m ^ 3
# real(kind=real4), parameter:: Rho_Ice = 0.917    !g / m ^ 3
#
# !--- compute
# cloud
# water
# path
# if (Iphase == 0) then
# Cwp_Dcomp(Elem_Idx, Line_Idx) = 0.55 * Tau * Reff * Rho_Water
# Lwp_Dcomp(Elem_Idx, Line_Idx) = 0.55 * Tau * Reff * Rho_Water
# else
# Cwp_Dcomp(Elem_Idx, Line_Idx) = 0.667 * Tau * Reff * Rho_Ice
# Iwp_Dcomp(Elem_Idx, Line_Idx) = 0.667 * Tau * Reff * Rho_Ice
# endif


def compute_lwc_iwc(iphase, reff, opd, geo_dz):
    xy_shape = iphase.shape

    iphase = iphase.flatten()
    keep_0 = np.invert(np.isnan(iphase))

    reff = reff.flatten()
    keep_1 = np.invert(np.isnan(reff))

    opd = opd.flatten()
    keep_2 = np.invert(np.isnan(opd))

    geo_dz = geo_dz.flatten()
    keep_3 = np.logical_and(np.invert(np.isnan(geo_dz)), geo_dz > 1.0)

    keep = keep_0 & keep_1 & keep_2 & keep_3

    lwp_dcomp = np.zeros(reff.shape[0])
    iwp_dcomp = np.zeros(reff.shape[0])
    lwp_dcomp[:] = np.nan
    iwp_dcomp[:] = np.nan

    ice = iphase == 1 & keep
    no_ice = iphase != 1 & keep

    # compute ice/liquid water path, g m-2
    reff *= 1.0e-06  # convert microns to meters

    iwp_dcomp[ice] = 0.667 * opd[ice] * rho_ice * reff[ice]
    lwp_dcomp[no_ice] = 0.55 * opd[no_ice] * rho_water * reff[no_ice]

    iwp_dcomp /= geo_dz
    lwp_dcomp /= geo_dz

    lwp_dcomp = lwp_dcomp.reshape(xy_shape)
    iwp_dcomp = iwp_dcomp.reshape(xy_shape)

    return lwp_dcomp, iwp_dcomp


# Example GOES file to retrieve GEOS parameters in MetPy form (CONUS)
exmp_file_conus = '/Users/tomrink/data/OR_ABI-L1b-RadC-M6C14_G16_s20193140811215_e20193140813588_c20193140814070.nc'
# Full Disk
exmp_file_fd = '/Users/tomrink/data/OR_ABI-L1b-RadF-M6C16_G16_s20212521800223_e20212521809542_c20212521809596.nc'

# keep for reference
# if domain == 'CONUS':
#     exmpl_ds = xr.open_dataset(exmp_file_conus)
# elif domain == 'FD':
#     exmpl_ds = xr.open_dataset(exmp_file_fd)
# mdat = exmpl_ds.metpy.parse_cf('Rad')
# geos = mdat.metpy.cartopy_crs
# xlen = mdat.x.values.size
# ylen = mdat.y.values.size
# exmpl_ds.close()

# Taiwan domain:
# lon, lat = 120.955098, 23.834310
# elem, line = (1789, 1505)
# # UR from Taiwan
# lon, lat = 135.0, 35.0
# elem_ur, line_ur = (2499, 995)
taiwan_i0 = 1079
taiwan_j0 = 995
taiwan_lenx = 1420
taiwan_leny = 1020
# geos.transform_point(135.0, 35.0, ccrs.PlateCarree(), False)
# geos.transform_point(106.61, 13.97, ccrs.PlateCarree(), False)
taiwain_extent = [-3342, -502, 1470, 3510]  # GEOS coordinates, not line, elem


# ------------ This code will not be needed when we implement a Fully Convolutional CNN -----------------------------------
# Generate and return tiles of name_list parameters
def make_for_full_domain_predict(h5f, name_list=None, satellite='GOES16', domain='FD', res_fac=1):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    grd_dct = {name: None for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        fill_value, fill_value_name = get_fill_attrs(ds_name)
        gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
            grd_dct[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    grd_dct_n = {name: [] for name in name_list}

    n_x = int(xlen/s_x) - 1
    n_y = int(ylen/s_y) - 1

    r_x = xlen - (n_x * s_x)
    x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    n_x -= x_d

    r_y = ylen - (n_y * s_y)
    y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    n_y -= y_d

    ll = [j_0 + j*s_y for j in range(n_y)]
    cc = [i_0 + i*s_x for i in range(n_x)]

    for ds_name in name_list:
        for j in range(n_y):
            j_ul = j * s_y
            j_ul_b = j_ul + w_y
            for i in range(n_x):
                i_ul = i * s_x
                i_ul_b = i_ul + w_x
                grd_dct_n[ds_name].append(grd_dct[ds_name][j_ul:j_ul_b, i_ul:i_ul_b])

    grd_dct = {name: None for name in name_list}
    for ds_name in name_list:
        grd_dct[ds_name] = np.stack(grd_dct_n[ds_name])

    return grd_dct, ll, cc


def make_for_full_domain_predict_viirs_clavrx(h5f, name_list=None, res_fac=1, day_night='DAY', use_dnb=False):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    ylen = h5f['scan_lines_along_track_direction'].shape[0]
    xlen = h5f['pixel_elements_along_scan_direction'].shape[0]

    use_nl_comp = False
    if (day_night == 'NIGHT' or day_night == 'AUTO') and use_dnb:
        use_nl_comp = True

    grd_dct = {name: None for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        name = ds_name

        if use_nl_comp:
            if ds_name == 'cld_reff_dcomp':
                name = 'cld_reff_nlcomp'
            elif ds_name == 'cld_opd_dcomp':
                name = 'cld_opd_nlcomp'

        fill_value, fill_value_name = get_fill_attrs(name)
        gvals = get_grid_values(h5f, name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
            grd_dct[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    # TODO: need to investigate discrepencies with compute_lwc_iwc
    # if use_nl_comp:
    #     cld_phase = get_grid_values(h5f, 'cloud_phase', j_0, i_0, None, num_j=ylen, num_i=xlen)
    #     cld_dz = get_grid_values(h5f, 'cld_geo_thick', j_0, i_0, None, num_j=ylen, num_i=xlen)
    #     reff = grd_dct['cld_reff_dcomp']
    #     opd = grd_dct['cld_opd_dcomp']
    #
    #     lwc_nlcomp, iwc_nlcomp = compute_lwc_iwc(cld_phase, reff, opd, cld_dz)
    #     grd_dct['iwc_dcomp'] = iwc_nlcomp
    #     grd_dct['lwc_dcomp'] = lwc_nlcomp

    grd_dct_n = {name: [] for name in name_list}

    n_x = int(xlen/s_x) - 1
    n_y = int(ylen/s_y) - 1

    r_x = xlen - (n_x * s_x)
    x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    n_x -= x_d

    r_y = ylen - (n_y * s_y)
    y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    n_y -= y_d

    ll = [j_0 + j*s_y for j in range(n_y)]
    cc = [i_0 + i*s_x for i in range(n_x)]

    for ds_name in name_list:
        for j in range(n_y):
            j_ul = j * s_y
            j_ul_b = j_ul + w_y
            for i in range(n_x):
                i_ul = i * s_x
                i_ul_b = i_ul + w_x
                grd_dct_n[ds_name].append(grd_dct[ds_name][j_ul:j_ul_b, i_ul:i_ul_b])

    grd_dct = {name: None for name in name_list}
    for ds_name in name_list:
        grd_dct[ds_name] = np.stack(grd_dct_n[ds_name])

    lats = get_grid_values(h5f, 'latitude', j_0, i_0, None, num_j=ylen, num_i=xlen)
    lons = get_grid_values(h5f, 'longitude', j_0, i_0, None, num_j=ylen, num_i=xlen)

    ll_2d, cc_2d = np.meshgrid(ll, cc, indexing='ij')

    lats = lats[ll_2d, cc_2d]
    lons = lons[ll_2d, cc_2d]

    return grd_dct, ll, cc, lats, lons


def make_for_full_domain_predict2(h5f, satellite='GOES16', domain='FD', res_fac=1):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = int(xlen/s_x)
    n_y = int(ylen/s_y)

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]
    satzen = satzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]

    return solzen, satzen
# -------------------------------------------------------------------------------------------


def prepare_evaluate(h5f, name_list, satellite='GOES16', domain='FD', res_fac=1, offset=0):
    w_x = 16
    w_y = 16
    i_0 = 0
    j_0 = 0
    s_x = int(w_x / res_fac)
    s_y = int(w_y / res_fac)

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    if satellite == 'H08':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0
    elif satellite == 'H09':
        xlen = taiwan_lenx
        ylen = taiwan_leny
        i_0 = taiwan_i0
        j_0 = taiwan_j0

    n_x = int(xlen/s_x) - 1
    n_y = int(ylen/s_y) - 1

    r_x = xlen - (n_x * s_x)
    x_d = 0 if r_x >= w_x else int((w_x - r_x)/s_x)
    n_x -= x_d

    r_y = ylen - (n_y * s_y)
    y_d = 0 if r_y >= w_y else int((w_y - r_y)/s_y)
    n_y -= y_d

    ll = [(offset+j_0) + j*s_y for j in range(n_y)]
    cc = [(offset+i_0) + i*s_x for i in range(n_x)]

    grd_dct_n = {name: [] for name in name_list}

    cnt_a = 0
    for ds_name in name_list:
        fill_value, fill_value_name = get_fill_attrs(ds_name)
        gvals = get_grid_values(h5f, ds_name, j_0, i_0, None, num_j=ylen, num_i=xlen, fill_value_name=fill_value_name, fill_value=fill_value)
        if gvals is not None:
            grd_dct_n[ds_name] = gvals
            cnt_a += 1

    if cnt_a > 0 and cnt_a != len(name_list):
        raise GenericException('weirdness')

    solzen = get_grid_values(h5f, 'solar_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    satzen = get_grid_values(h5f, 'sensor_zenith_angle', j_0, i_0, None, num_j=ylen, num_i=xlen)
    solzen = solzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]
    satzen = satzen[0:(n_y-1)*s_y:s_y, 0:(n_x-1)*s_x:s_x]

    grd_dct = {name: None for name in name_list}
    for ds_name in name_list:
        grd_dct[ds_name] = np.stack(grd_dct_n[ds_name])

    return grd_dct, solzen, satzen, ll, cc


flt_level_ranges_str = {k: None for k in range(5)}
flt_level_ranges_str[0] = '0_2000'
flt_level_ranges_str[1] = '2000_4000'
flt_level_ranges_str[2] = '4000_6000'
flt_level_ranges_str[3] = '6000_8000'
flt_level_ranges_str[4] = '8000_15000'

# flt_level_ranges_str = {k: None for k in range(1)}
# flt_level_ranges_str[0] = 'column'


def get_cf_nav_parameters(satellite='GOES16', domain='FD'):
    param_dct = None

    if satellite == 'H08':  # We presently only have FD
        param_dct = {'semi_major_axis': 6378.137,
                     'semi_minor_axis': 6356.7523,
                     'perspective_point_height': 35785.863,
                     'latitude_of_projection_origin': 0.0,
                     'longitude_of_projection_origin': 140.7,
                     'inverse_flattening': 298.257,
                     'sweep_angle_axis': 'y',
                     'x_scale_factor': 5.58879902955962e-05,
                     'x_add_offset': -0.153719917308037,
                     'y_scale_factor': -5.58879902955962e-05,
                     'y_add_offset': 0.153719917308037}
    elif satellite == 'H09':
        param_dct = {'semi_major_axis': 6378.137,
                     'semi_minor_axis': 6356.7523,
                     'perspective_point_height': 35785.863,
                     'latitude_of_projection_origin': 0.0,
                     'longitude_of_projection_origin': 140.7,
                     'inverse_flattening': 298.257,
                     'sweep_angle_axis': 'y',
                     'x_scale_factor': 5.58879902955962e-05,
                     'x_add_offset': -0.153719917308037,
                     'y_scale_factor': -5.58879902955962e-05,
                     'y_add_offset': 0.153719917308037}
    elif satellite == 'GOES16':
        if domain == 'CONUS':
            param_dct = {'semi_major_axis': 6378137.0,
                         'semi_minor_axis': 6356752.31414,
                         'perspective_point_height': 35786023.0,
                         'latitude_of_projection_origin': 0.0,
                         'longitude_of_projection_origin': -75,
                         'inverse_flattening': 298.257,
                         'sweep_angle_axis': 'x',
                         'x_scale_factor': 5.6E-05,
                         'x_add_offset': -0.101332,
                         'y_scale_factor': -5.6E-05,
                         'y_add_offset': 0.128212}
        elif domain == 'FD':
            param_dct = {'semi_major_axis': 6378137.0,
                         'semi_minor_axis': 6356752.31414,
                         'perspective_point_height': 35786023.0,
                         'latitude_of_projection_origin': 0.0,
                         'longitude_of_projection_origin': -75,
                         'inverse_flattening': 298.257,
                         'sweep_angle_axis': 'x',
                         'x_scale_factor': 5.6E-05,
                         'x_add_offset': -0.151844,
                         'y_scale_factor': -5.6E-05,
                         'y_add_offset': 0.151844}

    return param_dct


def write_icing_file(clvrx_str_time, output_dir, preds_dct, probs_dct, x, y, lons, lats, elems, lines):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.h5'
    h5f_out = h5py.File(outfile_name, 'w')

    dim_0_name = 'x'
    dim_1_name = 'y'

    prob_s = []
    pred_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]
        pred_s.append(preds)
        icing_pred_ds = h5f_out.create_dataset('icing_prediction_level_'+flt_level_ranges_str[flvl], data=preds, dtype='i2')
        icing_pred_ds.attrs.create('coordinates', data='y x')
        icing_pred_ds.attrs.create('grid_mapping', data='Projection')
        icing_pred_ds.attrs.create('missing', data=-1)
        icing_pred_ds.dims[0].label = dim_0_name
        icing_pred_ds.dims[1].label = dim_1_name

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)
        icing_prob_ds = h5f_out.create_dataset('icing_probability_level_'+flt_level_ranges_str[flvl], data=probs, dtype='f4')
        icing_prob_ds.attrs.create('coordinates', data='y x')
        icing_prob_ds.attrs.create('grid_mapping', data='Projection')
        icing_prob_ds.attrs.create('missing', data=-1.0)
        icing_prob_ds.dims[0].label = dim_0_name
        icing_prob_ds.dims[1].label = dim_1_name

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)

    icing_prob_ds = h5f_out.create_dataset('max_icing_probability_column', data=max_prob, dtype='f4')
    icing_prob_ds.attrs.create('coordinates', data='y x')
    icing_prob_ds.attrs.create('grid_mapping', data='Projection')
    icing_prob_ds.attrs.create('missing', data=-1.0)
    icing_prob_ds.dims[0].label = dim_0_name
    icing_prob_ds.dims[1].label = dim_1_name

    max_lvl = np.argmax(prob_s, axis=2)

    icing_pred_ds = h5f_out.create_dataset('max_icing_probability_level', data=max_lvl, dtype='i2')
    icing_pred_ds.attrs.create('coordinates', data='y x')
    icing_pred_ds.attrs.create('grid_mapping', data='Projection')
    icing_pred_ds.attrs.create('missing', data=-1)
    icing_pred_ds.dims[0].label = dim_0_name
    icing_pred_ds.dims[1].label = dim_1_name

    lon_ds = h5f_out.create_dataset('longitude', data=lons, dtype='f4')
    lon_ds.attrs.create('units', data='degrees_east')
    lon_ds.attrs.create('long_name', data='icing prediction longitude')
    lon_ds.dims[0].label = dim_0_name
    lon_ds.dims[1].label = dim_1_name

    lat_ds = h5f_out.create_dataset('latitude', data=lats, dtype='f4')
    lat_ds.attrs.create('units', data='degrees_north')
    lat_ds.attrs.create('long_name', data='icing prediction latitude')
    lat_ds.dims[0].label = dim_0_name
    lat_ds.dims[1].label = dim_1_name

    proj_ds = h5f_out.create_dataset('Projection', data=0, dtype='b')
    proj_ds.attrs.create('long_name', data='Himawari Imagery Projection')
    proj_ds.attrs.create('grid_mapping_name', data='geostationary')
    proj_ds.attrs.create('sweep_angle_axis', data='y')
    proj_ds.attrs.create('units', data='rad')
    proj_ds.attrs.create('semi_major_axis', data=6378.137)
    proj_ds.attrs.create('semi_minor_axis', data=6356.7523)
    proj_ds.attrs.create('inverse_flattening', data=298.257)
    proj_ds.attrs.create('perspective_point_height', data=35785.863)
    proj_ds.attrs.create('latitude_of_projection_origin', data=0.0)
    proj_ds.attrs.create('longitude_of_projection_origin', data=140.7)
    proj_ds.attrs.create('CFAC', data=20466275)
    proj_ds.attrs.create('LFAC', data=20466275)
    proj_ds.attrs.create('COFF', data=2750.5)
    proj_ds.attrs.create('LOFF', data=2750.5)

    if x is not None:
        x_ds = h5f_out.create_dataset('x', data=x, dtype='f8')
        x_ds.dims[0].label = dim_0_name
        x_ds.attrs.create('units', data='rad')
        x_ds.attrs.create('standard_name', data='projection_x_coordinate')
        x_ds.attrs.create('long_name', data='GOES PUG W-E fixed grid viewing angle')
        x_ds.attrs.create('scale_factor', data=5.58879902955962e-05)
        x_ds.attrs.create('add_offset', data=-0.153719917308037)
        x_ds.attrs.create('CFAC', data=20466275)
        x_ds.attrs.create('COFF', data=2750.5)

        y_ds = h5f_out.create_dataset('y', data=y, dtype='f8')
        y_ds.dims[0].label = dim_1_name
        y_ds.attrs.create('units', data='rad')
        y_ds.attrs.create('standard_name', data='projection_y_coordinate')
        y_ds.attrs.create('long_name', data='GOES PUG S-N fixed grid viewing angle')
        y_ds.attrs.create('scale_factor', data=-5.58879902955962e-05)
        y_ds.attrs.create('add_offset', data=0.153719917308037)
        y_ds.attrs.create('LFAC', data=20466275)
        y_ds.attrs.create('LOFF', data=2750.5)

    if elems is not None:
        elem_ds = h5f_out.create_dataset('elems', data=elems, dtype='i2')
        elem_ds.dims[0].label = dim_0_name
        line_ds = h5f_out.create_dataset('lines', data=lines, dtype='i2')
        line_ds.dims[0].label = dim_1_name
        pass

    h5f_out.close()


def write_icing_file_nc4(clvrx_str_time, output_dir, preds_dct, probs_dct,
                         x, y, lons, lats, elems, lines, satellite='GOES16', domain='CONUS',
                         has_time=False, use_nan=False, prob_thresh=0.5, bt_10_4=None):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.nc'
    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'time y x'

    dim_0 = rootgrp.createDimension(dim_0_name, size=x.shape[0])
    dim_1 = rootgrp.createDimension(dim_1_name, size=y.shape[0])
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    prob_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]

        icing_pred_ds = rootgrp.createVariable('icing_prediction_level_'+flt_level_ranges_str[flvl], 'i2', var_dim_list)
        icing_pred_ds.setncattr('coordinates', geo_coords)
        icing_pred_ds.setncattr('grid_mapping', 'Projection')
        icing_pred_ds.setncattr('missing', -1)
        if has_time:
            preds = preds.reshape((1, y.shape[0], x.shape[0]))
        icing_pred_ds[:,] = preds

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)

        icing_prob_ds = rootgrp.createVariable('icing_probability_level_'+flt_level_ranges_str[flvl], 'f4', var_dim_list)
        icing_prob_ds.setncattr('coordinates', geo_coords)
        icing_prob_ds.setncattr('grid_mapping', 'Projection')
        if not use_nan:
            icing_prob_ds.setncattr('missing', -1.0)
        else:
            icing_prob_ds.setncattr('missing', np.nan)
        if has_time:
            probs = probs.reshape((1, y.shape[0], x.shape[0]))
        if use_nan:
            probs = np.where(probs < prob_thresh, np.nan, probs)
        icing_prob_ds[:,] = probs

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)
    if use_nan:
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)
    if has_time:
        max_prob = max_prob.reshape(1, y.shape[0], x.shape[0])

    icing_prob_ds = rootgrp.createVariable('max_icing_probability_column', 'f4', var_dim_list)
    icing_prob_ds.setncattr('coordinates', geo_coords)
    icing_prob_ds.setncattr('grid_mapping', 'Projection')
    if not use_nan:
        icing_prob_ds.setncattr('missing', -1.0)
    else:
        icing_prob_ds.setncattr('missing', np.nan)
    icing_prob_ds[:,] = max_prob

    prob_s = np.where(prob_s < prob_thresh, -1.0, prob_s)
    max_lvl = np.where(np.all(prob_s == -1, axis=2), -1, np.argmax(prob_s, axis=2))
    if use_nan:
        max_lvl = np.where(max_lvl == -1, np.nan, max_lvl)
    if has_time:
        max_lvl = max_lvl.reshape((1, y.shape[0], x.shape[0]))

    icing_pred_ds = rootgrp.createVariable('max_icing_probability_level', 'i2', var_dim_list)
    icing_pred_ds.setncattr('coordinates', geo_coords)
    icing_pred_ds.setncattr('grid_mapping', 'Projection')
    icing_pred_ds.setncattr('missing', -1)
    icing_pred_ds[:,] = max_lvl

    if bt_10_4 is not None:
        bt_ds = rootgrp.createVariable('bt_10_4', 'f4', var_dim_list)
        bt_ds.setncattr('coordinates', geo_coords)
        bt_ds.setncattr('grid_mapping', 'Projection')
        bt_ds[:,] = bt_10_4

    lon_ds = rootgrp.createVariable('longitude', 'f4', [dim_1_name, dim_0_name])
    lon_ds.units = 'degrees_east'
    lon_ds[:,] = lons

    lat_ds = rootgrp.createVariable('latitude', 'f4', [dim_1_name, dim_0_name])
    lat_ds.units = 'degrees_north'
    lat_ds[:,] = lats

    cf_nav_dct = get_cf_nav_parameters(satellite, domain)

    if satellite == 'H08':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'H09':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'GOES16':
        long_name = 'GOES-16/17 Imagery Projection'

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('long_name', long_name)
    proj_ds.setncattr('grid_mapping_name', 'geostationary')
    proj_ds.setncattr('sweep_angle_axis', cf_nav_dct['sweep_angle_axis'])
    proj_ds.setncattr('semi_major_axis', cf_nav_dct['semi_major_axis'])
    proj_ds.setncattr('semi_minor_axis', cf_nav_dct['semi_minor_axis'])
    proj_ds.setncattr('inverse_flattening', cf_nav_dct['inverse_flattening'])
    proj_ds.setncattr('perspective_point_height', cf_nav_dct['perspective_point_height'])
    proj_ds.setncattr('latitude_of_projection_origin', cf_nav_dct['latitude_of_projection_origin'])
    proj_ds.setncattr('longitude_of_projection_origin', cf_nav_dct['longitude_of_projection_origin'])

    if x is not None:
        x_ds = rootgrp.createVariable(dim_0_name, 'f8', [dim_0_name])
        x_ds.units = 'rad'
        x_ds.setncattr('axis', 'X')
        x_ds.setncattr('standard_name', 'projection_x_coordinate')
        x_ds.setncattr('long_name', 'fixed grid viewing angle')
        x_ds.setncattr('scale_factor', cf_nav_dct['x_scale_factor'])
        x_ds.setncattr('add_offset', cf_nav_dct['x_add_offset'])
        x_ds[:] = x

        y_ds = rootgrp.createVariable(dim_1_name, 'f8', [dim_1_name])
        y_ds.units = 'rad'
        y_ds.setncattr('axis', 'Y')
        y_ds.setncattr('standard_name', 'projection_y_coordinate')
        y_ds.setncattr('long_name', 'fixed grid viewing angle')
        y_ds.setncattr('scale_factor', cf_nav_dct['y_scale_factor'])
        y_ds.setncattr('add_offset', cf_nav_dct['y_add_offset'])
        y_ds[:] = y

    if elems is not None:
        elem_ds = rootgrp.createVariable('elems', 'i2', [dim_0_name])
        elem_ds[:] = elems
        line_ds = rootgrp.createVariable('lines', 'i2', [dim_1_name])
        line_ds[:] = lines
        pass

    rootgrp.close()


def write_icing_file_nc4_viirs(clvrx_str_time, output_dir, preds_dct, probs_dct, lons, lats,
                               has_time=False, use_nan=False, prob_thresh=0.5, bt_10_4=None):
    outfile_name = output_dir + 'icing_prediction_'+clvrx_str_time+'.nc'
    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'longitude latitude'

    dim_1_len, dim_0_len = lons.shape

    dim_0 = rootgrp.createDimension(dim_0_name, size=dim_0_len)
    dim_1 = rootgrp.createDimension(dim_1_name, size=dim_1_len)
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    prob_s = []

    flt_lvls = list(preds_dct.keys())
    for flvl in flt_lvls:
        preds = preds_dct[flvl]

        icing_pred_ds = rootgrp.createVariable('icing_prediction_level_'+flt_level_ranges_str[flvl], 'i2', var_dim_list)
        icing_pred_ds.setncattr('coordinates', geo_coords)
        icing_pred_ds.setncattr('grid_mapping', 'Projection')
        icing_pred_ds.setncattr('missing', -1)
        if has_time:
            preds = preds.reshape((1, dim_1_len, dim_0_len))
        icing_pred_ds[:,] = preds

    for flvl in flt_lvls:
        probs = probs_dct[flvl]
        prob_s.append(probs)

        icing_prob_ds = rootgrp.createVariable('icing_probability_level_'+flt_level_ranges_str[flvl], 'f4', var_dim_list)
        icing_prob_ds.setncattr('coordinates', geo_coords)
        icing_prob_ds.setncattr('grid_mapping', 'Projection')
        if not use_nan:
            icing_prob_ds.setncattr('missing', -1.0)
        else:
            icing_prob_ds.setncattr('missing', np.nan)
        if has_time:
            probs = probs.reshape((1, dim_1_len, dim_0_len))
        if use_nan:
            probs = np.where(probs < prob_thresh, np.nan, probs)
        icing_prob_ds[:,] = probs

    prob_s = np.stack(prob_s, axis=-1)
    max_prob = np.max(prob_s, axis=2)
    if use_nan:
        max_prob = np.where(max_prob < prob_thresh, np.nan, max_prob)
    if has_time:
        max_prob = max_prob.reshape(1, dim_1_len, dim_0_len)

    icing_prob_ds = rootgrp.createVariable('max_icing_probability_column', 'f4', var_dim_list)
    icing_prob_ds.setncattr('coordinates', geo_coords)
    icing_prob_ds.setncattr('grid_mapping', 'Projection')
    if not use_nan:
        icing_prob_ds.setncattr('missing', -1.0)
    else:
        icing_prob_ds.setncattr('missing', np.nan)
    icing_prob_ds[:,] = max_prob

    prob_s = np.where(prob_s < prob_thresh, -1.0, prob_s)
    max_lvl = np.where(np.all(prob_s == -1, axis=2), -1, np.argmax(prob_s, axis=2))
    if use_nan:
        max_lvl = np.where(max_lvl == -1, np.nan, max_lvl)
    if has_time:
        max_lvl = max_lvl.reshape((1, dim_1_len, dim_0_len))

    icing_pred_ds = rootgrp.createVariable('max_icing_probability_level', 'i2', var_dim_list)
    icing_pred_ds.setncattr('coordinates', geo_coords)
    icing_pred_ds.setncattr('grid_mapping', 'Projection')
    icing_pred_ds.setncattr('missing', -1)
    icing_pred_ds[:,] = max_lvl

    if bt_10_4 is not None:
        bt_ds = rootgrp.createVariable('bt_10_4', 'f4', var_dim_list)
        bt_ds.setncattr('coordinates', geo_coords)
        bt_ds.setncattr('grid_mapping', 'Projection')
        bt_ds[:,] = bt_10_4

    lon_ds = rootgrp.createVariable('longitude', 'f4', [dim_1_name, dim_0_name])
    lon_ds.units = 'degrees_east'
    lon_ds[:,] = lons

    lat_ds = rootgrp.createVariable('latitude', 'f4', [dim_1_name, dim_0_name])
    lat_ds.units = 'degrees_north'
    lat_ds[:,] = lats

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('grid_mapping_name', 'latitude_longitude')

    rootgrp.close()


def write_cld_prods_file_nc4(clvrx_str_time, outfile_name, cloud_fraction, cloud_frac_opd,
                            x, y, elems, lines, satellite='GOES16', domain='CONUS',
                            has_time=False):

    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'time y x'

    dim_0 = rootgrp.createDimension(dim_0_name, size=x.shape[0])
    dim_1 = rootgrp.createDimension(dim_1_name, size=y.shape[0])
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    cld_frac_ds = rootgrp.createVariable('cloud_fraction', 'i1', var_dim_list)
    cld_frac_ds.setncattr('coordinates', geo_coords)
    cld_frac_ds.setncattr('grid_mapping', 'Projection')
    cld_frac_ds.setncattr('missing', -1)
    if has_time:
        cloud_fraction = cloud_fraction.reshape((1, y.shape[0], x.shape[0]))
    cld_frac_ds[:, ] = cloud_fraction

    cld_frac_opd_ds = rootgrp.createVariable('cldy_fraciion_opd', 'f4', var_dim_list)
    cld_frac_opd_ds.setncattr('coordinates', geo_coords)
    cld_frac_opd_ds.setncattr('grid_mapping', 'Projection')
    cld_frac_opd_ds.setncattr('missing', -1)
    if has_time:
        cloud_frac_opd = cloud_frac_opd.reshape((1, y.shape[0], x.shape[0]))
    cld_frac_opd_ds[:, ] = cloud_frac_opd

    cf_nav_dct = get_cf_nav_parameters(satellite, domain)

    if satellite == 'H08':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'H09':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'GOES16':
        long_name = 'GOES-16/17 Imagery Projection'

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('long_name', long_name)
    proj_ds.setncattr('grid_mapping_name', 'geostationary')
    proj_ds.setncattr('sweep_angle_axis', cf_nav_dct['sweep_angle_axis'])
    proj_ds.setncattr('semi_major_axis', cf_nav_dct['semi_major_axis'])
    proj_ds.setncattr('semi_minor_axis', cf_nav_dct['semi_minor_axis'])
    proj_ds.setncattr('inverse_flattening', cf_nav_dct['inverse_flattening'])
    proj_ds.setncattr('perspective_point_height', cf_nav_dct['perspective_point_height'])
    proj_ds.setncattr('latitude_of_projection_origin', cf_nav_dct['latitude_of_projection_origin'])
    proj_ds.setncattr('longitude_of_projection_origin', cf_nav_dct['longitude_of_projection_origin'])

    if x is not None:
        x_ds = rootgrp.createVariable(dim_0_name, 'f8', [dim_0_name])
        x_ds.units = 'rad'
        x_ds.setncattr('axis', 'X')
        x_ds.setncattr('standard_name', 'projection_x_coordinate')
        x_ds.setncattr('long_name', 'fixed grid viewing angle')
        x_ds.setncattr('scale_factor', cf_nav_dct['x_scale_factor'])
        x_ds.setncattr('add_offset', cf_nav_dct['x_add_offset'])
        x_ds[:] = x

        y_ds = rootgrp.createVariable(dim_1_name, 'f8', [dim_1_name])
        y_ds.units = 'rad'
        y_ds.setncattr('axis', 'Y')
        y_ds.setncattr('standard_name', 'projection_y_coordinate')
        y_ds.setncattr('long_name', 'fixed grid viewing angle')
        y_ds.setncattr('scale_factor', cf_nav_dct['y_scale_factor'])
        y_ds.setncattr('add_offset', cf_nav_dct['y_add_offset'])
        y_ds[:] = y

    if elems is not None:
        elem_ds = rootgrp.createVariable('elems', 'i2', [dim_0_name])
        elem_ds[:] = elems
        line_ds = rootgrp.createVariable('lines', 'i2', [dim_1_name])
        line_ds[:] = lines
        pass

    rootgrp.close()


def downscale_2x(original, smoothing=False, samples_axis_first=False):
    # if smoothing:
    #     original = scipy.ndimage.gaussian_filter(original, sigma = 1/2)
    if not samples_axis_first:
        lr = np.nanmean(np.array([original[0::2,0::2],
              original[1::2,0::2],
              original[0::2,1::2],
              original[1::2,1::2]]),axis=0).squeeze()
    elif samples_axis_first:
        lr = np.nanmean(np.array([original[:,0::2,0::2],
              original[:,1::2,0::2],
              original[:,0::2,1::2],
              original[:,1::2,1::2]]),axis=0).squeeze()
    return lr


def resample(x, y, z, x_new, y_new):
    z_intrp = []
    for k in range(z.shape[0]):
        z_k = z[k, :, :]
        f = RectBivariateSpline(x, y, z_k)
        z_intrp.append(f(x_new, y_new))
    return np.stack(z_intrp)


def resample_one(x, y, z, x_new, y_new):
    f = RectBivariateSpline(x, y, z)
    return f(x_new, y_new)


def resample_2d_linear(x, y, z, x_new, y_new):
    z_intrp = []
    for k in range(z.shape[0]):
        z_k = z[k, :, :]
        f = interp2d(x, y, z_k)
        z_intrp.append(f(x_new, y_new))
    return np.stack(z_intrp)


def resample_2d_linear_one(x, y, z, x_new, y_new):
    f = interp2d(x, y, z)
    return f(x_new, y_new)


# Gaussian filter suitable for model training Data Pipeline
# z: input array. Must have dimensions: [BATCH_SIZE, Y, X]
# sigma: Standard deviation for Gaussian kernel
# returns stacked 2d arrays of same input dimension
def smooth_2d(z, sigma=1.0):
    z_smoothed = []
    for j in range(z.shape[0]):
        z_j = z[j, :, :]
        z_smoothed.append(gaussian_filter(z_j, sigma=sigma))
    return np.stack(z_smoothed)


# For [Y, X], see above
def smooth_2d_single(z, sigma=1.0):
    return gaussian_filter(z, sigma=sigma)


def median_filter_2d(z, kernel_size=3):
    z_filtered = []
    for j in range(z.shape[0]):
        z_j = z[j, :, :]
        z_filtered.append(medfilt2d(z_j, kernel_size=kernel_size))
    return np.stack(z_filtered)


def median_filter_2d_single(z, kernel_size=3):
    return medfilt2d(z, kernel_size=kernel_size)


def get_training_parameters(day_night='DAY', l1b_andor_l2='both', satellite='GOES16', use_dnb=False):
    if day_night == 'DAY':
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']

        if satellite == 'GOES16':
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_9_7um_nom',
                                'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']
                                # 'refl_2_10um_nom']
        elif satellite == 'H08':
            train_params_l1b = ['temp_10_4um_nom', 'temp_12_0um_nom', 'temp_8_5um_nom', 'temp_3_75um_nom', 'refl_2_10um_nom',
                                'refl_1_60um_nom', 'refl_0_86um_nom', 'refl_0_47um_nom']
    else:
        train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                           'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_acha', 'cld_opd_acha']

        if use_dnb is True:
            train_params_l2 = ['cld_height_acha', 'cld_geo_thick', 'cld_temp_acha', 'cld_press_acha', 'supercooled_cloud_fraction',
                               'cld_emiss_acha', 'conv_cloud_fraction', 'cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp']

        if satellite == 'GOES16':
            train_params_l1b = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
                                'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_9_7um_nom']
        elif satellite == 'H08':
            train_params_l1b = ['temp_10_4um_nom', 'temp_12_0um_nom', 'temp_8_5um_nom', 'temp_3_75um_nom']

    if l1b_andor_l2 == 'both':
        train_params = train_params_l1b + train_params_l2
    elif l1b_andor_l2 == 'l1b':
        train_params = train_params_l1b
    elif l1b_andor_l2 == 'l2':
        train_params = train_params_l2

    return train_params, train_params_l1b, train_params_l2