Skip to content
Snippets Groups Projects
util.py 31.21 KiB
import metpy
import numpy as np
import xarray as xr
import datetime
from datetime import timezone
from metpy.units import units
from metpy.calc import thickness_hydrostatic
from collections import namedtuple
import os
import h5py
import pickle
from netCDF4 import Dataset
from util.setup import ancillary_path
from scipy.interpolate import RectBivariateSpline, interp2d
from scipy.ndimage import gaussian_filter
from scipy.signal import medfilt2d
from cartopy.crs import Geostationary, Globe

LatLonTuple = namedtuple('LatLonTuple', ['lat', 'lon'])
FGFTuple = namedtuple('FGFTuple', ['line', 'elem'])

homedir = os.path.expanduser('~') + '/'

# --- CLAVRx Radiometric parameters and metadata ------------------------------------------------
l1b_ds_list = ['temp_10_4um_nom', 'temp_11_0um_nom', 'temp_12_0um_nom', 'temp_13_3um_nom', 'temp_3_75um_nom',
               'temp_6_2um_nom', 'temp_6_7um_nom', 'temp_7_3um_nom', 'temp_8_5um_nom', 'temp_9_7um_nom',
               'refl_0_47um_nom', 'refl_0_65um_nom', 'refl_0_86um_nom', 'refl_1_38um_nom', 'refl_1_60um_nom']

l1b_ds_types = {ds: 'f4' for ds in l1b_ds_list}
l1b_ds_fill = {l1b_ds_list[i]: -32767 for i in range(10)}
l1b_ds_fill.update({l1b_ds_list[i+10]: -32768 for i in range(5)})
l1b_ds_range = {ds: 'actual_range' for ds in l1b_ds_list}

# --- CLAVRx L2 parameters and metadata
ds_list = ['cld_height_acha', 'cld_geo_thick', 'cld_press_acha', 'sensor_zenith_angle', 'supercooled_prob_acha',
           'supercooled_cloud_fraction', 'cld_temp_acha', 'cld_opd_acha', 'solar_zenith_angle',
           'cld_reff_acha', 'cld_reff_dcomp', 'cld_reff_dcomp_1', 'cld_reff_dcomp_2', 'cld_reff_dcomp_3',
           'cld_opd_dcomp', 'cld_opd_dcomp_1', 'cld_opd_dcomp_2', 'cld_opd_dcomp_3', 'cld_cwp_dcomp', 'iwc_dcomp',
           'lwc_dcomp', 'cld_emiss_acha', 'conv_cloud_fraction', 'cloud_type', 'cloud_phase', 'cloud_mask']

ds_types = {ds_list[i]: 'f4' for i in range(23)}
ds_types.update({ds_list[i+23]: 'i1' for i in range(3)})
ds_fill = {ds_list[i]: -32768 for i in range(23)}
ds_fill.update({ds_list[i+23]: -128 for i in range(3)})
ds_range = {ds_list[i]: 'actual_range' for i in range(23)}
ds_range.update({ds_list[i]: None for i in range(3)})

ds_types.update(l1b_ds_types)
ds_fill.update(l1b_ds_fill)
ds_range.update(l1b_ds_range)

ds_types.update({'temp_3_9um_nom': 'f4'})
ds_types.update({'cloud_fraction': 'f4'})
ds_fill.update({'temp_3_9um_nom': -32767})
ds_fill.update({'cloud_fraction': -32768})
ds_range.update({'temp_3_9um_nom': 'actual_range'})
ds_range.update({'cloud_fraction': 'actual_range'})


class MyGenericException(Exception):
    def __init__(self, message):
        self.message = message


def make_tf_callable_generator(the_generator):
    class MyCallable:
        def __init__(self, gen):
            self.gen = gen

        def __call__(self):
            return self.gen

    the_callable = MyCallable(the_generator)
    return the_callable


def get_fill_attrs(name):
    if name in ds_fill:
        v = ds_fill[name]
        if v is None:
            return None, '_FillValue'
        else:
            return v, None
    else:
        return None, '_FillValue'


class GenericException(Exception):
    def __init__(self, message):
        self.message = message


