import re import datetime from datetime import timezone import numpy as np import xarray as xr import pandas as pd import rasterio from PIL import Image import matplotlib.pyplot as plt import matplotlib.image as mpimg import h5py from util.util import get_grid_values_all from util.gfs_reader import get_volume, volume_np_to_xr from util.geos_nav import GEOSNavigation from aeolus.datasource import GFSfiles import cartopy.crs as ccrs from metpy.calc import shearing_deformation, static_stability, wind_speed, first_derivative from metpy.units import units gfs_files = GFSfiles('/Users/tomrink/data/contrail/gfs/') # GEOSNavigation needs to be updated to support GOES-18 # geos_nav = GEOSNavigation() num_lines = 5424 num_elems = 5424 def load_image(image_path): # Extract date time string from image filepath datetime_regex = '_\\d{8}_\\d{6}' datetime_string = re.search(datetime_regex, image_path) if datetime_string: datetime_string = datetime_string.group() dto = datetime.datetime.strptime(datetime_string, '_%Y%m%d_%H%M%S').replace(tzinfo=timezone.utc) ts = dto.timestamp() img = mpimg.imread(image_path) return img, ts def get_contrail_mask_image(image, thresh=0.157): image = np.where(image > thresh, 1, 0) return image def extract(mask_image, image_ts, clavrx_path): # The GFS file closest to image_ts gfs_file, _, _ = gfs_files.get_file(image_ts) xr_dataset = xr.open_dataset(gfs_file) clvrx_h5f = h5py.File(clavrx_path, 'r') cloud_top_press = get_grid_values_all(clvrx_h5f, 'cld_press_acha').flatten() clvrx_lons = get_grid_values_all(clvrx_h5f, 'longitude').flatten() clvrx_lats = get_grid_values_all(clvrx_h5f, 'latitude').flatten() contrail_mask = (mask_image == 1).flatten() contrail_press = cloud_top_press[contrail_mask] contrail_lons, contrail_lats = clvrx_lons[contrail_mask], clvrx_lats[contrail_mask] # Use only the contrail obs where we have valid CLAVRx keep = np.invert(np.isnan(contrail_press)) contrail_press = contrail_press[keep] contrail_lons = contrail_lons[keep] contrail_lats = contrail_lats[keep] # Indexes of contrail_press for individual bins press_bins = np.arange(100, 1000, 50) binned_indexes = np.digitize(contrail_press, press_bins) # Store the indexes in a dictionary where the key is the bin number and value is the list of indexes bins_dict = {} for bin_num in np.unique(binned_indexes): bins_dict[bin_num] = np.where(binned_indexes == bin_num)[0] # This section will compute model parameters in bulk for the region then pull for each contrail pixel lon_range = [np.min(contrail_lons), np.max(contrail_lons)] lat_range = [np.min(contrail_lats), np.max(contrail_lats)] uwind_3d = get_volume(xr_dataset, 'u-wind', 'm s-1', lon_range=lon_range, lat_range=lat_range) vwind_3d = get_volume(xr_dataset, 'v-wind', 'm s-1', lon_range=lon_range, lat_range=lat_range) temp_3d = get_volume(xr_dataset, 'temperature', 'degK', lon_range=lon_range, lat_range=lat_range) rh_3d = get_volume(xr_dataset, 'rh', '%', lon_range=lon_range, lat_range=lat_range) uwind_3d = uwind_3d.transpose('Pressure', 'Latitude', 'Longitude') vwind_3d = vwind_3d.transpose('Pressure', 'Latitude', 'Longitude') temp_3d = temp_3d.transpose('Pressure', 'Latitude', 'Longitude') rh_3d = rh_3d.transpose('Pressure', 'Latitude', 'Longitude') horz_shear_3d = shearing_deformation(uwind_3d, vwind_3d) static_3d = static_stability(temp_3d.coords['Pressure'] * units.hPa, temp_3d) horz_wind_spd_3d = wind_speed(uwind_3d, vwind_3d) # This one's a bit more work: `first_derivative` only returns a ndarray with no units, so we use the # helper function to create a DataArray and add units via metpy's pint support vert_shear_3d = first_derivative(horz_wind_spd_3d, axis=0, x=temp_3d.coords['Pressure']) vert_shear_3d = volume_np_to_xr(vert_shear_3d, ['Pressure', 'Latitude', 'Longitude'], lon_range=lon_range, lat_range=lat_range) vert_shear_3d = vert_shear_3d / units.hPa # Form a new Dataset for interpolation below da_dct = {'horz_shear_3d': horz_shear_3d, 'static_3d': static_3d, 'horz_wind_spd_3d': horz_wind_spd_3d, 'vert_shear_3d': vert_shear_3d, 'temp_3d': temp_3d, 'rh_3d': rh_3d} xr_ds = xr.Dataset(da_dct) all_list = [] levels_dict = {press_bins[key]: [] for key in bins_dict.keys()} levels_dict_cidx = {press_bins[key]: [] for key in bins_dict.keys()} for key in bins_dict.keys(): press_level = press_bins[key] print('working on pressure level: ', press_level) for c_idx in bins_dict[key]: press = contrail_press[c_idx] lat = contrail_lats[c_idx] lon = contrail_lons[c_idx] if lon < 0: # Match GFS convention lon += 360.0 intrp_ds = xr_ds.interp(Pressure=press, Longitude=lon, Latitude=lat, method='nearest') horz_shear_value, static_value, horz_wind_spd_value, vert_shear_value, temp_value, rh_value = ( intrp_ds['horz_shear_3d'].item(0), intrp_ds['static_3d'].item(0), intrp_ds['horz_wind_spd_3d'].item(0), intrp_ds['vert_shear_3d'].item(0), intrp_ds['temp_3d'].item(0), intrp_ds['rh_3d'].item(0)) levels_dict_cidx[press_level].append(c_idx) levels_dict[press_level].append((press, lat, lon, temp_value, rh_value, horz_shear_value, static_value, horz_wind_spd_value, vert_shear_value)) all_list.append((press_level, press, lat, lon, temp_value, rh_value, horz_shear_value, static_value, horz_wind_spd_value, vert_shear_value)) # Create pandas DataFrame for each list of tuples in levels_dict levels_dict_df = {} for k, v in levels_dict.items(): print('pressure level, number of contrail points: ', k, len(v)) df = pd.DataFrame(v, columns=["pressure", "lat", "lon", "temperature", "relative_humidity", "horz_shear_deform", "static_stability", "horz_wind_speed", "vert_wind_shear"]) levels_dict_df[k] = df # Create a contrail location image for each pressure layer levels_dict_image = {} for k, v in levels_dict_cidx.items(): cidx_nda = np.array(v) flat_image = np.zeros(num_lines*num_elems) flat_image[cidx_nda] = 1 levels_dict_image[k] = flat_image.reshape(num_lines, num_elems) # Create a DataFrame for all tuples all_df = pd.DataFrame(all_list, columns=["pressure_level", "pressure", "lat", "lon", "temperature", "relative_humidity", "horz_shear_deform", "static_stability", "horz_wind_speed", "vert_wind_shear"]) xr_dataset.close() return all_df def analyze(dataFrame, column, value): result_df = dataFrame[dataFrame[column] == value] # get rows where column has a certain value mean = result_df.mean() # calculate mean for other columns stddev = result_df.std() # calculate standard deviation for other columns print("Mean:", mean) print("Std deviation:", stddev)