Skip to content
Snippets Groups Projects
resample.py 3.98 KiB
Newer Older
tomrink's avatar
tomrink committed
import numpy as np
from scipy.interpolate import RegularGridInterpolator, griddata
from cartopy.crs import LambertAzimuthalEqualArea
import cartopy.crs as ccrs
from netCDF4 import Dataset
import xarray as xr

# resample methods:
linear = 'linear'
cubic = 'cubic'
nearest = 'nearest'


def fill_missing(fld_a, fld_b, mask):
    min_val = min(np.nanmin(fld_a), np.nanmin(fld_b))
    max_val = max(np.nanmax(fld_a), np.nanmax(fld_b))
    random_values_a = np.random.uniform(min_val, max_val, size=fld_a.shape)
    random_values_b = np.random.uniform(min_val, max_val, size=fld_a.shape)

    fld_a[mask] = random_values_a[mask]
    fld_b[mask] = random_values_b[mask]


def get_projection(cartopy_map_name, cen_lat, cen_lon):
    if cartopy_map_name == "LambertAzimuthalEqualArea":
        projection = ccrs.LambertAzimuthalEqualArea(
            central_longitude=cen_lon, central_latitude=cen_lat
        )
    elif cartopy_map_name == "AlbersEqualArea":
        projection = ccrs.AlbersEqualArea(
            central_longitude=cen_lon, central_latitude=cen_lat
        )
    elif cartopy_map_name == "Sinusoidal":
        projection = ccrs.Sinusoidal(central_longitude=cen_lon)
    else:
        raise ValueError("Projection: " + cartopy_map_name + " is not supported")

    return projection


def resample_reg_grid(scalar_field, y, x, y_s, x_s, method='linear'):

    intrp = RegularGridInterpolator((y, x), scalar_field, method=method, bounds_error=False)

    xg, yg = np.meshgrid(x_s, y_s, indexing='xy')
    yg, xg = yg.flatten(), xg.flatten()
    pts = np.array([yg, xg])
    t_pts = np.transpose(pts)

    return np.reshape(intrp(t_pts), (y_s.shape[0], x_s.shape[0]))


def resample(scalar_field, y_d, x_d, y_t, x_t, method='linear'):
    # 2D target locations shape
    t_shape = y_t.shape
    # reproject scalar fields
    fld_repro = griddata((y_d.flatten(), x_d.flatten()), scalar_field.flatten(),(y_t.flatten(), x_t.flatten()), method=method)
    fld_repro = fld_repro.reshape(t_shape)
    return fld_repro


def reproject(fld_2d, lat_2d, lon_2d, proj, target_grid=None, grid_spacing=15000, method=linear):
    """
    :param fld_2d: the 2D scalar field to reproject
    :param lat_2d: 2D latitude of the scalar field domain
    :param lon_2d: 2D longitude of the scalar field domain
    :param proj: the map projection (Cartopy). Default: LambertAzimuthalEqualArea
    :param region_grid: the larger region grid that we pull the target grid from
    :param target_grid: the resampling target (y_map, x_map) where y_map and x_map are 2D. If None, the grid is created
           automatically. The target grid is always returned.
    :param grid_spacing: distance between the target grid points (in meters)
    :param method: resampling method: 'linear', 'nearest', 'cubic'
    :return: reprojected 2D scalar field, the target grid (will be 2D if rotate=True)
    """

    data_xy = proj.transform_points(ccrs.PlateCarree(), lon_2d, lat_2d)[..., :2]

    # Generate a regular 2d grid extending the min and max of the xy dimensions with grid_spacing
    if target_grid is None:
        x_min, y_min = np.amin(data_xy, axis=(0, 1))
        x_max, y_max = np.amax(data_xy, axis=(0, 1))
        x_map = np.arange(x_min, x_max, grid_spacing)
        y_map = np.arange(y_min, y_max, grid_spacing)
        x_map_2d, y_map_2d = np.meshgrid(x_map, y_map)
    else:
        y_map_2d, x_map_2d = target_grid

    fld_reproj = resample(fld_2d, data_xy[..., 1], data_xy[..., 0], y_map_2d, x_map_2d, method=method)

    return fld_reproj, (y_map_2d, x_map_2d)


def bisect_great_circle(lon_a, lat_a, lon_b, lat_b):
    lon_a = np.radians(lon_a)
    lat_a = np.radians(lat_a)
    lon_b = np.radians(lon_b)
    lat_b = np.radians(lat_b)

    dlon = lon_b - lon_a

    Bx = np.cos(lat_b) * np.cos(dlon)
    By = np.cos(lat_b) * np.sin(dlon)

    lat_c = np.arctan2(np.sin(lat_a) + np.sin(lat_b), np.sqrt((np.cos(lat_a) + Bx) ** 2 + By ** 2))
    lon_c = lon_a + np.arctan2(By, np.cos(lat_a) + Bx)

    lon_c = np.degrees(lon_c)
    lat_c = np.degrees(lat_c)

    return lon_c, lat_c