class EarlyStop:
    def __init__(self, window_length=3, patience=5):
        self.patience = patience
        self.min = np.finfo(np.single).max
        self.cnt = 0
        self.cnt_wait = 0
        self.window = np.zeros(window_length, dtype=np.single)
        self.window.fill(np.nan)

    def check_stop(self, value):
        self.window[:-1] = self.window[1:]
        self.window[-1] = value

        if np.any(np.isnan(self.window)):
            return False

        ave = np.mean(self.window)
        if ave < self.min:
            self.min = ave
            self.cnt_wait = 0
            return False
        else:
            self.cnt_wait += 1

        if self.cnt_wait > self.patience:
            return True
        else:
            return False


def get_time_tuple_utc(timestamp):
    dt_obj = datetime.datetime.fromtimestamp(timestamp, timezone.utc)
    return dt_obj, dt_obj.timetuple()


def get_datetime_obj(dt_str, format_code='%Y-%m-%d_%H:%M'):
    dto = datetime.datetime.strptime(dt_str, format_code).replace(tzinfo=timezone.utc)
    return dto


def get_timestamp(dt_str, format_code='%Y-%m-%d_%H:%M'):
    dto = datetime.datetime.strptime(dt_str, format_code).replace(tzinfo=timezone.utc)
    ts = dto.timestamp()
    return ts


def add_time_range_to_filename(pathname, tstart, tend=None):
    filename = os.path.split(pathname)[1]
    w_o_ext, ext = os.path.splitext(filename)

    dt_obj, _ = get_time_tuple_utc(tstart)
    str_start = dt_obj.strftime('%Y%m%d%H')
    filename = w_o_ext+'_'+str_start

    if tend is not None:
        dt_obj, _ = get_time_tuple_utc(tend)
        str_end = dt_obj.strftime('%Y%m%d%H')
        filename = filename+'_'+str_end
    filename = filename+ext

    path = os.path.split(pathname)[0]
    path = path+'/'+filename
    return path


def haversine_np(lon1, lat1, lon2, lat2, earth_radius=6367.0):
    """
    Calculate the great circle distance between two points
    on the earth (specified in decimal degrees)

    (lon1, lat1) must be broadcastable with (lon2, lat2).
    """

    lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])

    dlon = lon2 - lon1
    dlat = lat2 - lat1

    a = np.sin(dlat/2.0)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2.0)**2

    c = 2.0 * np.arcsin(np.sqrt(a))
    km = earth_radius * c
    return km


def bin_data_by(a, b, bin_ranges):
    nbins = len(bin_ranges)
    binned_data = []

    for i in range(nbins):
        rng = bin_ranges[i]
        idxs = (b >= rng[0]) & (b < rng[1])
        binned_data.append(a[idxs])

    return binned_data


def bin_data_by_edges(a, b, edges):
    nbins = len(edges) - 1
    binned_data = []

    for i in range(nbins):
        idxs = (b >= edges[i]) & (b < edges[i+1])
        binned_data.append(a[idxs])

    return binned_data


def get_bin_ranges(lop, hip, bin_size=100):
    bin_ranges = []
    delp = hip - lop
    nbins = int(delp/bin_size)

    for i in range(nbins):
        rng = [lop + i*bin_size, lop + i*bin_size + bin_size]
        bin_ranges.append(rng)

    return bin_ranges


# t must be monotonic increasing
def get_breaks(t, threshold):
    t_0 = t[0:t.shape[0]-1]
    t_1 = t[1:t.shape[0]]
    d = t_1 - t_0
    idxs = np.nonzero(d > threshold)
    return idxs


# return indexes of ts where value is within ts[i] - threshold < value < ts[i] + threshold
# eventually, if necessary, fully vectorize (numpy) this is possible
# threshold units: seconds
def get_indexes_within_threshold(ts, value, threshold):
    idx_s = []
    t_s = []
    for k, v in enumerate(ts):
        if (ts[k] - threshold) <= value <= (ts[k] + threshold):
            idx_s.append(k)
            t_s.append(v)
    return idx_s, t_s


