Skip to content
Snippets Groups Projects
util.py 3.62 KiB
import re
import datetime
from datetime import timezone

import numpy as np
import xarray as xr
import rasterio
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import h5py
from util.util import get_grid_values_all
from util.gfs_reader import *
from util.geos_nav import GEOSNavigation
from aeolus.datasource import GFSfiles

gfs_files = GFSfiles('/Users/tomrink/data/contrail/gfs/')
# GEOSNavigation needs to be updated to support GOES-18
# geos_nav = GEOSNavigation()


def load_image(image_path):
    # Extract date time string from image path
    datetime_regex = '_\\d{8}_\\d{6}'
    datetime_string = re.search(datetime_regex, image_path)
    if datetime_string:
        datetime_string = datetime_string.group()
    dto = datetime.datetime.strptime(datetime_string, '_%Y%m%d_%H%M%S').replace(tzinfo=timezone.utc)
    ts = dto.timestamp()

    img = mpimg.imread(image_path)

    return img, ts


def get_contrail_mask_image(image, thresh=0.157):
    image = np.where(image > thresh,1, 0)
    return image


def extract(mask_image, image_ts, clavrx_path):
    gfs_file, _, _ = gfs_files.get_file(image_ts)
    xr_dataset = xr.open_dataset(gfs_file)

    clvrx_h5f = h5py.File(clavrx_path, 'r')
    cloud_top_press = get_grid_values_all(clvrx_h5f, 'cld_press_acha').flatten()
    clvrx_lons = get_grid_values_all(clvrx_h5f, 'longitude').flatten()
    clvrx_lats = get_grid_values_all(clvrx_h5f, 'latitude').flatten()

    contrail_idxs = (mask_image == 1).flatten()
    print('number of contrail pixels: ', np.sum(contrail_idxs))

    # Assuming GOES FD for now -------------------
    # elems, lines = np.meshgrid(np.arange(5424), np.arange(5424))
    # lines, elems = lines.flatten(), elems.flatten()
    # See note above regarding support for GOES-18
    # contrail_lines, contrail_elems = lines[contrail_idxs], elems[contrail_idxs]
    # contrail_lons, contrail_lats = geos_nav.lc_to_earth(contrail_elems, contrail_lines)

    contrail_press = cloud_top_press[contrail_idxs]
    contrail_lons, contrail_lats = clvrx_lons[contrail_idxs], clvrx_lats[contrail_idxs]

    keep = np.invert(np.isnan(contrail_press))
    contrail_press = contrail_press[keep]
    contrail_lons = contrail_lons[keep]
    contrail_lats = contrail_lats[keep]

    bins = np.arange(100, 1000, 100)

    # Indexes of contrail_press for individual bins
    binned_indexes = np.digitize(contrail_press, bins)

    # Store the indexes in a dictionary where the key is the bin number and value is the list of indexes
    bins_dict = {}
    for bin_num in np.unique(binned_indexes):
        bins_dict[bin_num] = np.where(binned_indexes == bin_num)[0]

    # MetPy for shearing deformation, static stability, vertical shear?, others...
    # This does point by point computation of model parameters for each contrail pixel
    voxel_dict = {key: [] for key in bins_dict.keys()}
    for key in bins_dict.keys():
        print('working on pressure level: ', bins[key])
        for c_idx in bins_dict[key]:
            lon = contrail_lons[c_idx]
            lat = contrail_lats[c_idx]
            press = contrail_press[c_idx]

            wind_3d = get_voxel_s(xr_dataset, ['u-wind','v-wind'], lon, lat, press)
            if wind_3d is not None:
                voxel_dict[key].append(wind_3d)

    # This section will compute model parameters in bulk for the region then pull for each contrail pixel
    lon_range = [np.min(contrail_lons), np.max(contrail_lons)]
    lat_range = [np.min(contrail_lats), np.max(contrail_lats)]

    wind_3d = get_volume(xr_dataset, ['u-wind','v-wind'], lon_range=[lon_range[0], lon_range[1]], lat_range=[lat_range[0], lat_range[1]])

    xr_dataset.close()