From f92fcc58d353733e92ea73fc41b9d7d523cb83f2 Mon Sep 17 00:00:00 2001
From: Coda Phillips <coda.phillips@wisc.edu>
Date: Mon, 3 Jun 2024 19:10:56 +0000
Subject: [PATCH] forgor
---
l2bc_utils.py | 378 ++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 378 insertions(+)
create mode 100644 l2bc_utils.py
diff --git a/l2bc_utils.py b/l2bc_utils.py
new file mode 100644
index 0000000..618625f
--- /dev/null
+++ b/l2bc_utils.py
@@ -0,0 +1,378 @@
+import netCDF4
+import numpy as np
+from tqdm import tqdm
+from pathlib import Path
+import pandas as pd
+from datetime import datetime, timedelta
+import multiprocessing as mp
+import socket
+import io
+import time
+import threading
+from queue import Queue, Empty
+import random
+import warnings
+from collections import defaultdict
+import xarray as xr
+
+ROOT = Path('/media/coda-drive/patmosx_l2bc')
+
+
+def scan_files(root=ROOT):
+ root = Path(root)
+ df = []
+ with tqdm(root.rglob('*.nc')) as bar:
+ for f in bar:
+ # patmosx_v06r00_NOAA-09_asc_d19880428_c20210810.cloud_probability.nc
+ stem, key, _ = f.name.split('.')
+ _,_,platform,node,date,_ = stem.split('_')
+ date = datetime.strptime(date, 'd%Y%m%d')
+ df.append({'platform': platform, 'node': node, 'date': date, 'key': key, 'path': f})
+ df = pd.DataFrame(df)
+ return df
+
+
+def get_files(root=ROOT):
+ root = Path(root)
+ df = scan_files(root)
+ if df.empty:
+ raise ValueError(f'No .nc files found in {root}')
+ files = df.pivot(index=['platform','node','date'], columns='key', values='path').sort_index(level='date')
+ return files
+
+
+def _read_worker(q, out_q):
+ while True:
+ try:
+ args = q.get(timeout=1)
+ key, f = args
+ nc = netCDF4.Dataset(f)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ data = nc.variables[key][0]
+ out_q.put(data)
+ nc.close()
+ except Empty:
+ return
+
+def clear_queue(q):
+ while True:
+ try:
+ q.get_nowait()
+ except Empty:
+ return
+
+def fast_iter_nc_var(files, *args):
+ args = list(args)
+ file_subset = files[args].dropna()
+ task_queue = Queue()
+ out_queue = Queue()
+ thread = threading.Thread(target=_read_worker, args=(task_queue, out_queue))
+ tasks = []
+ for (platform, node, date), paths in file_subset.iterrows():
+ key = platform, date, node
+ tasks.append((key, args))
+ for k in args:
+ task_queue.put((k, paths[k]))
+ thread.start()
+
+ try:
+ for key,task in tasks:
+ results = [out_queue.get() for _ in task]
+ yield key, results
+ thread.join()
+ except:
+ clear_queue(task_queue)
+ thread.join()
+ raise
+
+
+def read_file(args, from_phase=False):
+ """
+ read a single netcdf file
+ return table of features (x,y,frac_year,year,local_hour,cloud_probability)
+ """
+ (platform, node, date), row = args
+ if from_phase:
+ nc = netCDF4.Dataset(row['ice_cloud_probability'])
+ ice_v = nc.variables['ice_cloud_probability'][0]
+ nc.close()
+ nc = netCDF4.Dataset(row['water_cloud_probability'])
+ water_v = nc.variables['water_cloud_probability'][0]
+ nc.close()
+ v = ice_v + water_v
+ else:
+ nc = netCDF4.Dataset(row['cloud_probability'])
+ v = nc.variables['cloud_probability'][0]
+ nc.close()
+ nc = netCDF4.Dataset(row['scan_line_time'])
+ slt = nc.variables['scan_line_time'][0]
+ lon = nc.variables['longitude'][:]
+ nc.close()
+ y,x = np.meshgrid(np.arange(v.shape[0]), np.arange(v.shape[1]), indexing='ij', copy=True)
+ frac_year = (date - datetime(date.year,1,1)).total_seconds() / (365.25*24*60*60)
+ year = date.year
+ utc_offset = lon / 360 * 24
+ local_hour = (slt + utc_offset) % 24
+ mask = ~(np.ma.getmaskarray(v) | np.isnan(v) | np.ma.getmaskarray(slt) | np.isnan(slt)).ravel()
+ features = {}
+ features['x'] = x.ravel()[mask]
+ features['y'] = y.ravel()[mask]
+ features['frac_year'] = np.full(x.size, frac_year, dtype=np.float32)[mask]
+ features['year'] = np.full(x.size, year, dtype=np.float32)[mask]
+ features['year'] += frac_year
+ features['local_hour'] = local_hour.ravel()[mask].data
+ features['cloud_probability'] = v.ravel()[mask].data
+ return features
+
+
+def send_tv(batch):
+ """
+ send a batch of files to the server
+ """
+ HOST = '127.0.0.1'
+ PORT = 50007
+ payload = io.BytesIO()
+ #np.savez_compressed(payload, **batch)
+ np.savez(payload, **batch)
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ while True:
+ try:
+ s.connect((HOST, PORT))
+ s.sendall(payload.getbuffer())
+ return
+ except ConnectionRefusedError:
+ time.sleep(.1)
+ except ConnectionResetError:
+ time.sleep(random.random()*.1)
+
+
+def get_features_tv():
+ HOST = 'localhost'
+ PORT = 50007
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind((HOST, PORT))
+ s.listen()
+ while True:
+ payload = io.BytesIO()
+ conn, addr = s.accept()
+ with conn:
+ while True:
+ data = conn.recv(2**16)
+ if not data:
+ break
+ payload.write(data)
+ payload.seek(0)
+ npz = np.load(payload)
+ features = {k: npz[k] for k in npz.files}
+ if not features:
+ return
+ yield features
+
+
+def read_and_send(args):
+ """
+ read a single file and send it to the server
+ """
+ features = read_file(args)
+ send_tv(features)
+
+
+def compute_stats_index(batch, coords):
+ hour_coords = coords['hour']
+ season_coords = coords['season']
+ year_coords = coords['year']
+
+ hour_season_shape = (len(hour_coords), len(season_coords), 180, 360)
+ year_anomaly_shape = (len(year_coords), 180, 360)
+
+ hour_idx = np.digitize(batch['local_hour'], hour_coords, right=True) - 1
+ season_idx = np.digitize(batch['frac_year'], season_coords, right=True) - 1
+ hour_season_idx = np.ravel_multi_index((hour_idx, season_idx, batch['y'], batch['x']), hour_season_shape, mode='clip')
+ year_idx = np.digitize(batch['year'], year_coords, right=True).clip(1, len(year_coords)) - 1
+ year_anomaly_idx = np.ravel_multi_index((year_idx, batch['y'], batch['x']), year_anomaly_shape, mode='clip')
+ return hour_season_idx, year_anomaly_idx
+
+
+def update_stats(stats, hour_season_idx, year_anomaly_idx, v, mask):
+ old_shapes = {k: stats[k].shape for k in stats}
+ stats['hour_season_sum'] = stats['hour_season_sum'].ravel().at[hour_season_idx].add(v)
+ stats['hour_season_count'] = stats['hour_season_count'].ravel().at[hour_season_idx].add(mask)
+ stats['year_anomaly_sum'] = stats['year_anomaly_sum'].ravel().at[year_anomaly_idx].add(v)
+ stats['year_anomaly_count'] = stats['year_anomaly_count'].ravel().at[year_anomaly_idx].add(mask)
+ for k in stats:
+ stats[k] = stats[k].reshape(old_shapes[k])
+ return stats
+
+
+def init_stats(coords):
+
+ hour_coords = coords['hour']
+ season_coords = coords['season']
+ year_coords = coords['year']
+
+ zeros = np.zeros
+ stats = {}
+ stats['hour_season_sum'] = zeros((len(hour_coords), len(season_coords), 180, 360), dtype=np.float32)
+ stats['hour_season_count'] = zeros((len(hour_coords), len(season_coords), 180, 360), dtype=np.int32)
+ stats['year_anomaly_sum'] = zeros((len(year_coords),180,360), dtype=np.float32)
+ stats['year_anomaly_count'] = zeros((len(year_coords),180,360), dtype=np.int32)
+ return stats
+
+
+def filter_szas(v, sza, slt, ANGLE_CUTOFF=30):
+ mask = sza > ANGLE_CUTOFF
+ v[mask] = np.ma.masked
+ return v, slt
+
+
+def compute_hour_season_stats(files, *required_vars, transform=filter_szas):
+ nc = netCDF4.Dataset(files.iloc[0].dropna().values[0])
+ lon = nc.variables['longitude'][:]
+ utc_offset = lon / 360 * 24
+ nc.close()
+ season_hour_sum = np.zeros((12,24,180,360), dtype=np.float64)
+ season_hour_count = np.zeros((12,24,180,360), dtype=np.int32)
+ i,j = np.meshgrid(np.arange(180), np.arange(360), indexing='ij')
+
+ files = files[list(required_vars)].dropna()
+ with tqdm(fast_iter_nc_var(files, *required_vars), total=len(files)) as bar:
+ for (platform, date, node), data in bar:
+ v, slt = transform(*data)
+ idx = compute_season_hour_index(date, slt, utc_offset, i, j)
+ season_hour_sum.ravel()[idx] += v.filled(0).ravel()
+ season_hour_count.ravel()[idx] += 1-np.ma.getmaskarray(v.ravel())
+ return season_hour_sum, season_hour_count
+
+
+def compute_season_hour_index(date, slt, utc_offset, i, j):
+ local_hour = (slt + utc_offset)
+ hour_idx = np.floor(local_hour).astype(np.int32) % 24
+ season_idx = np.full(hour_idx.shape, date.month-1, dtype=np.int32)
+ idx = np.ravel_multi_index((season_idx.ravel(), hour_idx.ravel(), i.ravel(), j.ravel()), (12,24,180,360))
+ return idx
+
+
+def frac_month(date):
+ since_month_begin = date - date.replace(day=1)
+ month_duration = (date.replace(day=1)+timedelta(days=32)).replace(day=1) - date.replace(day=1)
+ return since_month_begin.total_seconds() / month_duration.total_seconds()
+
+def interp_month_coords(date):
+ month_alpha = frac_month(date)
+ # the proper location of the month coordinate is the middle of the month
+ # so if month alpha is less than .5, we are interpolating between the previous month and the current month
+ # and the month alpha should be 1 if it was .5, 0.5 if it was 0
+ # if month alpha is greater than .5, we are interpolating between the current month and the next month
+ # the month alpha should be 0 if it was .5, 0.5 if it was 1
+ if month_alpha > .5:
+ month0 = date.month-1
+ month1 = date.month % 12
+ month_alpha -= .5
+ else:
+ month0 = (date.month-2) % 12
+ month1 = date.month-1
+ month_alpha = month_alpha - .5 + 1
+
+ return month0, month1, month_alpha
+
+def interp_season_hour_offset(season_hour_offset, local_hour, date):
+ i,j = np.meshgrid(np.arange(180), np.arange(360), indexing='ij')
+ hour_idx = np.floor(local_hour).astype(np.int32) % 24
+
+ month0, month1, month_alpha = interp_month_coords(date)
+
+ hour_alpha = local_hour.ravel() % 1
+
+ hour_offset = (1-month_alpha)*season_hour_offset[month0] + month_alpha*season_hour_offset[month1]
+
+ idx0 = np.ravel_multi_index((hour_idx.ravel(), i.ravel(), j.ravel()), hour_offset.shape)
+ idx1 = np.ravel_multi_index((hour_idx.ravel()+1, i.ravel(), j.ravel()), hour_offset.shape, mode='wrap')
+
+ offset = (1-hour_alpha)*hour_offset.ravel()[idx0] + hour_alpha*hour_offset.ravel()[idx1]
+ return offset.reshape(local_hour.shape)
+
+
+
+def correct_hour_season(files, season_hour_offset, *required_vars, floor_date=None, transform=filter_szas):
+ def init_stats():
+ return {'sum':np.zeros((180,360), dtype=np.float64), 'count':np.zeros((180,360), dtype=np.int32)}
+ stats = defaultdict(init_stats)
+
+ if floor_date is None:
+ floor_date = lambda x: x.replace(day=1)
+
+ nc = netCDF4.Dataset(files.iloc[0].iloc[0])
+ lon = nc.variables['longitude'][:]
+ utc_offset = lon / 360 * 24
+ nc.close()
+ files = files[list(required_vars)].dropna()
+ with tqdm(fast_iter_nc_var(files, *required_vars), total=len(files)) as bar:
+ for (platform, date, node), data in bar:
+ key = platform, floor_date(date), node
+ v, slt = transform(*data)
+
+ local_hour = (slt + utc_offset) - .5
+ offset = interp_season_hour_offset(season_hour_offset, local_hour, date)
+
+ v_full = (v - offset.reshape(v.shape))
+
+ stats[key]['sum'] += v_full.filled(0)
+ stats[key]['count'] += (1-np.ma.getmaskarray(v_full))
+
+ stats = dict(stats.items())
+ return stats
+
+def get_mean_from_stats(stat, lat, lon, min_count=1):
+ mean = []
+ platforms = []
+ start_times = []
+ nodes = []
+ with tqdm(stat) as bar:
+ for k in bar:
+ (platform, start_time, node) = k
+ mean.append((stat[k]['sum'] / np.ma.masked_less(stat[k]['count'], min_count)).astype(np.float32))
+ platforms.append(platform)
+ start_times.append(start_time)
+ nodes.append(node)
+ mean = np.ma.stack(mean)
+
+ mean = xr.DataArray(mean, dims=['group','lat','lon'],
+ coords={
+ 'platform':(['group'], platforms),
+ 'start_time':(['group'], start_times),
+ 'node':(['group'], nodes),
+ 'lat':(['lat'], lat),
+ 'lon':(['lon'], lon)
+ }
+ )
+ mean = mean.ffill('group').bfill('group')
+ return mean
+
+
+PLATFORM_COLORS = {'noaa-06': 'C0',
+ 'noaa-07': 'C1',
+ 'noaa-08': 'C2',
+ 'noaa-09': 'C3',
+ 'noaa-10': 'C4',
+ 'noaa-11': 'C5',
+ 'noaa-12': 'C6',
+ 'noaa-14': 'C7',
+ 'noaa-15': 'C8',
+ 'noaa-16': 'C9',
+ 'noaa-17': 'C10',
+ 'noaa-18': 'C11',
+ 'metop-a': 'C12',
+ 'noaa-19': 'C13',
+ 'metop-b': 'C14'}
+
+def plot_mean(mean):
+ import matplotlib.pyplot as plt
+ global_mean = mean.mean(dim=['lat','lon']).set_index(group=['platform','start_time','node']).to_pandas()
+ for (platform,node), s in global_mean.groupby(['platform','node']):
+ s = s.droplevel(['platform','node'])
+ s.plot(label=f'{platform} {node}', color=PLATFORM_COLORS[platform.lower()])
+ global_mean.groupby('start_time').mean().plot(label='mean', color='k')
+ plt.xlabel('')
+ plt.ylabel('Lower Cloud Coverage Anomaly (ACHA)')
+ plt.grid()
--
GitLab