def pressure_to_altitude(pres, temp, prof_pres, prof_temp, sfc_pres=None, sfc_temp=None, sfc_elev=0):
    if not np.all(np.diff(prof_pres) > 0):
        raise GenericException("target pressure profile must be monotonic increasing")

    if pres < prof_pres[0]:
        raise GenericException("target pressure less than top of pressure profile")

    if temp is None:
        temp = np.interp(pres, prof_pres, prof_temp)

    i_top = np.argmax(np.extract(prof_pres <= pres, prof_pres)) + 1

    pres_s = prof_pres.tolist()
    temp_s = prof_temp.tolist()

    pres_s = [pres] + pres_s[i_top:]
    temp_s = [temp] + temp_s[i_top:]

    if sfc_pres is not None:
        if pres > sfc_pres:  # incoming pressure below surface
            return -999.0

        prof_pres = np.array(pres_s)
        prof_temp = np.array(temp_s)
        i_bot = prof_pres.shape[0] - 1

        if sfc_pres > prof_pres[i_bot]:  # surface below profile bottom
            pres_s = pres_s + [sfc_pres]
            temp_s = temp_s + [sfc_temp]
        else:
            idx = np.argmax(np.extract(prof_pres < sfc_pres, prof_pres))
            if sfc_temp is None:
                sfc_temp = np.interp(sfc_pres, prof_pres, prof_temp)
            pres_s = prof_pres.tolist()
            temp_s = prof_temp.tolist()
            pres_s = pres_s[0:idx+1] + [sfc_pres]
            temp_s = temp_s[0:idx+1] + [sfc_temp]

    prof_pres = np.array(pres_s)
    prof_temp = np.array(temp_s)

    prof_pres = prof_pres[::-1]
    prof_temp = prof_temp[::-1]

    prof_pres = prof_pres * units.hectopascal
    prof_temp = prof_temp * units.kelvin
    sfc_elev = sfc_elev * units.meter

    z = thickness_hydrostatic(prof_pres, prof_temp) + sfc_elev

    return z.magnitude


# http://fourier.eng.hmc.edu/e176/lectures/NM/node25.html
def minimize_quadratic(xa, xb, xc, ya, yb, yc):
    x_m = xb + 0.5*(((ya-yb)*(xc-xb)*(xc-xb) - (yc-yb)*(xb-xa)*(xb-xa)) / ((ya-yb)*(xc-xb) + (yc-yb)*(xb-xa)))
    return x_m


# Return index of nda closest to value. nda must be 1d
def value_to_index(nda, value):
    diff = np.abs(nda - value)
    idx = np.argmin(diff)
    return idx


def find_bin_index(nda, value_s):
    idxs = np.arange(nda.shape[0])
    iL_s = np.zeros(value_s.shape[0])
    iL_s[:,] = -1

    for k, v in enumerate(value_s):
        above = v >= nda
        if not above.any():
            continue

        below = v < nda
        if not below.any():
            continue

        iL = idxs[above].max()
        iL_s[k] = iL

    return iL_s.astype(np.int32)


# array solzen must be degrees, missing values must NaN. For small roughly 50x50km regions only
def is_day(solzen, test_angle=80.0):
    solzen = solzen.flatten()
    solzen = solzen[np.invert(np.isnan(solzen))]
    if len(solzen) == 0 or np.sum(solzen <= test_angle) < len(solzen):
        return False
    else:
        return True


# array solzen must be degrees, missing values must NaN. For small roughly 50x50km regions only
def is_night(solzen, test_angle=100.0):
    solzen = solzen.flatten()
    solzen = solzen[np.invert(np.isnan(solzen))]
    if len(solzen) == 0 or np.sum(solzen >= test_angle) < len(solzen):
        return False
    else:
        return True


def check_oblique(satzen, test_angle=70.0):
    satzen = satzen.flatten()
    satzen = satzen[np.invert(np.isnan(satzen))]
    if len(satzen) == 0 or np.sum(satzen <= test_angle) < len(satzen):
        return False
    else:
        return True


def get_median(tile_2d):
    tile = tile_2d.flatten()
    return np.nanmedian(tile)


