util.py 7.03 KiB
import re
import datetime
from datetime import timezone
import numpy as np
import xarray as xr
import pandas as pd
import rasterio
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import h5py
from util.util import get_grid_values_all
from util.gfs_reader import get_volume, volume_np_to_xr
from util.geos_nav import GEOSNavigation
from aeolus.datasource import GFSfiles
import cartopy.crs as ccrs
from metpy.calc import shearing_deformation, static_stability, wind_speed, first_derivative
from metpy.units import units
gfs_files = GFSfiles('/Users/tomrink/data/contrail/gfs/')
# GEOSNavigation needs to be updated to support GOES-18
# geos_nav = GEOSNavigation()
num_lines = 5424
num_elems = 5424
def load_image(image_path):
# Extract date time string from image filepath
datetime_regex = '_\\d{8}_\\d{6}'
datetime_string = re.search(datetime_regex, image_path)
if datetime_string:
datetime_string = datetime_string.group()
dto = datetime.datetime.strptime(datetime_string, '_%Y%m%d_%H%M%S').replace(tzinfo=timezone.utc)
ts = dto.timestamp()
img = mpimg.imread(image_path)
return img, ts
def get_contrail_mask_image(image, thresh=0.157):
image = np.where(image > thresh, 1, 0)
return image
def extract(mask_image, image_ts, clavrx_path):
# The GFS file closest to image_ts
gfs_file, _, _ = gfs_files.get_file(image_ts)
xr_dataset = xr.open_dataset(gfs_file)
clvrx_h5f = h5py.File(clavrx_path, 'r')
cloud_top_press = get_grid_values_all(clvrx_h5f, 'cld_press_acha').flatten()
clvrx_lons = get_grid_values_all(clvrx_h5f, 'longitude').flatten()
clvrx_lats = get_grid_values_all(clvrx_h5f, 'latitude').flatten()
contrail_mask = (mask_image == 1).flatten()
contrail_press = cloud_top_press[contrail_mask]
contrail_lons, contrail_lats = clvrx_lons[contrail_mask], clvrx_lats[contrail_mask]
# Use only the contrail obs where we have valid CLAVRx
keep = np.invert(np.isnan(contrail_press))
contrail_press = contrail_press[keep]
contrail_lons = contrail_lons[keep]
contrail_lats = contrail_lats[keep]
# Indexes of contrail_press for individual bins
press_bins = np.arange(100, 1000, 50)
binned_indexes = np.digitize(contrail_press, press_bins)
# Store the indexes in a dictionary where the key is the bin number and value is the list of indexes
bins_dict = {}
for bin_num in np.unique(binned_indexes):
bins_dict[bin_num] = np.where(binned_indexes == bin_num)[0]
# This section will compute model parameters in bulk for the region then pull for each contrail pixel
lon_range = [np.min(contrail_lons), np.max(contrail_lons)]
lat_range = [np.min(contrail_lats), np.max(contrail_lats)]
uwind_3d = get_volume(xr_dataset, 'u-wind', 'm s-1', lon_range=lon_range, lat_range=lat_range)
vwind_3d = get_volume(xr_dataset, 'v-wind', 'm s-1', lon_range=lon_range, lat_range=lat_range)
temp_3d = get_volume(xr_dataset, 'temperature', 'degK', lon_range=lon_range, lat_range=lat_range)
rh_3d = get_volume(xr_dataset, 'rh', '%', lon_range=lon_range, lat_range=lat_range)
uwind_3d = uwind_3d.transpose('Pressure', 'Latitude', 'Longitude')
vwind_3d = vwind_3d.transpose('Pressure', 'Latitude', 'Longitude')
temp_3d = temp_3d.transpose('Pressure', 'Latitude', 'Longitude')
rh_3d = rh_3d.transpose('Pressure', 'Latitude', 'Longitude')
horz_shear_3d = shearing_deformation(uwind_3d, vwind_3d)
static_3d = static_stability(temp_3d.coords['Pressure'] * units.hPa, temp_3d)
horz_wind_spd_3d = wind_speed(uwind_3d, vwind_3d)
# This one's a bit more work: `first_derivative` only returns a ndarray with no units, so we use the
# helper function to create a DataArray and add units via metpy's pint support
vert_shear_3d = first_derivative(horz_wind_spd_3d, axis=0, x=temp_3d.coords['Pressure'])
vert_shear_3d = volume_np_to_xr(vert_shear_3d, ['Pressure', 'Latitude', 'Longitude'], lon_range=lon_range, lat_range=lat_range)
vert_shear_3d = vert_shear_3d / units.hPa
# Form a new Dataset for interpolation below
da_dct = {'horz_shear_3d': horz_shear_3d, 'static_3d': static_3d, 'horz_wind_spd_3d': horz_wind_spd_3d,
'vert_shear_3d': vert_shear_3d, 'temp_3d': temp_3d, 'rh_3d': rh_3d}
xr_ds = xr.Dataset(da_dct)
all_list = []
levels_dict = {press_bins[key]: [] for key in bins_dict.keys()}
levels_dict_cidx = {press_bins[key]: [] for key in bins_dict.keys()}
for key in bins_dict.keys():
press_level = press_bins[key]
print('working on pressure level: ', press_level)
for c_idx in bins_dict[key]:
press = contrail_press[c_idx]
lat = contrail_lats[c_idx]
lon = contrail_lons[c_idx]
if lon < 0: # Match GFS convention
lon += 360.0
intrp_ds = xr_ds.interp(Pressure=press, Longitude=lon, Latitude=lat, method='nearest')
horz_shear_value, static_value, horz_wind_spd_value, vert_shear_value, temp_value, rh_value = (
intrp_ds['horz_shear_3d'].item(0), intrp_ds['static_3d'].item(0), intrp_ds['horz_wind_spd_3d'].item(0),
intrp_ds['vert_shear_3d'].item(0), intrp_ds['temp_3d'].item(0), intrp_ds['rh_3d'].item(0))
levels_dict_cidx[press_level].append(c_idx)
levels_dict[press_level].append((press, lat, lon, temp_value, rh_value, horz_shear_value, static_value, horz_wind_spd_value, vert_shear_value))
all_list.append((press_level, press, lat, lon, temp_value, rh_value, horz_shear_value, static_value, horz_wind_spd_value, vert_shear_value))
# Create pandas DataFrame for each list of tuples in levels_dict
levels_dict_df = {}
for k, v in levels_dict.items():
print('pressure level, number of contrail points: ', k, len(v))
df = pd.DataFrame(v,
columns=["pressure", "lat", "lon", "temperature", "relative_humidity",
"horz_shear_deform", "static_stability", "horz_wind_speed", "vert_wind_shear"])
levels_dict_df[k] = df
# Create a contrail location image for each pressure layer
levels_dict_image = {}
for k, v in levels_dict_cidx.items():
cidx_nda = np.array(v)
flat_image = np.zeros(num_lines*num_elems)
flat_image[cidx_nda] = 1
levels_dict_image[k] = flat_image.reshape(num_lines, num_elems)
# Create a DataFrame for all tuples
all_df = pd.DataFrame(all_list,
columns=["pressure_level", "pressure", "lat", "lon", "temperature", "relative_humidity",
"horz_shear_deform", "static_stability", "horz_wind_speed", "vert_wind_shear"])
xr_dataset.close()
return all_df
def analyze(dataFrame, column, value):
result_df = dataFrame[dataFrame[column] == value] # get rows where column has a certain value
mean = result_df.mean() # calculate mean for other columns
stddev = result_df.std() # calculate standard deviation for other columns
print("Mean:", mean)
print("Std deviation:", stddev)