def get_grid_values(h5f, grid_name, j_c, i_c, half_width, num_j=None, num_i=None, scale_factor_name='scale_factor', add_offset_name='add_offset',
                    fill_value_name='_FillValue', valid_range_name='valid_range', actual_range_name='actual_range', fill_value=None):
    hfds = h5f[grid_name]
    attrs = hfds.attrs
    if attrs is None:
        raise GenericException('No attributes object for: '+grid_name)

    ylen, xlen = hfds.shape

    if half_width is not None:
        j_l = j_c-half_width
        i_l = i_c-half_width
        if j_l < 0 or i_l < 0:
            return None

        j_r = j_c+half_width+1
        i_r = i_c+half_width+1
        if j_r >= ylen or i_r >= xlen:
            return None
    else:
        j_l = j_c
        j_r = j_c + num_j + 1
        i_l = i_c
        i_r = i_c + num_i + 1

    grd_vals = hfds[j_l:j_r, i_l:i_r]

    if fill_value_name is not None:
        attr = attrs.get(fill_value_name)
        if attr is not None:
            if np.isscalar(attr):
                fill_value = attr
            else:
                fill_value = attr[0]
            grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)
    elif fill_value is not None:
        grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)

    if valid_range_name is not None:
        attr = attrs.get(valid_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    if scale_factor_name is not None:
        attr = attrs.get(scale_factor_name)
        if attr is not None:
            if np.isscalar(attr):
                scale_factor = attr
            else:
                scale_factor = attr[0]
            grd_vals = grd_vals * scale_factor

    if add_offset_name is not None:
        attr = attrs.get(add_offset_name)
        if attr is not None:
            if np.isscalar(attr):
                add_offset = attr
            else:
                add_offset = attr[0]
            grd_vals = grd_vals + add_offset

    if actual_range_name is not None:
        attr = attrs.get(actual_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    return grd_vals


def get_grid_values_all(h5f, grid_name, scale_factor_name='scale_factor', add_offset_name='add_offset',
                        fill_value_name='_FillValue', valid_range_name='valid_range', actual_range_name='actual_range', fill_value=None, stride=None):
    hfds = h5f[grid_name]
    attrs = hfds.attrs

    if attrs is None:
        raise GenericException('No attributes object for: '+grid_name)

    if stride is None:
        grd_vals = hfds[:,]
    else:
        grd_vals = hfds[::stride, ::stride]

    if fill_value_name is not None:
        attr = attrs.get(fill_value_name)
        if attr is not None:
            if np.isscalar(attr):
                fill_value = attr
            else:
                fill_value = attr[0]
            grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)
    elif fill_value is not None:
        grd_vals = np.where(grd_vals == fill_value, np.nan, grd_vals)

    if valid_range_name is not None:
        attr = attrs.get(valid_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    if scale_factor_name is not None:
        attr = attrs.get(scale_factor_name)
        if attr is not None:
            if np.isscalar(attr):
                scale_factor = attr
            else:
                scale_factor = attr[0]
            grd_vals = grd_vals * scale_factor

    if add_offset_name is not None:
        attr = attrs.get(add_offset_name)
        if attr is not None:
            if np.isscalar(attr):
                add_offset = attr
            else:
                add_offset = attr[0]
            grd_vals = grd_vals + add_offset

    if actual_range_name is not None:
        attr = attrs.get(actual_range_name)
        if attr is not None:
            low = attr[0]
            high = attr[1]
            grd_vals = np.where(grd_vals < low, np.nan, grd_vals)
            grd_vals = np.where(grd_vals > high, np.nan, grd_vals)

    return grd_vals


# dt_str_0: start datetime string in format YYYY-MM-DD_HH:MM
# dt_str_1: stop datetime string, if not None num_steps is computed
# format_code: default '%Y-%m-%d_%H:%M'
# num_steps with increment of days, hours, minutes or seconds
# dt_str_1 and num_steps cannot both be None
# return num_steps+1 lists of datetime strings and timestamps (edges of a numpy histogram)
def make_times(dt_str_0, dt_str_1=None, format_code='%Y-%m-%d_%H:%M', num_steps=None, days=None, hours=None, minutes=None, seconds=None):
    if days is not None:
        inc = 86400*days
    elif hours is not None:
        inc = 3600*hours
    elif minutes is not None:
        inc = 60*minutes
    else:
        inc = seconds

    dt_obj_s = []
    ts_s = []
    dto_0 = datetime.datetime.strptime(dt_str_0, format_code).replace(tzinfo=timezone.utc)
    ts_0 = dto_0.timestamp()

    if dt_str_1 is not None:
        dto_1 = datetime.datetime.strptime(dt_str_1, format_code).replace(tzinfo=timezone.utc)
        ts_1 = dto_1.timestamp()
        num_steps = int((ts_1 - ts_0)/inc)

    dt_obj_s.append(dto_0)
    ts_s.append(ts_0)
    dto_last = dto_0
    for k in range(num_steps):
        dt_obj = dto_last + datetime.timedelta(seconds=inc)
        dt_obj_s.append(dt_obj)
        ts_s.append(dt_obj.timestamp())
        dto_last = dt_obj

    return dt_obj_s, ts_s


def normalize(data, param, mean_std_dict, copy=True):

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    if copy:
        data = data.copy()

    shape = data.shape
    data = data.flatten()

    mean, std, lo, hi = mean_std_dict.get(param)
    data -= mean
    data /= std

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def denormalize(data, param, mean_std_dict, copy=True):
    if copy:
        data = data.copy()

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    shape = data.shape
    data = data.flatten()

    mean, std, lo, hi = mean_std_dict.get(param)
    data *= std
    data += mean

    data = np.reshape(data, shape)

    return data


def scale(data, param, mean_std_dict, copy=True):
    if copy:
        data = data.copy()

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    shape = data.shape
    data = data.flatten()

    _, _, lo, hi = mean_std_dict.get(param)

    data -= lo
    data /= (hi - lo)

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def scale2(data, lo, hi, copy=True):
    if copy:
        data = data.copy()

    shape = data.shape
    data = data.flatten()

    data -= lo
    data /= (hi - lo)

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def descale2(data, lo, hi, copy=True):
    if copy:
        data = data.copy()

    shape = data.shape
    data = data.flatten()

    data *= (hi - lo)
    data += lo

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def descale(data, param, mean_std_dict, copy=True):
    if copy:
        data = data.copy()

    if mean_std_dict.get(param) is None:
        raise MyGenericException(param + ': not found in dictionary')

    shape = data.shape
    data = data.flatten()

    _, _, lo, hi = mean_std_dict.get(param)

    data *= (hi - lo)
    data += lo

    not_valid = np.isnan(data)
    data[not_valid] = 0

    data = np.reshape(data, shape)

    return data


def add_noise(data, noise_scale=0.01, seed=None, copy=True):
    if copy:
        data = data.copy()

    shape = data.shape
    data = data.flatten()

    if seed is not None:
        np.random.seed(seed)
    rnd = np.random.normal(loc=0, scale=noise_scale, size=data.size)
    data += rnd

    data = np.reshape(data, shape)

    return data


crs_goes16_fd = Geostationary(central_longitude=-75.0, satellite_height=35786023.0, sweep_axis='x',
                               globe=Globe(ellipse=None, semimajor_axis=6378137.0, semiminor_axis=6356752.31414,
                                           inverse_flattening=298.2572221))

crs_goes16_conus = Geostationary(central_longitude=-75.0, satellite_height=35786023.0, sweep_axis='x',
                                  globe=Globe(ellipse=None, semimajor_axis=6378137.0, semiminor_axis=6356752.31414,
                                              inverse_flattening=298.2572221))

crs_h08_fd = Geostationary(central_longitude=140.7, satellite_height=35785.863, sweep_axis='y',
                            globe=Globe(ellipse=None, semimajor_axis=6378.137, semiminor_axis=6356.7523,
                                        inverse_flattening=298.25702))


def get_cartopy_crs(satellite, domain):
    if satellite == 'GOES16':
        if domain == 'FD':
            geos = crs_goes16_fd
            xlen = 5424
            xmin = -5433893.0
            xmax = 5433893.0
            ylen = 5424
            ymin = -5433893.0
            ymax = 5433893.0
        elif domain == 'CONUS':
            geos = crs_goes16_conus
            xlen = 2500
            xmin = -3626269.5
            xmax = 1381770.0
            ylen = 1500
            ymin = 1584175.9
            ymax = 4588198.0
    elif satellite == 'H08':
        geos = crs_h08_fd
        xlen = 5500
        xmin = -5498.99990119
        xmax = 5498.99990119
        ylen = 5500
        ymin = -5498.99990119
        ymax = 5498.99990119
    elif satellite == 'H09':
        geos = crs_h08_fd
        xlen = 5500
        xmin = -5498.99990119
        xmax = 5498.99990119
        ylen = 5500
        ymin = -5498.99990119
        ymax = 5498.99990119

    return geos, xlen, xmin, xmax, ylen, ymin, ymax


def concat_dict_s(t_dct_0, t_dct_1):
    keys_0 = list(t_dct_0.keys())
    nda_0 = np.array(keys_0)

    keys_1 = list(t_dct_1.keys())
    nda_1 = np.array(keys_1)

    comm_keys, comm0, comm1 = np.intersect1d(nda_0, nda_1, return_indices=True)

    comm_keys = comm_keys.tolist()

    for key in comm_keys:
        t_dct_1.pop(key)
    t_dct_0.update(t_dct_1)

    return t_dct_0


rho_water = 1000000.0  # g m^-3
rho_ice = 917000.0  # g m^-3

# real(kind=real4), parameter:: Rho_Water = 1.0    !g / m ^ 3
# real(kind=real4), parameter:: Rho_Ice = 0.917    !g / m ^ 3
#
# !--- compute
# cloud
# water
# path
# if (Iphase == 0) then
# Cwp_Dcomp(Elem_Idx, Line_Idx) = 0.55 * Tau * Reff * Rho_Water
# Lwp_Dcomp(Elem_Idx, Line_Idx) = 0.55 * Tau * Reff * Rho_Water
# else
# Cwp_Dcomp(Elem_Idx, Line_Idx) = 0.667 * Tau * Reff * Rho_Ice
# Iwp_Dcomp(Elem_Idx, Line_Idx) = 0.667 * Tau * Reff * Rho_Ice
# endif


def compute_lwc_iwc(iphase, reff, opd, geo_dz):
    xy_shape = iphase.shape

    iphase = iphase.flatten()
    keep_0 = np.invert(np.isnan(iphase))

    reff = reff.flatten()
    keep_1 = np.invert(np.isnan(reff))

    opd = opd.flatten()
    keep_2 = np.invert(np.isnan(opd))

    geo_dz = geo_dz.flatten()
    keep_3 = np.logical_and(np.invert(np.isnan(geo_dz)), geo_dz > 1.0)

    keep = keep_0 & keep_1 & keep_2 & keep_3

    lwp_dcomp = np.zeros(reff.shape[0])
    iwp_dcomp = np.zeros(reff.shape[0])
    lwp_dcomp[:] = np.nan
    iwp_dcomp[:] = np.nan

    ice = iphase == 1 & keep
    no_ice = iphase != 1 & keep

    # compute ice/liquid water path, g m-2
    reff *= 1.0e-06  # convert microns to meters

    iwp_dcomp[ice] = 0.667 * opd[ice] * rho_ice * reff[ice]
    lwp_dcomp[no_ice] = 0.55 * opd[no_ice] * rho_water * reff[no_ice]

    iwp_dcomp /= geo_dz
    lwp_dcomp /= geo_dz

    lwp_dcomp = lwp_dcomp.reshape(xy_shape)
    iwp_dcomp = iwp_dcomp.reshape(xy_shape)

    return lwp_dcomp, iwp_dcomp


# Example GOES file to retrieve GEOS parameters in MetPy form (CONUS)
exmp_file_conus = '/Users/tomrink/data/OR_ABI-L1b-RadC-M6C14_G16_s20193140811215_e20193140813588_c20193140814070.nc'
# Full Disk
exmp_file_fd = '/Users/tomrink/data/OR_ABI-L1b-RadF-M6C16_G16_s20212521800223_e20212521809542_c20212521809596.nc'


def get_cf_nav_parameters(satellite='GOES16', domain='FD'):
    param_dct = None

    if satellite == 'H08':  # We presently only have FD
        param_dct = {'semi_major_axis': 6378.137,
                     'semi_minor_axis': 6356.7523,
                     'perspective_point_height': 35785.863,
                     'latitude_of_projection_origin': 0.0,
                     'longitude_of_projection_origin': 140.7,
                     'inverse_flattening': 298.257,
                     'sweep_angle_axis': 'y',
                     'x_scale_factor': 5.58879902955962e-05,
                     'x_add_offset': -0.153719917308037,
                     'y_scale_factor': -5.58879902955962e-05,
                     'y_add_offset': 0.153719917308037}
    elif satellite == 'H09':
        param_dct = {'semi_major_axis': 6378.137,
                     'semi_minor_axis': 6356.7523,
                     'perspective_point_height': 35785.863,
                     'latitude_of_projection_origin': 0.0,
                     'longitude_of_projection_origin': 140.7,
                     'inverse_flattening': 298.257,
                     'sweep_angle_axis': 'y',
                     'x_scale_factor': 5.58879902955962e-05,
                     'x_add_offset': -0.153719917308037,
                     'y_scale_factor': -5.58879902955962e-05,
                     'y_add_offset': 0.153719917308037}
    elif satellite == 'GOES16':
        if domain == 'CONUS':
            param_dct = {'semi_major_axis': 6378137.0,
                         'semi_minor_axis': 6356752.31414,
                         'perspective_point_height': 35786023.0,
                         'latitude_of_projection_origin': 0.0,
                         'longitude_of_projection_origin': -75,
                         'inverse_flattening': 298.257,
                         'sweep_angle_axis': 'x',
                         'x_scale_factor': 5.6E-05,
                         'x_add_offset': -0.101332,
                         'y_scale_factor': -5.6E-05,
                         'y_add_offset': 0.128212}
        elif domain == 'FD':
            param_dct = {'semi_major_axis': 6378137.0,
                         'semi_minor_axis': 6356752.31414,
                         'perspective_point_height': 35786023.0,
                         'latitude_of_projection_origin': 0.0,
                         'longitude_of_projection_origin': -75,
                         'inverse_flattening': 298.257,
                         'sweep_angle_axis': 'x',
                         'x_scale_factor': 5.6E-05,
                         'x_add_offset': -0.151844,
                         'y_scale_factor': -5.6E-05,
                         'y_add_offset': 0.151844}

    return param_dct


def write_cld_prods_file_nc4(clvrx_str_time, outfile_name, cloud_fraction, cloud_frac_opd,
                            x, y, elems, lines, satellite='GOES16', domain='CONUS',
                            has_time=False):

    rootgrp = Dataset(outfile_name, 'w', format='NETCDF4')

    rootgrp.setncattr('Conventions', 'CF-1.7')

    dim_0_name = 'x'
    dim_1_name = 'y'
    time_dim_name = 'time'
    geo_coords = 'time y x'

    dim_0 = rootgrp.createDimension(dim_0_name, size=x.shape[0])
    dim_1 = rootgrp.createDimension(dim_1_name, size=y.shape[0])
    dim_time = rootgrp.createDimension(time_dim_name, size=1)

    tvar = rootgrp.createVariable('time', 'f8', time_dim_name)
    tvar[0] = get_timestamp(clvrx_str_time)
    tvar.units = 'seconds since 1970-01-01 00:00:00'

    if not has_time:
        var_dim_list = [dim_1_name, dim_0_name]
    else:
        var_dim_list = [time_dim_name, dim_1_name, dim_0_name]

    cld_frac_ds = rootgrp.createVariable('cloud_fraction', 'i1', var_dim_list)
    cld_frac_ds.setncattr('long_name', 'FCNN inferred cloud fraction')
    cld_frac_ds.setncattr('coordinates', geo_coords)
    cld_frac_ds.setncattr('grid_mapping', 'Projection')
    cld_frac_ds.setncattr('missing', -1)
    if has_time:
        cloud_fraction = cloud_fraction.reshape((1, y.shape[0], x.shape[0]))
    cld_frac_ds[:, ] = cloud_fraction

    cld_frac_opd_ds = rootgrp.createVariable('cldy_fraction_opd', 'f4', var_dim_list)
    cld_frac_opd_ds.setncattr('long_name', 'FCNN inferred fractional OPD')
    cld_frac_opd_ds.setncattr('coordinates', geo_coords)
    cld_frac_opd_ds.setncattr('grid_mapping', 'Projection')
    cld_frac_opd_ds.setncattr('missing', -1.0)
    if has_time:
        cloud_frac_opd = cloud_frac_opd.reshape((1, y.shape[0], x.shape[0]))
    cld_frac_opd_ds[:, ] = cloud_frac_opd

    cf_nav_dct = get_cf_nav_parameters(satellite, domain)

    if satellite == 'H08':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'H09':
        long_name = 'Himawari Imagery Projection'
    elif satellite == 'GOES16':
        long_name = 'GOES-16/17 Imagery Projection'

    proj_ds = rootgrp.createVariable('Projection', 'b')
    proj_ds.setncattr('long_name', long_name)
    proj_ds.setncattr('grid_mapping_name', 'geostationary')
    proj_ds.setncattr('sweep_angle_axis', cf_nav_dct['sweep_angle_axis'])
    proj_ds.setncattr('semi_major_axis', cf_nav_dct['semi_major_axis'])
    proj_ds.setncattr('semi_minor_axis', cf_nav_dct['semi_minor_axis'])
    proj_ds.setncattr('inverse_flattening', cf_nav_dct['inverse_flattening'])
    proj_ds.setncattr('perspective_point_height', cf_nav_dct['perspective_point_height'])
    proj_ds.setncattr('latitude_of_projection_origin', cf_nav_dct['latitude_of_projection_origin'])
    proj_ds.setncattr('longitude_of_projection_origin', cf_nav_dct['longitude_of_projection_origin'])

    if x is not None:
        x_ds = rootgrp.createVariable(dim_0_name, 'f8', [dim_0_name])
        x_ds.units = 'rad'
        x_ds.setncattr('axis', 'X')
        x_ds.setncattr('standard_name', 'projection_x_coordinate')
        x_ds.setncattr('long_name', 'fixed grid viewing angle')
        x_ds.setncattr('scale_factor', cf_nav_dct['x_scale_factor'])
        x_ds.setncattr('add_offset', cf_nav_dct['x_add_offset'])
        x_ds[:] = x

        y_ds = rootgrp.createVariable(dim_1_name, 'f8', [dim_1_name])
        y_ds.units = 'rad'
        y_ds.setncattr('axis', 'Y')
        y_ds.setncattr('standard_name', 'projection_y_coordinate')
        y_ds.setncattr('long_name', 'fixed grid viewing angle')
        y_ds.setncattr('scale_factor', cf_nav_dct['y_scale_factor'])
        y_ds.setncattr('add_offset', cf_nav_dct['y_add_offset'])
        y_ds[:] = y

    if elems is not None:
        elem_ds = rootgrp.createVariable('elems', 'i2', [dim_0_name])
        elem_ds[:] = elems
        line_ds = rootgrp.createVariable('lines', 'i2', [dim_1_name])
        line_ds[:] = lines
        pass

    rootgrp.close()


def downscale_2x(original, smoothing=False, samples_axis_first=False):
    # if smoothing:
    #     original = scipy.ndimage.gaussian_filter(original, sigma = 1/2)
    if not samples_axis_first:
        lr = np.nanmean(np.array([original[0::2,0::2],
              original[1::2,0::2],
              original[0::2,1::2],
              original[1::2,1::2]]),axis=0).squeeze()
    elif samples_axis_first:
        lr = np.nanmean(np.array([original[:,0::2,0::2],
              original[:,1::2,0::2],
              original[:,0::2,1::2],
              original[:,1::2,1::2]]),axis=0).squeeze()
    return lr


def resample(x, y, z, x_new, y_new):
    z_intrp = []
    for k in range(z.shape[0]):
        z_k = z[k, :, :]
        f = RectBivariateSpline(x, y, z_k)
        z_intrp.append(f(x_new, y_new))
    return np.stack(z_intrp)


def resample_one(x, y, z, x_new, y_new):
    f = RectBivariateSpline(x, y, z)
    return f(x_new, y_new)


def resample_2d_linear(x, y, z, x_new, y_new):
    z_intrp = []
    for k in range(z.shape[0]):
        z_k = z[k, :, :]
        f = interp2d(x, y, z_k)
        z_intrp.append(f(x_new, y_new))
    return np.stack(z_intrp)


def resample_2d_linear_one(x, y, z, x_new, y_new):
    f = interp2d(x, y, z)
    return f(x_new, y_new)


# Gaussian filter suitable for model training Data Pipeline
# z: input array. Must have dimensions: [BATCH_SIZE, Y, X]
# sigma: Standard deviation for Gaussian kernel
# returns stacked 2d arrays of same input dimension
def smooth_2d(z, sigma=1.0):
    z_smoothed = []
    for j in range(z.shape[0]):
        z_j = z[j, :, :]
        z_smoothed.append(gaussian_filter(z_j, sigma=sigma))
    return np.stack(z_smoothed)


# For [Y, X], see above
def smooth_2d_single(z, sigma=1.0):
    return gaussian_filter(z, sigma=sigma)


def median_filter_2d(z, kernel_size=3):
    z_filtered = []
    for j in range(z.shape[0]):
        z_j = z[j, :, :]
        z_filtered.append(medfilt2d(z_j, kernel_size=kernel_size))
    return np.stack(z_filtered)


def median_filter_2d_single(z, kernel_size=3):
    return medfilt2d(z, kernel_size=kernel_size)