diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b3c55e7c6d766efc4ecfa043acf546d3c7dd16ee --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 TJ Vandal + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/data/download_g5nr.py b/data/download_g5nr.py new file mode 100644 index 0000000000000000000000000000000000000000..725bf11e5ae2358923c7bc1f0f3a26527936bed9 --- /dev/null +++ b/data/download_g5nr.py @@ -0,0 +1,45 @@ +import os + +import urllib3 +import pandas + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('output', metavar='G5NR_7km/', type=str, + help='Where to save your data?') + +args = parser.parse_args() + +def url_to_pandas(url): + http = urllib3.PoolManager() + r = http.request('GET', url) + return pandas.read_html(r.data)[-2] + +path = 'https://portal.nccs.nasa.gov/datashare/gmao_obsteam/osse_for_wisc/' +to_directory = args.output + +if not os.path.exists(to_directory): + os.makedirs(to_directory) + +df = url_to_pandas(path) +files = [] + +for name in df['Name']: + try: + int(name[:8]) + except: + continue + + folder_path = os.path.join(path, name) + df_folder = url_to_pandas(folder_path) + for filename in df_folder['Name']: + if isinstance(filename, str): + if 'nc4' == filename[-3:]: + filepath = os.path.join(folder_path, filename) + print(f"Filepath: {filepath}") + if os.path.exists(filepath): + continue + + wget_cmd = f"wget -O {to_directory}/{filename} {filepath}" + os.system(wget_cmd) \ No newline at end of file diff --git a/data/g5nr_to_zarr.py b/data/g5nr_to_zarr.py new file mode 100644 index 0000000000000000000000000000000000000000..723ec0b49fcff711ba9d38122a87edbac73432a2 --- /dev/null +++ b/data/g5nr_to_zarr.py @@ -0,0 +1,29 @@ +import xarray as xr +import os, sys + +sys.path.append('..') + +from nexai.datasets.g5nr import g5nr +from dask.distributed import Client + + +# Set data variables +label_directory = '/nobackupp10/tvandal/nex-ai-opticalflow/data/G5NR_7km/' +label_files = g5nr.get_files(label_directory, 'test').reset_index() + +output_directory = '/nobackupp10/tvandal/nex-ai-opticalflow/data/G5NR_7km_zarr/' +print(label_files) + +N = len(label_files) +for i, row in label_files.iterrows(): + fname = os.path.basename(row['U']) + output_zarr = os.path.join(output_directory, fname.split('.')[2] + '.zarr') + if os.path.exists(output_zarr): + continue + + U = xr.open_mfdataset(row['U']) + V = xr.open_mfdataset(row['V']) + ds = xr.open_mfdataset(row['QV']).merge(U).merge(V) + ds = ds.chunk(dict(time=1, lat=500, lon=500, lev=1)) + ds.to_zarr(output_zarr) + print(i/N, output_zarr) \ No newline at end of file diff --git a/model_weights/windflow.raft.pth.tar b/model_weights/windflow.raft.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..0dbb63486c37cb23116fbb78033812c6beb5f3d6 Binary files /dev/null and b/model_weights/windflow.raft.pth.tar differ diff --git a/pipelines/eco1280/eco_to_flows.py b/pipelines/eco1280/eco_to_flows.py new file mode 100644 index 0000000000000000000000000000000000000000..ad864a88d690681b8d0c8b22c0f68de958f513e9 --- /dev/null +++ b/pipelines/eco1280/eco_to_flows.py @@ -0,0 +1,141 @@ +import os, sys +import glob +import argparse +import time + +# T.Rink, MacBook OSX doesn't support MPI +# from mpi4py import MPI + +# MPI_COMM = MPI.COMM_WORLD +# MPI_RANK = MPI_COMM.Get_rank() +# MPI_SIZE = MPI_COMM.Get_size() +MPI_RANK = 0 + +# T.Rink, MacBook GPU (AMD Radeon Pro 5300M) doesn't support CUDA +# N_GPUS = 2 + +# rank_gpu = MPI_RANK % N_GPUS +# os.environ['CUDA_VISIBLE_DEVICES'] = f'{rank_gpu}' +os.environ['CUDA_VISIBLE_DEVICES'] = f'{0.0}' + +import numpy as np +import pandas as pd +import xarray as xr +import matplotlib.pyplot as plt + +import torch + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +# from windflow.datasets.g5nr import g5nr +from windflow.inference.inference_flows import FlowRunner +from windflow.datasets.utils import cartesian_to_speed_eco + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name", default="raft", type=str) +# parser.add_argument("--checkpoint", default="models/raft-size_512/checkpoint.pth.tar", type=str) +parser.add_argument("--checkpoint", default="model_weights/windflow.raft.pth.tar", type=str) +parser.add_argument("--data_directory", default="data/ECO1280/", type=str) +parser.add_argument("--output", default="data/ECO_flows/raft/", type=str) +parser.add_argument("--n_files", default=None, type=int) +parser.add_argument("--batch_size", default=2, type=int) + +args = parser.parse_args() + +# Set model information +model_name = args.model_name + +# Set data variables +data_directory = args.data_directory +to_directory = args.output + +if (MPI_RANK == 0) and (not os.path.exists(to_directory)): + os.makedirs(to_directory) + +# get file list +#files = g5nr.get_files(data_directory, 'test') +#print("Number of OSSE test files", len(files)) + +pattern = os.path.join(data_directory, '*.nc4') +files = sorted(glob.glob(pattern)) + +variables = ['gp', 'uv'] +var_files = {v: sorted([f for f in files if f'{v}_' in f]) for v in variables} +var_files = pd.DataFrame(var_files) + +if args.n_files is None: + N_files = len(files) +else: + N_files = args.n_files + +# N_per_rank = N_files // MPI_SIZE +# files = files.iloc[N_per_rank * MPI_RANK: N_per_rank * (MPI_RANK+1) + 1] + +# load model +tile_size = 512 +overlap = 128 + +if model_name in ['pwc-net-rmsvd', 'pwc-net-l1']: + runner = FlowRunner('pwc-net', + tile_size=tile_size, + overlap=overlap, + batch_size=args.batch_size) +else: + runner = FlowRunner(model_name.replace('-guided','').replace('-unflow', ''), + tile_size=tile_size, + overlap=overlap, + batch_size=args.batch_size, + device=torch.device('cpu')) + +runner.load_checkpoint(args.checkpoint) + +stats = [] + +# iterate and perform inference, compute test statistics +ds1 = xr.open_mfdataset(files.iloc[0].values, engine='netcdf4') +for i in range(1, files.shape[0]): + # open dataset + + ds2 = xr.open_mfdataset(files.iloc[i].values, engine='netcdf4') + + f = os.path.basename(files.iloc[i-1]['uv']) # get ds1 file + to_flow_file = os.path.join(to_directory, f.replace('uv_', '_WindFlow_').replace('.nc', '.zarr')) + if os.path.exists(to_flow_file): + ds1 = ds2.copy() + continue + + # t = ds1['time'].values + + output_ds = xr.zeros_like(ds1) + del output_ds['gp_newP'] + + U_flows = np.zeros(ds1.ugrd_newP.shape) + V_flows = np.zeros(ds1.vgrd_newP.shape) + + t0 = time.time() + + for i, lev in enumerate(ds1.lev_p): + qv1_lev = ds1.sel(lev_p=lev)['gp_newP'].values[0] + qv2_lev = ds2.sel(lev_p=lev)['gp_newP'].values[0] + _, flows_lev = runner.forward(qv1_lev, qv2_lev) + U_flows[:, i] = flows_lev[0] + V_flows[:, i] = flows_lev[1] + + output_ds['ugrd_newP'] = output_ds['ugrd_newP'] + U_flows + output_ds['vgrd_newP'] = output_ds['vgrd_newP'] + V_flows + + output_ds = cartesian_to_speed_eco(output_ds) + + output_ds.attrs['Source'] = 'NEX' + output_ds.attrs['Title'] = 'Optical Flow Feature Tracking' + output_ds.attrs['Contact'] = 'vandal@baeri.org' + output_ds.attrs['History'] = 'G5NR outputs from GEOS-5 by gmao processed by NEX optical flow.' + output_ds.attrs['Model'] = model_name + output_ds.attrs['Pytorch_Checkpoint'] = args.checkpoint + + # output_ds.to_netcdf(to_flow_file) + output_ds.to_zarr(to_flow_file) + # print(f'Wrote to file {to_flow_file}') + print(f"Wrote to file: {to_flow_file} -- Processing time {time.time()-t0} (seconds)") + + ds1 = ds2.copy() diff --git a/pipelines/g5nr/g5nr_to_flows.py b/pipelines/g5nr/g5nr_to_flows.py new file mode 100644 index 0000000000000000000000000000000000000000..04a1d57521024cd584d98b98fc0d5c06800c3eaf --- /dev/null +++ b/pipelines/g5nr/g5nr_to_flows.py @@ -0,0 +1,131 @@ +import os, sys +import glob +import argparse +import time + +# T.Rink, MacBook OSX doesn't support MPI +# from mpi4py import MPI + +# MPI_COMM = MPI.COMM_WORLD +# MPI_RANK = MPI_COMM.Get_rank() +# MPI_SIZE = MPI_COMM.Get_size() + +# T.Rink, MacBook GPU (AMD Radeon Pro 5300M) doesn't support CUDA +# N_GPUS = 2 + +# rank_gpu = MPI_RANK % N_GPUS +# os.environ['CUDA_VISIBLE_DEVICES'] = f'{rank_gpu}' +os.environ['CUDA_VISIBLE_DEVICES'] = f'{0.0}' + +import numpy as np +import pandas as pd +import xarray as xr +import matplotlib.pyplot as plt + +import torch + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from windflow.datasets.g5nr import g5nr +from windflow.inference.inference_flows import FlowRunner +from windflow.datasets.utils import cartesian_to_speed + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name", default="raft", type=str) +parser.add_argument("--checkpoint", default="models/raft-size_512/checkpoint.pth.tar", type=str) +parser.add_argument("--data_directory", default="data/G5NR/", type=str) +parser.add_argument("--output", default="data/G5NR_Flows/raft/", type=str) +parser.add_argument("--n_files", default=None, type=int) +parser.add_argument("--batch_size", default=2, type=int) + +args = parser.parse_args() + +# Set model information +model_name = args.model_name + +# Set data variables +data_directory = args.data_directory +to_directory = args.output + +if (MPI_RANK == 0) and (not os.path.exists(to_directory)): + os.makedirs(to_directory) + +# get file list +files = g5nr.get_files(data_directory, 'test') +print("Number of OSSE test files", len(files)) + +if args.n_files is None: + N_files = len(files) +else: + N_files = args.n_files + +# N_per_rank = N_files // MPI_SIZE +# files = files.iloc[N_per_rank * MPI_RANK: N_per_rank * (MPI_RANK+1) + 1] + +# load model +tile_size = 512 +overlap = 128 + +if model_name in ['pwc-net-rmsvd', 'pwc-net-l1']: + runner = FlowRunner('pwc-net', + tile_size=tile_size, + overlap=overlap, + batch_size=args.batch_size) +else: + runner = FlowRunner(model_name.replace('-guided','').replace('-unflow', ''), + tile_size=tile_size, + overlap=overlap, + batch_size=args.batch_size) + +runner.load_checkpoint(args.checkpoint) + +stats = [] + +# iterate and perform inference, compute test statistics +ds1 = xr.open_mfdataset(files.iloc[0].values, engine='netcdf4') +for i in range(1, files.shape[0]): + # open dataset + + ds2 = xr.open_mfdataset(files.iloc[i].values, engine='netcdf4') + + f = os.path.basename(files.iloc[i-1]['U']) # get ds1 file + to_flow_file = os.path.join(to_directory, f.replace('_U_', '_WindFlow_').replace('.nc', '.zarr')) + if os.path.exists(to_flow_file): + ds1 = ds2.copy() + continue + + t = ds1['time'].values + + output_ds = xr.zeros_like(ds1) + del output_ds['QV'], output_ds['tpw'] + + U_flows = np.zeros(ds1.U.shape) + V_flows = np.zeros(ds1.V.shape) + + t0 = time.time() + + for i, lev in enumerate(ds1.lev): + qv1_lev = ds1.sel(lev=lev)['QV'].values[0] + qv2_lev = ds2.sel(lev=lev)['QV'].values[0] + _, flows_lev = runner.forward(qv1_lev, qv2_lev) + U_flows[0,i] = flows_lev[0] + V_flows[0,i] = flows_lev[1] + + output_ds['U'] = output_ds['U'] + U_flows + output_ds['V'] = output_ds['V'] + V_flows + + output_ds = cartesian_to_speed(output_ds) + + output_ds.attrs['Source'] = 'NEX' + output_ds.attrs['Title'] = 'Optical Flow Feature Tracking' + output_ds.attrs['Contact'] = 'vandal@baeri.org' + output_ds.attrs['History'] = 'G5NR outputs from GEOS-5 by gmao processed by NEX optical flow.' + output_ds.attrs['Model'] = model_name + output_ds.attrs['Pytorch_Checkpoint'] = args.checkpoint + + #output_ds.to_netcdf(to_flow_file) + output_ds.to_zarr(to_flow_file) + #print(f'Wrote to file {to_flow_file}') + print(f"Wrote to file: {to_flow_file} -- Processing time {time.time()-t0} (seconds)") + + ds1 = ds2.copy() diff --git a/pipelines/geo/geo_flows_to_zarr.py b/pipelines/geo/geo_flows_to_zarr.py new file mode 100644 index 0000000000000000000000000000000000000000..272e67090fd28f055ad85c6d0e4e1c135cb1187b --- /dev/null +++ b/pipelines/geo/geo_flows_to_zarr.py @@ -0,0 +1,128 @@ +''' +Author: TJ Vandal + +Process NOAA L1b Files to Flow files between a range of dates + +Inputs: 10-minute NOAA L1b files, Optical Flow Model +Output: 10-minute wind vectors +''' + +import os, sys +import time + +# from mpi4py import MPI +# +# comm = MPI.COMM_WORLD +# rank = comm.Get_rank() +# size = comm.Get_size() +# name = MPI.Get_processor_name() +rank = -1 +size = -1 +name = '' + +import torch + +# sys.stdout.write(f"Rank {rank} Size {size} Name {name} CUDA: {os.environ['CUDA_VISIBLE_DEVICES']} Devices count: {torch.cuda.device_count()}\n") +# sys.stdout.write(f'number of devices: {torch.cuda.device_count()}\n') +# sys.stdout.flush() + +device = 0 # 'cuda:0'#torch.cuda.current_device() #'cuda:0'# + str(rank % N_GPUS) +# torch.cuda.set_device(device) + +import glob +import argparse +import datetime as dt + +import xarray as xr +import numpy as np + + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from windflow.datasets.geostationary import goesr, stats +from windflow.inference import inference_flows + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name", default="raft", type=str) +parser.add_argument("--checkpoint", + default="/nobackupp10/tvandal/windflow/models/raft-size_512/checkpoint.pth.tar", type=str) +parser.add_argument('--start_date', type=lambda s: dt.datetime.strptime(s, '%Y-%m-%d'), + default='2021-01-01') +parser.add_argument('--end_date', type=lambda s: dt.datetime.strptime(s, '%Y-%m-%d'), default='2021-01-02') +parser.add_argument('--timestep', type=int, default=10) +parser.add_argument('--band', type=int, default=10) +parser.add_argument('--data_path', type=str, default='/nex/datapool/geonex/public/GOES16/NOAA-L1B/') +parser.add_argument('--spatial', type=str, default=None) +parser.add_argument('--upsample', type=float, default=None) +parser.add_argument('--batch_size', type=int, default=1) +parser.add_argument('--product', type=str, default='ABI-L1b-RadF') +parser.add_argument('--output_path', type=str, default='data/L1B-WindFlow/') +args = parser.parse_args() + + +curr_date = args.start_date +dates = [] +while curr_date <= args.end_date: + dates.append(curr_date) + #curr_date = curr_date + dt.timedelta(minutes=args.timestep) + curr_date = curr_date + dt.timedelta(hours=1) + +rank_idxs = np.arange(rank, len(dates)-1, size) + +inference = inference_flows.GeoFlows(args.model_name.replace('-guided', ''), + args.data_path, + overlap=128, + tile_size=512, + channels=[args.band], + product=args.product, + device=device, + batch_size=args.batch_size, + timestep=args.timestep, + upsample_data=args.upsample, + spatial=args.spatial) +inference.load_checkpoint(args.checkpoint) + +for idx in rank_idxs: + t = dates[idx] + + filename = f'GeoNEX_AMV_Flows-{args.product}-B{args.band:02g}-{t.year}{t.month:02g}{t.day:02g}_{t.hour:02g}{t.minute:02g}.{args.model_name}.zarr' + output_file_path = os.path.join(args.output_path, args.product, f'{t.year}', + f'{t.timetuple().tm_yday:03g}') + if not os.path.exists(output_file_path): + os.makedirs(output_file_path) + + output_file = os.path.join(output_file_path, filename) + if os.path.exists(output_file): + continue + + sys.stdout.write(f"Processing time {t}\n") + sys.stdout.flush() + + #data, flows = inference.flows_by_time(t) + try: + t0 = time.time() + data = inference.flows_by_time(t, reproject=True) + + ''' + img1 = data['Rad'].values + img1[img1 == 0] = np.nan + flows[0][img1 != img1] = np.nan + flows[1][img1 != img1] = np.nan + ''' + + ds_out = xr.Dataset(dict(U=data['U'], V=data['V'])) + ds_out['band'] = args.band + ds_out['lat'] = data['lat'] + ds_out['lon'] = data['lon'] + ds_out['t'] = t + #ds_out.attrs['input_dataset_name'] = data.attrs['dataset_name'] + + #ds_out.to_netcdf(output_file) + ds_out.chunk({'lat': 1000, 'lon': 1000}).to_zarr(output_file) + print(f"Wrote to file: {output_file} -- Processing time {time.time()-t0} (seconds)") + sys.stdout.write(output_file + '\n') + sys.stdout.flush() + + del ds_out, data + except Exception as err: + print('Exception Raised', err, t) diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed95c4f1bd37599d5f16645545222274cf4976 --- /dev/null +++ b/train.py @@ -0,0 +1,167 @@ +import sys +import os +import time +import argparse + +from torch.utils.tensorboard import SummaryWriter # there is a bug with summarywriter importing +from collections import OrderedDict + +import numpy as np +import torch +torch.cuda.empty_cache() + +import torch.nn as nn +from torch import optim +from torch.utils import data +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import torch.multiprocessing as mp + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from windflow.datasets import get_dataset +from windflow.networks.models import get_flow_model +from windflow.train.trainers import get_flow_trainer + +os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning' + + +def setup(rank, world_size, port): + ''' + Setup multi-gpu processing group + Args: + rank: current rank + world_size: number of processes + port: which port to connect to + Returns: + None + ''' + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = f'{port}' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def train_net(params, rank=0): + print('CUDA is available? ', torch.cuda.is_available()) + # set device + # if not device: + # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # device = rank #% N_DEVICES + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f'in train_net rank {rank} on device {device}') + + if params['ngpus'] > 1: + distribute = True + else: + distribute = False + + dataset_train, dataset_valid = get_dataset(params['dataset'], params['data_path'], + scale_factor=params['scale_input'], + frames=params['input_frames']) + + data_params = {'batch_size': params['batch_size'] // params['ngpus'], 'shuffle': True, + 'num_workers': 8, 'pin_memory': True} + training_generator = data.DataLoader(dataset_train, **data_params) + val_generator = data.DataLoader(dataset_valid, **data_params) + + model = get_flow_model(params['model_name'], small=False) + if distribute: + model = DDP(model.to(device), device_ids=[device], find_unused_parameters=True) + + trainer = get_flow_trainer(model, params['model_name'], + params['model_path'], + distribute=distribute, + rank=rank, + lr=params['lr'], + loss=params['loss']) + + trainer.model.to(device) + trainer.load_checkpoint() + + print(f'Start Training {params["model_name"]} on {params["dataset"]}') + + while trainer.global_step < params['max_iterations']: + running_loss = 0.0 + t0 = time.time() + for batch_idx, (images, flows) in enumerate(training_generator): + images = images.to(device) + flows = flows.to(device) + + log = False + if (trainer.global_step % params['log_step'] == 0): + log=True + + train_loss = trainer.step(images, flows, + log=log, train=True) + if np.isinf(train_loss.cpu().detach().numpy()): + print(train_loss, I0.cpu().detach().numpy(), I1.cpu().detach().numpy().flatten()) + return + + if log: + images, flows = next(iter(val_generator)) + valid_loss = trainer.step(images.to(device), flows.to(device), + log=log, train=False) + print(f'Rank {trainer.rank} @ Step: {trainer.global_step-1}, Training Loss: {train_loss:.3f}, Valid Loss: {valid_loss:.3f}') + + if (trainer.global_step % params['checkpoint_step'] == 0) and (rank == 0): + trainer.save_checkpoint() + + return best_validation_loss + + +def manual_experiment(args): + train_net(vars(args)) + + +def train_net_mp(rank, world_size, port, params): + ''' + Setup and train on node + ''' + setup(rank, world_size, port) + train_net(params, rank=rank) + + +def run_training(args, world_size, port): + #params['batch_size'] = params['batch_size'] // world_size + if world_size > 1: + mp.spawn(train_net_mp, + args=(world_size, port, vars(args)), + nprocs=world_size, + join=True) + cleanup() + else: + train_net(vars(args)) + + +if __name__ == "__main__": + # Feb 1 2020, Band 1 hyper-parameter search + # {'w': 0.6490224421024322, 's': 0.22545622639358046, 'batch_size': 128} + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", default="models/raft-size_512/", type=str) + parser.add_argument("--dataset", default="g5nr", type=str) + parser.add_argument("--data_path", default="data/G5NR_patches/", type=str) + parser.add_argument("--input_frames", default=2, type=int) + parser.add_argument("--model_name", default="raft", type=str) + # parser.add_argument("--gpus", default="0", type=str) + parser.add_argument("--batch_size", default=4, type=int) + parser.add_argument("--lr", default=1e-4, type=float) + parser.add_argument("--scale_input", default=None, type=float, help='Bilinear interpolation on input image.') + parser.add_argument('--max_iterations',type=int, default=2000000, help='Number of training iterations') + parser.add_argument("--log_step", default=1000, type=int) + parser.add_argument("--checkpoint_step", default=5000, type=int) + parser.add_argument("--ngpus", default=1, type=int) + parser.add_argument("--port", default=9009, type=int) + parser.add_argument("--loss", default='L1', type=str) + + args = parser.parse_args() + print('torch.cuda.device_count: ', torch.cuda.device_count()) + if torch.cuda.device_count() < args.ngpus: + print(f"Cannot running training because {args.ngpus} are not available.") + + run_training(args, args.ngpus, args.port) diff --git a/windflow/__init__.py b/windflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/datasets/__init__.py b/windflow/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6f9ed4d36f0ec9b43b5cc3b9aa4e1ed338aad5 --- /dev/null +++ b/windflow/datasets/__init__.py @@ -0,0 +1,2 @@ +from . import utils +from .datasets import * diff --git a/windflow/datasets/datasets.py b/windflow/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..9953d76b5f58dea5eafccf31f63b4dd3db58580f --- /dev/null +++ b/windflow/datasets/datasets.py @@ -0,0 +1,23 @@ +import os + +from .g5nr import G5NRFlows +from .geostationary import L1bPatches + +def get_dataset(dataset_name, data_path, size=512, scale_factor=None, frames=2): + # Load dataset + if dataset_name in ['gmao_osse_7km', 'g5nr', 'g5nr_7km']: + dataset_train = G5NRFlows(os.path.join(data_path, 'train'), + size=size, scale_factor=scale_factor, + augment=True, + frames=frames) + dataset_valid = G5NRFlows(os.path.join(data_path, 'valid'), + size=size, scale_factor=scale_factor, + frames=frames) + elif dataset_name.lower() in ['goes', 'goes16']: + # Load Geostationary Observations + dataset_train = L1bPatches(data_path, time_step=10, # minutes + size=size, mode='train') + dataset_valid = L1bPatches(data_path, time_step=10, + size=size, mode='valid') + + return dataset_train, dataset_valid diff --git a/windflow/datasets/flow_transforms.py b/windflow/datasets/flow_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ac399377aa4864e72bf7e4439f4e36d72b58db92 --- /dev/null +++ b/windflow/datasets/flow_transforms.py @@ -0,0 +1,626 @@ +# Copyright (c) 2019 Henrique Morimitsu +# This file is released under the "MIT License". Please see the LICENSE file +# included in this package. +# +# Many of the functions are modifications from the FlowNetPytorch package +# available at https://github.com/ClementPinard/FlowNetPytorch, but adapted +# to be implemented with PyTorch functions. + +import numpy as np +import torch +import torch.nn.functional as F +import warnings +from abc import abstractmethod +from typing import Any, Sequence, Tuple + + +""" Common tranformation functions for augmenting data for training deep +optical flow models. The functions most functions take two 4D torch.Tensor +inputs: (1) a batch of images, (2) a batch of flows. The tensors are assumed +to have a shape [N, C, H, W], where: + N = batch size, + C = number of channels (typically 3 for images and 2 for flows), + H = height, + W = width. +""" + + +class BaseTransform(object): + @abstractmethod + def __call__(self, + images: Any, + flows: Any) -> Tuple[Any, Any]: + pass + + +class ToTensor(BaseTransform): + """ Converts a 4D numpy.ndarray or a list of 3D numpy.ndarrays into a + 4D torch.Tensor. Unlike the torchvision implementation, no normalization is + applied to the values of the inputs. + + Args: + images_order: str: optional, default 'HWC' + Must be one of {'CHW', 'HWC'}. Indicates whether the input images + have the channels first or last. + flows_order: str: optional, default 'HWC' + Must be one of {'CHW', 'HWC'}. Indicates whether the input flows + have the channels first or last. + fp16: bool: optional, default False + If True, the tensors use have-precision floating point. + device: str or torch.device: optional, default 'cpu' + Name of the torch device where the tensors will be put in. + """ + def __init__(self, + images_order: str = 'HWC', + flows_order: str = 'HWC', + fp16: bool = False, + device: Any = 'cpu') -> None: + self.images_order = images_order.upper() + self.flows_order = flows_order.upper() + self.fp16 = fp16 + self.device = device + + def __call__(self, + images: Any, + flows: Any) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(images, list) or isinstance(images, tuple): + images = np.stack(images, axis=0) + if isinstance(flows, list) or isinstance(flows, tuple): + flows = np.stack(flows, axis=0) + images = torch.from_numpy(images) + flows = torch.from_numpy(flows) + if self.images_order == 'HWC': + images = images.permute(0, 3, 1, 2) + if self.flows_order == 'HWC': + flows = flows.permute(0, 3, 1, 2) + if self.fp16: + images = images.half() + flows = flows.half() + else: + images = images.float() + flows = flows.float() + images = images.to(device=self.device) + flows = flows.to(device=self.device) + return images, flows + + +class Compose(BaseTransform): + """ Similar to torchvision Compose. Applies a series of transforms from + the input list in sequence. + + Args: + transforms_list: Sequence[BaseTransform]: + A sequence of transforms to be applied. + """ + def __init__(self, + transforms_list: Sequence[BaseTransform]) -> None: + self.transforms_list = transforms_list + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + for t in self.transforms_list: + images, flows = t(images, flows) + return images, flows + + +class Normalize(BaseTransform): + """ Normalizes the images according to the equation: + images = (images - means) / stdevs. + + Args: + means: Sequence[float]: + A list with the values of the means of each of the channels of the + image. + stdevs: Sequence[float]: + A list with the values of the standard deviations of each of the + channels of the image. + """ + def __init__(self, + means: Sequence[float], + stdevs: Sequence[float]) -> None: + self.means = torch.from_numpy(np.array(means)).reshape(1, -1, 1, 1) + self.stdevs = torch.from_numpy(np.array(stdevs)).reshape(1, -1, 1, 1) + self.has_converted = False + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if not self.has_converted: + self.means = self.means.type(images.dtype).to(images.device) + self.stdevs = self.stdevs.type(images.dtype).to(images.device) + self.has_converted = True + images -= self.means + images /= self.stdevs + return images, flows + + +class Crop(BaseTransform): + """ Crops a region at the given position of the image according to the + given size. + + Args: + coordinates: Tuple[int, int]: + The (y, x) position of the origin of the crop. The coordinates + values depend on the choice of the parameters is_center. If + is_center == True, then coordinates should be the center point of + the region to be cropped. Otherwise, it should be the top-left + corner of the region. + size: Tuple[int, int]: + The size (height, width) of the region to be cropped. + is_center: bool: optional, default True + Determines how the coordinates will be used. See the comments about + coordinates. + """ + def __init__(self, + coordinates: Tuple[int, int], + size: Tuple[int, int], + is_center: bool = True) -> None: + y, x = coordinates + assert y >= 0 and x >= 0, 'The coordinates must be positive.' + h, w = size + assert h > 0 and w > 0, 'The size must be larger than zero.' + self.coordinates = coordinates + self.size = size + self.is_center = is_center + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + _, _, h, w = images.shape + ch, cw = self.size + cy, cx = self.coordinates + if self.is_center: + ch2 = ch // 2 + cw2 = cw // 2 + y1 = cy - ch2 + x1 = cx - cw2 + y2 = cy + ch - ch2 + x2 = cx + cw - cw2 + else: + y1 = cy + x1 = cx + y2 = y1 + ch + x2 = x1 + cw + if y1 < 0: + warnings.warn('y1 < 0 when cropping. It will be set as y1 = 0.', + RuntimeWarning) + y1 = 0 + if x1 < 0: + warnings.warn('x1 < 0 when cropping. It will be set as x1 = 0.', + RuntimeWarning) + x1 = 0 + if y2 > h: + warnings.warn( + 'y2 > image_height when cropping. ' + 'It will be set as y2 = image_height.', RuntimeWarning) + y2 = h + if x2 > w: + warnings.warn( + 'x2 > image_width when cropping. ' + 'It will be set as x2 = image_width.', RuntimeWarning) + x2 = w + images = images[:, :, y1:y2, x1:x2] + flows = flows[:, :, y1:y2, x1:x2] + return images, flows + + + +class CenterCrop(BaseTransform): + """ Crops a region in the center of the image according to the given size. + + Args: + size: Tuple[int, int]: + The size (height, width) of the region to be cropped. + """ + def __init__(self, + size: Tuple[int, int]) -> None: + self.size = size + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + _, _, h, w = images.shape + th, tw = self.size + if w == tw and h == th: + return images, flows + + x1 = (w - tw) // 2 + y1 = (h - th) // 2 + images = images[:, :, y1:y1+th, x1:x1+tw] + flows = flows[:, :, y1:y1+th, x1:x1+tw] + return images, flows + + +class Resize(BaseTransform): + """ Resizes the inputs to a new given size. + + Args: + size: Tuple[int, int]: + The new size (height, width) of the inputs. + """ + def __init__(self, + size: Tuple[int, int]) -> None: + self.size = size + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + yscale = float(self.size[0]) / images.shape[2] + xscale = float(self.size[1]) / images.shape[3] + images = F.interpolate( + images, size=self.size, mode='bilinear', align_corners=False) + flows = F.interpolate( + flows, size=self.size, mode='bilinear', align_corners=False) + flows[:, 0] *= yscale + flows[:, 1] *= xscale + return images, flows + + +class RandomHorizontalFlip(BaseTransform): + """ Randomly flips the entire batch horizontally with a probability of 0.5. + """ + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if np.random.rand() < 0.5: + images = torch.flip(images, dims=[3]) + flows = torch.flip(flows, dims=[3]) + flows[:, 0] *= -1 + return images, flows + + +class RandomVerticalFlip(BaseTransform): + """ Randomly flips the entire batch vertically with a probability of 0.5. + """ + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if np.random.rand() < 0.5: + images = torch.flip(images, dims=[2]) + flows = torch.flip(flows, dims=[2]) + flows[:, 1] *= -1 + return images, flows + + +class RandomAdditiveColor(BaseTransform): + """ Adds a random value to the image. The value is sampled from a normal + distribution. + + Args: + stdev: float: + The standard deviation value for the normal distribution. + independent: bool: optional, default False + if True, one different value is sampled for each image of the + batch. If False, the same value is added to all images. + """ + def __init__(self, + stdev: float, + independent: bool = False) -> None: + self.stdev = stdev + self.independent = independent + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.independent: + add_vals = torch.normal(0.0, self.stdev, [images.shape[0]]) + else: + add_vals = torch.normal(0.0, self.stdev, [1]) + add_vals = add_vals.repeat(images.shape[0]) + add_vals = add_vals.type(images.dtype) + add_vals = add_vals.to(images.device) + add_vals = add_vals.reshape(-1, 1, 1, 1) + images += add_vals + return images, flows + + +class RandomMultiplicativeColor(BaseTransform): + """ Multiplies the image by a random value. The value is sampled from an + uniform distribution according to the given boundaries. + + Args: + range_min: float: + The minimum value of the uniform distribution. + range_max: float: + The maximum value of the uniform distribution. + independent: bool: optional, default False + if True, one different value is sampled for each image of the + batch. If False, all images are multiplied by the same value. + """ + def __init__(self, + range_min: float, + range_max: float, + independent: bool = False) -> None: + self.range_min = range_min + self.range_max = range_max + self.independent = independent + self.interval = range_max - range_min + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.independent: + mult_vals = torch.rand([images.shape[0]]) + else: + mult_vals = torch.rand([1]) + mult_vals = mult_vals.repeat(images.shape[0]) + mult_vals = mult_vals.type(images.dtype) + mult_vals = mult_vals.to(images.device) + mult_vals *= self.interval + mult_vals += self.range_min + mult_vals = mult_vals.reshape(-1, 1, 1, 1) + images *= mult_vals + return images, flows + + +class RandomGammaColor(BaseTransform): + """ Applies a power transform to the image by a random value. The value is + sampled from an uniform distribution according to the given boundaries. + + Args: + range_min: float: + The minimum value of the uniform distribution. + range_max: float: + The maximum value of the uniform distribution. + pixel_min: float: + The minimum valid value of a pixel of the image. + pixel_max: float: + The maximum valid value of a pixel of the image. + independent: bool: optional, default False + if True, one different value is sampled for each image of the + batch. If False, all images are multiplied by the same value. + """ + def __init__(self, + range_min: float, + range_max: float, + pixel_min: float, + pixel_max: float, + independent: bool = False) -> None: + self.range_min = range_min + self.range_max = range_max + self.pixel_min = pixel_min + self.pixel_max = pixel_max + self.independent = independent + self.interval = range_max - range_min + self.pixel_interval = pixel_max - pixel_min + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.independent: + expo_vals = torch.rand([images.shape[0]]) + else: + expo_vals = torch.rand([1]) + expo_vals = expo_vals.repeat(images.shape[0]) + expo_vals = expo_vals.type(images.dtype) + expo_vals = expo_vals.to(images.device) + expo_vals *= self.interval + expo_vals += self.range_min + expo_vals = torch.pow(expo_vals, -1.0) + expo_vals = expo_vals.reshape(-1, 1, 1, 1) + images = ( + torch.pow((images-self.pixel_min)/self.pixel_interval, expo_vals) * + self.pixel_interval + self.pixel_min) + return images, flows + + +class AddGaussianNoise(BaseTransform): + """ Adds noise to the image by summing each pixel to a value sampled from + a normal distribution. + + Args: + stdev: float: + The standard deviation value for the normal distribution. + """ + def __init__(self, + stdev: float) -> None: + self.stdev = stdev + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + noise = torch.normal(0.0, self.stdev, images.shape) + noise = noise.type(images.dtype) + noise = noise.to(images.device) + images += noise + return images, flows + + +class RandomScale(BaseTransform): + """ Rescale the inputs to a new size according to the given boundaries. + + Args: + range_min: float: + The minimum value of the uniform distribution. + range_max: float: + The maximum value of the uniform distribution. + """ + def __init__(self, + range_min: float, + range_max: float) -> None: + self.range_min = range_min + self.range_max = range_max + self.interval = range_max - range_min + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + scale = torch.rand([1]) + scale = scale.type(flows.dtype) + scale = scale.to(flows.device) + scale *= self.interval + scale += self.range_min + new_size = (int(scale*images.shape[2]), int(scale*images.shape[3])) + images = F.interpolate( + images, size=new_size, mode='bilinear', align_corners=False) + flows = F.interpolate( + flows, size=new_size, mode='bilinear', align_corners=False) + flows *= scale + return images, flows + + +class RandomCrop(BaseTransform): + """ Crops the inputs at a random location according to a given size. + + Args: + size: Tuple[int, int]: + The size [height, width] of the crop. + """ + def __init__(self, + size: Tuple[int, int]) -> None: + self.size = size + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + _, _, h, w = images.shape + th, tw = self.size + if w == tw and h == th: + return images, flows + + x1 = np.random.randint(0, w - tw) + y1 = np.random.randint(0, h - th) + images = images[:, :, y1:y1+th, x1:x1+tw] + flows = flows[:, :, y1:y1+th, x1:x1+tw] + return images, flows + + +class RandomTranslate(BaseTransform): + """ Creates a translation between images by applying a random alternated + crop on the sequence of inputs. A translation value t is randomly selected + first. Then, the first image is cropped by a box translated by t. The + second image will be cropped by a reversed translation -t. The third will + be cropped by t again, and so on... + + Args: + translation: Tuple[int, int] + The maximum absolute translation values (in pixels) in the + vertical and horizontal axis respectively. + """ + def __init__(self, + translation: Tuple[int, int]) -> None: + self.translation = translation + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + _, _, h, w = images.shape + th, tw = self.translation + tw = np.random.randint(-tw, tw) + th = np.random.randint(-th, th) + if tw == 0 and th == 0: + return images, flows + + new_images = images[:, :, abs(th):, abs(tw):].clone() + new_flows = flows[:, :, abs(th):, abs(tw):].clone() + + x1, x2 = max(0, tw), min(w+tw, w) + y1, y2 = max(0, th), min(h+th, h) + new_images[::2] = images[::2, :, y1:y2, x1:x2] + new_flows[::2] = flows[::2, :, y1:y2, x1:x2] + new_flows[::2, 0] += tw + new_flows[::2, 1] += th + + x1, x2 = max(0, -tw), min(w-tw, w) + y1, y2 = max(0, -th), min(h-th, h) + new_images[1::2] = images[1::2, :, y1:y2, x1:x2] + new_flows[1::2] = flows[1::2, :, y1:y2, x1:x2] + new_flows[1::2, 0] -= tw + new_flows[1::2, 1] -= th + + return new_images, new_flows + + +class RandomRotate(BaseTransform): + """ Applies random rotation to the inputs. The inputs are rotated around + the center of the image. First all inputs are rotated by the same random + major `angle`. Then, another random angle a is sampled according to `diff_angle`. + The first image will be rotated by a. The second image will be rotated by a + reversed angle -a. The third will be rotated by a again, and so on... + + Args: + angle: float: + The maximum absolute value to sample the major angle from. + diff_angle: float: optional, default 0 + The maximum absolute value to sample the angle difference between + consecutive images. + """ + def __init__(self, + angle: float, + diff_angle: float = 0): + self.angle = angle + self.diff_angle = diff_angle + + def __call__(self, + images: torch.Tensor, + flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + major_angle = np.random.uniform(-self.angle, self.angle) + inter_angle = np.random.uniform(-self.diff_angle, self.diff_angle) + + _, _, h, w = images.shape + + def generate_rotation_grid(rot_angle, batch_size): + vy, vx = torch.meshgrid(torch.arange(h), torch.arange(w)) + vx = vx.type(flows.dtype) + vy = vy.type(flows.dtype) + vx = vx.to(flows.device) + vy = vy.to(flows.device) + vx -= (w-1.0)/2.0 + vy -= (h-1.0)/2.0 + angle_rad = rot_angle*2*np.pi / 360 + rotx = np.cos(angle_rad)*vx - np.sin(angle_rad)*vy + roty = np.sin(angle_rad)*vx + np.cos(angle_rad)*vy + rotx = rotx / ((w-1)/2) + roty = roty / ((h-1)/2) + rot_grid = torch.stack((rotx, roty), dim=2)[None] + rot_grid = rot_grid.repeat(batch_size, 1, 1, 1) + return rot_grid + + def generate_rotation_matrix(rot_angle, batch_size): + vx, vy = torch.meshgrid(torch.arange(h), torch.arange(w)) + vx = vx.type(flows.dtype) + vy = vy.type(flows.dtype) + vx = vx.to(flows.device) + vy = vy.to(flows.device) + rotx = (vx - h/2.0) * (rot_angle*np.pi/180.0) + roty = -(vy - w/2.0) * (rot_angle*np.pi/180.0) + rot_mat = torch.stack((rotx, roty), dim=0)[None] + rot_mat = rot_mat.repeat(batch_size, 1, 1, 1) + return rot_mat + + def rotate_flow(flow, rot_angle): + angle_rad = rot_angle*2*np.pi / 360 + rot_flow = flow.clone() + rot_flow[:, 0] = ( + np.cos(angle_rad)*flow[:, 0] + np.sin(angle_rad)*flow[:, 1]) + rot_flow[:, 1] = ( + -np.sin(angle_rad)*flow[:, 0] + np.cos(angle_rad)*flow[:, 1]) + return rot_flow + + angle = major_angle - inter_angle / 2 + num_images = images[::2].shape[0] + rot_grid = generate_rotation_grid(angle, num_images) + mn = images.min() + mx = images.max() + + images[::2] = F.grid_sample(images[::2], rot_grid, mode='bilinear', align_corners=True) + + num_flows = flows[::2].shape[0] + rot_mat = generate_rotation_matrix(inter_angle, num_flows) + flows[::2] += rot_mat + flows[::2] = F.grid_sample( + flows[::2], rot_grid[:num_flows], mode='bilinear', align_corners=True) + flows[::2] = rotate_flow(flows[::2], angle) + + angle = major_angle + inter_angle / 2 + num_images = images[1::2].shape[0] + rot_grid = generate_rotation_grid(angle, num_images) + images[1::2] = F.grid_sample(images[1::2], rot_grid, mode='bilinear', align_corners=True) + num_flows = flows[1::2].shape[0] + flows[1::2] -= rot_mat + flows[1::2] = F.grid_sample( + flows[1::2], rot_grid[:num_flows], mode='bilinear', align_corners=True) + flows[1::2] = rotate_flow(flows[1::2], angle) + + return images, flows diff --git a/windflow/datasets/g5nr/__init__.py b/windflow/datasets/g5nr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b50680b15f72599a8e7f555eb1375539f3ccb343 --- /dev/null +++ b/windflow/datasets/g5nr/__init__.py @@ -0,0 +1 @@ +from .g5nr_loader import G5NRFlows diff --git a/windflow/datasets/g5nr/g5nr.py b/windflow/datasets/g5nr/g5nr.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e05461f1eed5e7efb290b8ed733efe0e0abe0e --- /dev/null +++ b/windflow/datasets/g5nr/g5nr.py @@ -0,0 +1,91 @@ +''' +Author: TJ Vandal + +GMAO G5NR files from Will McCarty in June 2021 +Script to generating training data for optical flow. +''' + +import os +import glob +import xarray as xr +import numpy as np +import pandas as pd + +from mpi4py import MPI + +comm = MPI.COMM_WORLD +MPI_RANK = comm.Get_rank() +MPI_SIZE = comm.Get_size() + +def get_files(folder, mode): + pattern = os.path.join(folder, '*.nc4') + files = sorted(glob.glob(pattern)) + + variables = ['QV', 'U', 'V'] + var_files = {v: sorted([f for f in files if f'_{v}_' in f]) for v in variables} + var_files = pd.DataFrame(var_files) + + N_files = len(var_files) + if mode == 'train': + var_files = var_files.iloc[0:int(N_files*0.7)] + elif mode == 'valid': + var_files = var_files.iloc[int(N_files*0.7):int(N_files*0.8)] + elif mode == 'test': + var_files = var_files.iloc[int(N_files*0.8):] + + return var_files + +def make_training_data(input_folder, output_folder, mode='train', patch_size=560, step=300): + files = get_files(data_folder, mode) + + for i in np.arange(MPI_RANK, len(files), MPI_SIZE): + data = xr.open_mfdataset(files.values[i]) + data = data.sel(lat=slice(-80, 80)) + n_lats = data.lat.shape[0] + n_lons = data.lon.shape[0] + n_lev = data.lev.shape[0] + + time = data.time.values[0] + timestamp = (time.astype('uint64') / 1e6).astype('uint32') + + for lev in data.lev.values: + for lat_idx in np.arange(0, n_lats-patch_size, step): + for lon_idx in np.arange(0, n_lons-patch_size, step): + + sub_ds = data.isel(lat=slice(lat_idx, lat_idx+patch_size), + lon=slice(lon_idx, lon_idx+patch_size))\ + .sel(lev=lev) + lat_0 = sub_ds.lat.values[0] + lon_0 = sub_ds.lon.values[0] + sub_folder = os.path.join(output_folder, f'{int(lev)}_{lat_0}_{lon_0}') + sub_file = os.path.join(sub_folder, str(timestamp).zfill(11) + '.nc4') + if os.path.exists(sub_file): + continue + + try: + if not os.path.exists(sub_folder): + os.makedirs(sub_folder) + except IOError: + pass + + sub_ds.to_netcdf(sub_file) + print(f"Rank: {MPI_RANK}, Saved patch to file {sub_file}") + +if __name__ == '__main__': + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('data', metavar='G5NR_7km/', type=str, + help='Where to save your data?') + parser.add_argument('output', metavar='G5NR_7km_training/', type=str, + help='Where to save your data?') + args = parser.parse_args() + + + data_folder = args.data + training_data_folder = args.output + make_training_data(data_folder, training_data_folder+'train', mode='train') + make_training_data(data_folder, training_data_folder+'valid', mode='valid') + + diff --git a/windflow/datasets/g5nr/g5nr_loader.py b/windflow/datasets/g5nr/g5nr_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..e2058656e0d9b73988a3293b07b457269c5b137a --- /dev/null +++ b/windflow/datasets/g5nr/g5nr_loader.py @@ -0,0 +1,109 @@ +import os +import glob +import numpy as np +import xarray as xr + +import torch +from torch.utils import data +import cv2 + +from .. import flow_transforms +from ..preprocess import image_histogram_equalization + +class TileFlows(data.Dataset): + def __init__(self, directory, scale_factor=None, size=None, + augment=False, frames=2, convert_cartesian=False): + ''' + 7-km, 30-min simulation + Units: m/s + Scale to cartesian coordinates: 1800 (seconds) / 7000 (m) + ''' + self.directory = directory + self.files = np.array(glob.glob(os.path.join(self.directory, '*.nc4'))) + + filenames = np.array([int(os.path.basename(f.replace('.nc4',''))) for f in self.files]) + filesort = np.argsort(filenames) + self.files = self.files[filesort] + + self.size = size + self.scale_factor = scale_factor + self.augment = augment + self.frames = frames + self.convert_cartesian = convert_cartesian + + self.rand_transform = flow_transforms.Compose([ + flow_transforms.ToTensor(images_order='CHW', flows_order='CHW'), # Must come before other transforms + #flow_transforms.RandomRotate(15), # rotate fills in a lot of the image with zeros + flow_transforms.RandomCrop((size, size)), + flow_transforms.RandomHorizontalFlip(), + flow_transforms.RandomVerticalFlip(), + #flow_transforms.RandomAdditiveColor(0.01), + #flow_transforms.RandomMultiplicativeColor(0.9, 1.1) + ]) + + def __len__(self): + return len(self.files)-self.frames+1 + + def __getitem__(self, idx): + try: + ds = xr.open_mfdataset(self.files[idx:idx+self.frames]) + h, w = ds.U.isel(time=0).shape + diff = ds.time.values[1] - ds.time.values[0] + if diff != np.timedelta64(30,'m'): + #print("Difference in time not 30 minutes", ds.time) + return self.__getitem__((idx+1) % len(self)) + + if self.scale_factor: + h, w = int(h * self.scale_factor), int(w * self.scale_factor) + new_lats = np.linspace(ds.lat.values[0], ds.lat.values[-1], h) + new_lons = np.linspace(ds.lon.values[0], ds.lon.values[-1], w) + ds = ds.interp(lat=new_lats, lon=new_lons) + + qv = ds['QV'].values + u = ds['U'].values + v = ds['V'].values + except (KeyError, xr.MergeError) as err: + print("Error in G5NR Loader __getitem__", err) + return self.__getitem__((idx+self.frames) % len(self)) + except (AttributeError) as err: + print(f"Error: {err}") + print(f"Files", self.files[idx:idx+self.frames]) + return self.__getitem__((idx+self.frames) % len(self)) + + uv = np.concatenate([u[:,np.newaxis], v[:,np.newaxis]], 1) #/ 7000. * (30 * 60) + uv[~np.isfinite(uv)] = 0. + qv[~np.isfinite(qv)] = 0. + + # pixel size differs by latitude + # see haversine formula -- compute in lon direction, lats are constant + + if self.convert_cartesian: + #uv = uv * CARTESIAN_SCALE + + lat_rad = np.radians(ds.lat.values) + lon_rad = np.radians(ds.lon.values) + a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2 + d = 2 * 6378.137 * np.arcsin(a**0.5) + size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # kms + + uv = uv / size_per_pixel / 1000 * 1800 + + qv = image_histogram_equalization(qv) + + images = [q[np.newaxis] for q in qv] + flows = [_uv for _uv in uv] + images_tensor, flow_tensor = self.rand_transform(images, flows) + return images_tensor, flow_tensor + + +class G5NRFlows(data.ConcatDataset): + def __init__(self, directory, scale_factor=None, size=None, augment=False, frames=2, convert_cartesian=True): + self.directory = directory + tiles = os.listdir(self.directory) + tile_paths = [os.path.join(self.directory, t) for t in tiles] + data.ConcatDataset.__init__(self, [TileFlows(p, scale_factor=scale_factor, size=size, augment=augment, frames=frames, convert_cartesian=convert_cartesian) for p in tile_paths]) + + +if __name__ == '__main__': + dataset = G5NRFlows('/nobackupp10/tvandal/nex-ai-opticalflow/data/G5NR_7km_training/train/', size=128) + print(dataset[0][0].shape) diff --git a/windflow/datasets/geostationary/__init__.py b/windflow/datasets/geostationary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..365315906f3a049e9561971fbddba33c460c474a --- /dev/null +++ b/windflow/datasets/geostationary/__init__.py @@ -0,0 +1,2 @@ +#from .schema import BandSchema1000, BandSchema256 +from .pytorch_dataloader import L1bResized, L1bPatches diff --git a/windflow/datasets/geostationary/geonexl1g.py b/windflow/datasets/geostationary/geonexl1g.py new file mode 100755 index 0000000000000000000000000000000000000000..9e2745ef2b4ae98fdba14d3aaa210cad4a18effb --- /dev/null +++ b/windflow/datasets/geostationary/geonexl1g.py @@ -0,0 +1,296 @@ +import os, sys +import numpy as np +from pyhdf.SD import SD, SDC +from scipy import ndimage +import glob +import pandas as pd +import xarray as xr +from joblib import Parallel, delayed + + +''' +# Basic parameters +lat_0 = 60 +lon_0 = -180 +res_x = 0.01 # 0.02 for the 2km grid +res_y = 0.01 # 0.02 for the 2km grid +tile_xdim = 600 # 300 for the 2km grid +tile_ydim = 600 # 300 for the 2km grid + +# Input information +hid = 14 # 0 - 59 +vid = 5 # 0 - 19 +x = 0 # column/sample, 0-(tile_xdim-1) +y = 0 # row/line, 0-(tile_ydim-1) + +# Output formula +lat_ulcnr = lat_0 - (vid*tile_ydim + y)*res_y # upper-left corner latitude +lon_ulcnr = lon_0 + (hid*tile_xdim + x)*res_y # upper-left corner longitude +lat_cntr = lat_ulcnr - 0.5*res_y # center latitude +lon_cntr = lon_ulcnr + 0.5*res_x # center longitude +''' + +def get_tile_coords(tile, resolution_km=1.): + hid = int(tile[1:3]) + vid = int(tile[4:6]) + + lat_0 = 60 + lon_0 = -180 + res_x = 0.01 * resolution_km + res_y = 0.01 * resolution_km + tile_xdim = int(600 / resolution_km) + tile_ydim = int(600 / resolution_km) + + lat_ulcnr = lat_0 - vid*tile_ydim*res_y # upper-left corner latitude + lon_ulcnr = lon_0 + hid*tile_xdim*res_y # upper-left corner longitude + lat_cntr = lat_ulcnr - 0.5*res_y # center latitude + lon_cntr = lon_ulcnr + 0.5*res_x + lats = np.linspace(lat_0 - vid*tile_ydim*res_y, lat_0 - (vid+1)*tile_ydim*res_y+res_y, tile_ydim, endpoint=True) + lons = np.linspace(lon_0 + hid*tile_xdim*res_x, lon_0 + (hid+1)*tile_xdim*res_x-res_x, tile_xdim, endpoint=True) + return lats, lons + +class L1GFile(object): + ''' + Reads a single L1B file at a common resolution. Channels are bilinearly interpolated to the defined resolution. + Args: + file: Filepath to L1b + bands (optional): List of bands, default=list(range(1,17)) + resolution_km (optional): Resolution in km for common grid, default=2 + ''' + def __init__(self, file, bands=list((range(1,17))), + resolution_km=2.): + self.file = file + self.bands = bands + self.resolution_km = resolution_km + self.resolution_size = int(600. / resolution_km) + self.reflective_bands = list(range(1,7)) + self.emissive_bands = list(range(7,17)) + + def load(self): + fp = SD(self.file, SDC.READ) + data_array = np.zeros((self.resolution_size, self.resolution_size, len(self.bands))) + for i, b in enumerate(self.bands): + b_obj = fp.select('BAND%02i' % b) + attrs = b_obj.attributes() + scale_factor = attrs['Scale_Factor'] + #fill_value = attrs['_FillValue'] ## this fill value seems to be wrong in l1g + fill_value = 32768. + arr = b_obj.get()[:].astype(np.float32) + offset = attrs['Offset_Constant'] + arr[arr == fill_value] = np.nan + arr *= scale_factor + arr += offset + if arr.shape[0] != self.resolution_size: + arr = ndimage.interpolation.zoom(arr, self.resolution_size/arr.shape[0], order=1) + data_array[:,:,i] = arr + #self.data_array = data_array + return data_array + + def solar(self): + fp = SD(self.file, SDC.READ) + sa = fp.select('Solar_Azimuth').get()[:] + sz = fp.select('Solar_Zenith').get()[:] + sa = ndimage.interpolation.zoom(sa, self.resolution_size/sa.shape[0], order=1) + sz = ndimage.interpolation.zoom(sz, self.resolution_size/sz.shape[0], order=1) + return sa*0.01, sz*0.01 + + def load_xarray(self): + data = self.load() + fp = SD(self.file, SDC.READ) + tile_path = os.path.join("/".join(self.file.split("/")[:-3]), 'constant') + tile_file = glob.glob(tile_path + '/*.hdf')[0] + sat, sensor, date, time, _, tile, _ = os.path.basename(self.file).replace(".hdf","").split("_") + lats, lons = get_tile_coords(tile, self.resolution_km) + sa, sz = self.solar() + #print(tile, 'lats', len(lats)) + + da = xr.DataArray(data, dims=('lat', 'lon', 'band'), coords=dict(lat=lats, lon=lons, band=self.bands)) + zenith = xr.DataArray(sz, dims=('lat', 'lon'), coords=dict(lat=lats, lon=lons)) + azimuth = xr.DataArray(sa, dims=('lat', 'lon'), coords=dict(lat=lats, lon=lons)) + return xr.Dataset({'data': da, 'zenith': zenith, 'azimuth': azimuth}) + +class GeoNEXL1G(object): + ''' + Get information on L1G data directory, available tiles, years, and files + file lists are locally cached to future reading as retrieving file lists + can be time consuming. + Args: + data_directory: directory of the L1G product + sensor: (G16,G17,H8) + /nex/datapool/geonex/public/ + ''' + def __init__(self, data_directory, sensor): + self.data_directory = data_directory + self.sensor = sensor + self.sat = os.path.basename(os.path.dirname(os.path.dirname(data_directory))) + + def tiles(self): + tile_pattern = os.path.join(self.data_directory, 'h*v*') + tile_folders = glob.glob(tile_pattern) + tiles = [os.path.basename(t) for t in tile_folders] + return tiles + + def years(self): + tile = self.tiles()[0] + years = os.listdir(os.path.join(self.data_directory, tile)) + years = [int(y) for y in years if y[0] == '2'] + return years + + def hours(self): + return list(range(0,24)) + + def files(self, tile=None, year=None, dayofyear=None, cachedir='.tmp'): + ''' + Args: + tile (optional): Tile from GeoNEX grid + year (optional): Year of files to get + dayofyear (optional): Day of year + cachedir (optional): Cache filelist in directory + Returns: + pd.DataFrame of filelist with year, dayofyear, hour, minute, tile, file, h, and v + ''' + if tile == None: + tile = '*' + if year == None: + year = '*' + else: + year = str(year) + if dayofyear == None: + dayofyear = '*' + else: + dayofyear = '%03i' % dayofyear + + cache_file = f'{cachedir}/filelist/{self.sat}_{self.sensor}_{tile}_{year}_{dayofyear}.pkl' + if os.path.exists(cache_file): + return pd.read_pickle(cache_file) + + + file_pattern = os.path.join(self.data_directory, '%s/%s/%s/*.hdf' % (tile, year, dayofyear)) + files = glob.glob(file_pattern) + fileinfo = [] + for f in files: + root = os.path.dirname(f) + rl = root.split('/') + doy = int(rl[-1]) + y = int(rl[-2]) + + fl = os.path.basename(f).split('_') + hour = int(fl[3][:2]) + minute = int(fl[3][2:]) + tile = fl[5] + h = int(tile[1:3]) + v = int(tile[4:6]) + + fileinfo.append(dict(year=y, dayofyear=doy, hour=hour, + minute=minute, file=f, tile=tile, h=h, v=v)) + fileinfo = pd.DataFrame(fileinfo) + if not os.path.exists(os.path.dirname(cache_file)): + os.makedirs(os.path.dirname(cache_file)) + fileinfo.to_pickle(cache_file) + return fileinfo + +class L1GPaired(object): + def __init__(self, data_path1, data_path2, sensor1, sensor2): + ''' + Retrieve lists of files that match temporally and spatially using two GeoNEXL1b objects + + Args: + data_path1: First L1G data directory + data_path2: Second L1G data directory + sensor1: First sensor + sensor2: Second sensor + ''' + self.data_path1 = data_path1 + self.data_path2 = data_path2 + + self.sensor1 = sensor1 + self.sensor2 = sensor2 + + self.data1 = GeoNEXL1G(data_path1, sensor1) + self.data2 = GeoNEXL1G(data_path2, sensor2) + + def tiles(self): + tiles1 = self.data1.tiles() + tiles2 = self.data2.tiles() + intersect = np.intersect1d(tiles1, tiles2) + return intersect.tolist() + + def day_pairs(self, year): + tile = self.tiles()[0] + path1 = os.path.join(self.data_path1, tile, str(year)) + return [int(day) for day in os.listdir(path1)] + + def files(self, tile=None, year=None, dayofyear=None, how='inner', cachedir='.tmp'): + ''' + Get filelists for each data directory and join in space-time + + Args: + tile (optional): Tile from GeoNEX grid + year (optional): Year of files to get + dayofyear (optional): Day of year + how (optional): Pandas method to join DataFrames, default='inner' + cachedir (optional): Cache filelist in directory + Returns: + pd.DataFrame of filelist with year, dayofyear, hour, minute, tile, file, h, and v + ''' + files1 = self.data1.files(tile=tile, year=year, + dayofyear=dayofyear, cachedir=cachedir) + files2 = self.data2.files(tile=tile, year=year, + dayofyear=dayofyear, cachedir=cachedir) + get_timestamp = lambda x: '_'.join(os.path.basename(x).split('_')[2:4]) + + files1['timestamp'] = files1['file'].apply(get_timestamp) + files1 = files1.set_index(['timestamp', 'tile']) + + files2['timestamp'] = files2['file'].apply(get_timestamp) + files2 = files2.set_index(['timestamp', 'tile']) + joined = files1.join(files2, lsuffix='1', rsuffix='2', how=how) + return joined + + +def _read_file_to_xarray(f, bands, resolution_km): + return L1GFile(f, resolution_km=resolution_km, bands=bands).load_xarray() + + +class MOSAIC(object): + def __init__(self, data_path, sensor, h_list=None, v_list=None): + self.data_path = data_path + self.sensor = sensor + self.geo = GeoNEXL1G(self.data_path, self.sensor) + self.h_list = h_list + self.v_list = v_list + + def read_data(self, year, dayofyear, hour, minute, + bands=list(range(1,17)), resolution_km=2.): + files = self.geo.files(year=year, dayofyear=dayofyear) + files = files[files.hour == hour] + files = files[files.minute == minute].reset_index() + piv = files.pivot('v', 'h') + files = piv['file'] + + if self.geo.sensor == 'H8': + files = pd.concat([files.loc[:,44:59], files.loc[:,0:3]], 1) + elif self.geo.sensor == 'G17': + files = pd.concat([files.loc[:,57:59], files.loc[:,0:16]], 1) + elif self.geo.sensor == 'G16': + files = files.loc[:,7:26] + + #files = files.iloc[5:15,5:15] + all_data = [] + if self.h_list is None: + h_list = files.columns.values + if self.v_list is None: + v_list = files.index.values + + for v in self.v_list: + jobs = [] # joblib does not seem to add performance + for h in self.h_list: + f = files.loc[v, h] + if isinstance(f, str): + jobs.append(_read_file_to_xarray(f, bands, resolution_km)) + #jobs.append(delayed(_read_file_to_xarray)(f, bands)) + all_data.append(jobs) + #all_data.append(Parallel(n_jobs=4)(jobs)) + + ds = xr.concat([xr.concat(d, 'lon') for d in all_data], 'lat') + return ds \ No newline at end of file diff --git a/windflow/datasets/geostationary/goesr.py b/windflow/datasets/geostationary/goesr.py new file mode 100644 index 0000000000000000000000000000000000000000..0972570827985d173f24c071eb6bde75bc6de0e6 --- /dev/null +++ b/windflow/datasets/geostationary/goesr.py @@ -0,0 +1,346 @@ +import os, sys +import glob +import xarray as xr +import pandas as pd +import numpy as np +import dask as da +import datetime as dt +import matplotlib.pyplot as plt +#from mpl_toolkits.basemap import Basemap +import pyproj +import pyresample +from .. import utils + + +def get_filename_metadata(f): + channel = int(f.split('_')[1][-2:]) + spatial = f.split('-')[2] + t1 = f.split('_')[3] + year = int(t1[1:5]) + dayofyear = int(t1[5:8]) + hour = int(t1[8:10]) + minute = int(t1[10:12]) + second = int(t1[12:15]) + return dict(channel=channel, year=year, dayofyear=dayofyear, hour=hour, + minute=minute, second=second, spatial=spatial) + +class L1bBand(object): + def __init__(self, fpath): + self.fpath = fpath + meta = get_filename_metadata(self.fpath) + self.band = meta['channel'] + self.year = meta['year'] + self.dayofyear = meta['dayofyear'] + self.hour = meta['hour'] + self.minute = meta['minute'] + self.second = meta['second'] + self.spatial = meta['spatial'] + self.datetime = dt.datetime(self.year, 1, 1, self.hour, self.minute, self.second//10) +\ + dt.timedelta(days=self.dayofyear-1) + + def open_dataset(self, rescale=True, force=False, chunks=None): + if (not hasattr(self, 'data')) or force: + ds = xr.open_dataset(self.fpath, chunks=chunks) + ds = ds.where(ds.DQF.isin([0, 1])) + band = ds.band_id[0] + # normalize radiance + if rescale: + radiance = ds['Rad'] + if band <= 6: + ds['Rad'] = ds['Rad'] * ds.kappa0 + else: + fk1 = ds.planck_fk1.values + fk2 = ds.planck_fk2.values + bc1 = ds.planck_bc1.values + bc2 = ds.planck_bc2.values + tmp = fk1 / ds["Rad"] + 1 + tmp = np.where(tmp > 0, tmp, 1) + T = (fk2/(np.log(tmp))-bc1)/bc2 + radiance.values = T + ds['Rad'] = radiance + self.data = ds + return self.data + + def plot(self, ax=None, cmap=None, norm=None): + if not hasattr(self, 'data'): + self.open_dataset() + + # Satellite height + sat_h = self.data['goes_imager_projection'].perspective_point_height + # Satellite longitude + sat_lon = self.data['goes_imager_projection'].longitude_of_projection_origin + + # The geostationary projection + x = self.data['x'].values * sat_h + y = self.data['y'].values * sat_h + if ax is None: + fig = plt.figure(figsize=[10,10]) + ax = fig.add_subplot(111) + + m = Basemap(projection='geos', lon_0=sat_lon, resolution='i', + rsphere=(6378137.00,6356752.3142), + llcrnrx=x.min(),llcrnry=y.min(), + urcrnrx=x.max(),urcrnry=y.max(), + ax=ax) + m.drawcoastlines() + m.drawcountries() + m.drawstates() + ax.set_title('GOES-16 -- Band {}'.format(self.band), fontweight='semibold', loc='left') + ax.set_title('%s' % self.datetime.strftime('%H:%M UTC %d %B %Y'), loc='right') + return m.imshow(self.data['Rad'].values[::-1], cmap=cmap, norm=norm) + #return m + + def plot_infrared(self, ax=None, colortable='ir_drgb_r', colorbar=False, cmin=180, cmax=270): + from metpy.plots import colortables + ir_norm, ir_cmap = colortables.get_with_range('WVCIMSS', cmin, cmax) + # Use a colortable/colormap available from MetPy + # ir_norm, ir_cmap = colortables.get_with_range(colortable, 190, 350) + im = self.plot(ax, cmap=ir_cmap, norm=ir_norm) + if colorbar: + plt.colorbar(im, pad=0.01, aspect=50, ax=ax, shrink=0.85, + ticks=range(cmin,cmax,10), label='Temperature (K)') + return im + + def interp(self, scale): + if not hasattr(self, 'data'): + self.open_dataset() + + new_x = np.linspace(self.data.x.values[0], self.data.x.values[-1], + int(self.data.x.values.shape[0] * scale)) + new_y = np.linspace(self.data.y.values[0], self.data.y.values[-1], + int(self.data.y.values.shape[0] * scale)) + self.data = self.data.interp(x=new_x, y=new_y) + + def latlon(self): + if not hasattr(self, 'lats'): + if not hasattr(self, 'data'): + self.open_dataset() + # Satellite height + sat_h = self.data['goes_imager_projection'].perspective_point_height + # Satellite longitude + sat_lon = self.data['goes_imager_projection'].longitude_of_projection_origin + sat_sweep= self.data['goes_imager_projection'].sweep_angle_axis + p = pyproj.Proj(proj='geos', h=sat_h, lon_0=sat_lon, sweep=sat_sweep) + X = self.data['x'].values * sat_h + Y = self.data['y'].values * sat_h + XX, YY = np.meshgrid(X, Y) + lons, lats = p(XX, YY, inverse=True) + self.lats = lats + self.lons = lons + NANs = np.isnan(self.data['Rad'].values) + self.lats[NANs] = np.nan + self.lons[NANs] = np.nan + self.lats[~np.isfinite(self.lats)] = np.nan + self.lons[~np.isfinite(self.lons)] = np.nan + return self.lats, self.lons + + def latlon_lookup(self, lat, lon): + self.latlon() + if (lat > self.lats.min()) and (lat < self.lats.max()) and (lon > self.lons.min()) and (lon < self.lons.max()): + dist = ((self.lats - lat)**2 + (self.lons - lon)**2)**0.5 + ix, iy = np.unravel_index(np.argmin(dist, axis=None), dist.shape) + return ix, iy + + + def reproject_to_latlon(self): + data = self.open_dataset() + lats, lons = self.latlon() + sat_lon = data['goes_imager_projection'].longitude_of_projection_origin + + if self.band in [1,3,4,5,6]: + s = 0.01 + elif self.band == 2: + s = 0.005 + else: + s = 0.02 + + lat_min = np.around(max(np.nanmin(lats), -60), 3) + lat_max = np.around(min(np.nanmax(lats), 60), 3) + lon_min = np.around(max(np.nanmin(lons), sat_lon-60), 3) + lon_max = np.around(min(np.nanmax(lons), sat_lon+60), 3) + lats_new = np.arange(lat_min, lat_max, s) + lons_new = np.arange(lon_min, lon_max, s) + lons_new, lats_new = np.meshgrid(lons_new, lats_new) + + source_def = pyresample.geometry.SwathDefinition(lats=lats, lons=lons) + target_def = pyresample.geometry.GridDefinition(lons=lons_new, lats=lats_new) + result = pyresample.kd_tree.resample_nearest(source_def, + data['Rad'].values, + target_def, + radius_of_influence=50000, + epsilon=0.05) + data_new = xr.DataArray(result, coords=dict(lat=lats_new[:,0], lon=lons_new[0,:]), dims=('lat', 'lon')) + return xr.Dataset(dict(Rad=data_new)) + + +class GroupBandTemporal(object): + def __init__(self, data_list): + self.data = data_list + self.open_dataset() + + def open_dataset(self): + for b in self.data: + if not hasattr(b, 'data'): + b.open_dataset() + + def get_radiances(self, indices=None): + self.open_dataset() + xequal = np.array_equal(self.data[0].data.x.values, + self.data[-1].data.x.values) + yequal = np.array_equal(self.data[0].data.y.values, + self.data[-1].data.y.values) + if (not xequal) or (not yequal): + return + + if indices is None: + indices = range(len(self.data)) + data = xr.concat([self.data[i].data['Rad'] for i in indices], 'time') + return data + + def get_radiance_patches(self, patch_size): + data = self.get_radiances() + if data is None: + return + + data = utils.block_array(data, 2, + size=patch_size, + stride=patch_size) + data = utils.block_array(data, 2, + size=patch_size, + stride=patch_size) + data= data.reshape(-1, len(self.data), + patch_size, + patch_size) + return data + + def add(self, band): + self.data.append(band) + + def __len__(self): + return len(self.data) + + def timeseries(self, ix, iy): + self.open_dataset() + data = np.array([b.data['Rad'].isel(x=ix, y=iy).values for b in self.data]) + x = [b.datetime for b in self.data] + return x, data + + def timeseries_latlon(self, lat, lon): + print("Lat: {}, Lon: {}".format(lat, lon)) + self.open_dataset() + indices = [b.latlon_lookup(lat, lon) for b in self.data] + data = [] + x = [] + for i, b in enumerate(self.data): + if indices[i] is None: + print("Data bounds: ({}, {}), ({}, {})".format(b.lats.min(), b.lats.max(), b.lons.min(), + b.lons.max())) + continue + data.append(b.data['Rad'].isel(x=indices[i][0], y=indices[i][1]).values) + x.append(b.datetime) + data = np.array(data) + x = np.array(x) + return x, data + +class GOESL1b(object): + def __init__(self, product='ABI-L1b-RadF', + channels=range(1,17), + data_directory='/nex/datapool/geonex/public/GOES16/NOAA-L1B/'): + self.product = product + self.channels = channels + self.data_directory = os.path.join(data_directory, product) + + def local_files(self, year=None, dayofyear=None, hour=None, spatial=None): + data = [] + base_dir = self.data_directory + + if year is None: + year = "*" + else: + year = '%04i' % year + if dayofyear is None: + dayofyear = "*" + else: + dayofyear = '%03i' % dayofyear + if hour is None: + hour = "*" + else: + hour = '%02i' % hour + + for c in self.channels: + path_pattern = os.path.join(self.data_directory, year, dayofyear, hour, + 'OR_ABI-L1b-*C%02i_*.nc' % c) + channel_files = glob.glob(path_pattern) + for f in channel_files: + + meta = get_filename_metadata(os.path.basename(f)) + meta['file'] = f + data.append(meta) + + data = pd.DataFrame(data) + if (len(data) > 0) and (spatial is not None): + data = data[data['spatial'] == spatial] + + if len(data) > 0: + new_index = ['year', 'dayofyear', 'hour', 'minute', 'second', 'spatial'] + data = data.set_index(new_index) + data = data.pivot(columns='channel').drop_duplicates() + + return data + + def snapshot_file(self, year, dayofyear, hour, minute, spatial=None): + t = dt.datetime(year, 1, 1, hour, minute) + dt.timedelta(days=dayofyear-1) + t1 = t + dt.timedelta(hours=1) + curr_time = t + files = [] + while curr_time <= t1: + files.append(self.local_files(curr_time.year, curr_time.timetuple().tm_yday, + curr_time.hour, spatial=spatial)) + curr_time = curr_time + dt.timedelta(hours=1) + files = pd.concat(files) + + if spatial is None: + spatial = files.index.get_level_values('spatial')[0] + + files = files.reset_index() + + def to_dt(row): + tt = dt.datetime(row['year'], 1, 1, row['hour'], row['minute'], row['second']/10.) + dt.timedelta(days=row['dayofyear'].values[0]-1) + return tt + + times = files.apply(to_dt, axis=1) + + idx = np.argmin(np.abs(times-t)) + return files.iloc[idx]['file'] + + def snapshot_files(self, year, dayofyear, hour, minute, spatial=None): + hour_files = self.local_files(year, dayofyear, hour, spatial=spatial) + + minutes = hour_files.index.get_level_values('minute') + if spatial is None: + spatial = hour_files.index.get_level_values('spatial')[0] + + idx = np.argmin(np.abs(minutes-minute)) + + minute_sel = minutes[idx] + return hour_files.loc[year, dayofyear, hour, minute_sel]['file'] + + + def open_snapshot(self, year, dayofyear, hour, minute, spatial=None): + snapshot_files = self.snapshot_files(year, dayofyear, hour, minute, spatial=spatial) + rads = [] + regrid = utils.regrid_1km + for c in self.channels: + band_file = snapshot_files[c].iloc[0] + l1b = L1bBand(band_file) + data = l1b.open_dataset() + rad_c = data['Rad'] + rad_c = rad_c.expand_dims(dict(band=[c])) + rad_c_regrid = regrid(rad_c, c) + rads.append(rad_c_regrid) + rads[-1] = rads[-1].assign_coords(x=rads[0].x.values, + y=rads[0].y.values) + + rad = xr.concat(rads, 'band') + #rad = rad.swapaxes(0, 1) # put time in front + return rad diff --git a/windflow/datasets/geostationary/noaa_amv.py b/windflow/datasets/geostationary/noaa_amv.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf91789590dab50bac36ed523a1e8c8812ac7ff --- /dev/null +++ b/windflow/datasets/geostationary/noaa_amv.py @@ -0,0 +1,16 @@ +import xarray as xr +import numpy as np +import datetime as dt + +class NOAAAMV(object): + def __init__(self, file): + self.file = file + self.dataset = xr.open_dataset(self.file) + self.dataset = self.dataset.isel(nMeasures=np.where(np.isfinite(self.dataset.wind_speed.values))[0]) + self.time = dt.datetime(2000, 1, 1, 12) + dt.timedelta(seconds=self.dataset['time_bounds'].values[0], minutes=0) + self.wind_speeds = self.dataset.wind_speed.values + self.direction = self.dataset.wind_direction.values + self.lats = self.dataset['lat'].values + self.lons = self.dataset['lon'].values + self.vs = np.sin(self.direction / 180 * np.pi) * self.wind_speeds + self.us = np.cos(self.direction / 180 * np.pi) * self.wind_speeds diff --git a/windflow/datasets/geostationary/pytorch_dataloader.py b/windflow/datasets/geostationary/pytorch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..1f525b5352d0f731fa9f685e0e8e9a0104b3a015 --- /dev/null +++ b/windflow/datasets/geostationary/pytorch_dataloader.py @@ -0,0 +1,124 @@ +import os +import xarray as xr +import numpy as np + +import torch +from torch.utils import data +from torchvision import transforms + +from . import goesr +from . import stats + +from .. import flow_transforms +from ..preprocess import image_histogram_equalization + +class L1bPatches(data.Dataset): + def __init__(self, data_directory, time_step=1, size=512, bands=[9,], mode='train'): + self.patch_folders = [os.path.join(data_directory, f) for f in os.listdir(data_directory)] + + self.patches = [L1bPatch(f, time_step=time_step, size=size, bands=bands, mode=mode) for f in self.patch_folders] + self.patches = data.ConcatDataset(self.patches) + self.N = len(self.patches) + + def __len__(self): + return self.N + + def __getitem__(self, idx): + return self.patches[idx] + +class L1bPatch(data.Dataset): + def __init__(self, data_directory, time_step=1, size=512, bands=[9,], mode='train'): + self.files = sorted([os.path.join(data_directory, f) for f in os.listdir(data_directory)]) + + N = len(self.files) + if mode == 'train': + self.files = self.files[:int(N*0.8)] + elif mode == 'valid': + self.files = self.files[int(N*0.8):int(N*0.9)] + elif mode == 'test': + self.files = self.files[int(N*0.9):] + + self.N = len(self.files)-time_step + self.time_step = time_step + self.bands = bands + self.rand_transform = transforms.Compose([ + transforms.RandomCrop((size, size)), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + ]) + + def __len__(self): + return self.N + + def __getitem__(self, idx): + sample_files = [self.files[idx], self.files[idx+self.time_step]] + N_samples = len(sample_files) + N_bands = len(self.bands) + + ds = [xr.open_dataset(f) for f in sample_files] + + x = np.concatenate([d['Rad'].sel(band=self.bands).values for d in ds], 0) + x[~np.isfinite(x)] = 0. + + mask = torch.Tensor((x[0] == x[0]).copy()).float() + + x = image_histogram_equalization(x) + x = self.rand_transform(torch.Tensor(x)) + x = x.unsqueeze(1) + #x = [x[i:i+N_bands] for i in range(N_samples)] + return x, mask + +class L1bResized(data.Dataset): + def __init__(self, data_directory, time_step=1, new_size = (1024, 1024), jitter=6, band=9): + self.files = sorted([os.path.join(data_directory, f) for f in os.listdir(data_directory)]) + self.N = len(self.files)-time_step + self.time_step = time_step + self.jitter = jitter + self.new_size = new_size + jitter_size = (new_size[0] + jitter, new_size[1] + jitter) + self.transform = transforms.Compose([transforms.ToPILImage(), + transforms.Resize(jitter_size), + transforms.ToTensor()]) + self.mu, self.sd = stats.get_sensor_stats("ABI") + self.mu = self.mu[band-1] + self.sd = self.sd[band-1] + + def __len__(self): + return self.N + + def __getitem__(self, idx): + sample_files = [self.files[idx], self.files[idx+self.time_step]] + print(sample_files) + samples = [] + for f in sample_files: + if (f[-3:] == '.nc') or (f[-4:] == '.nc4'): + data = xr.open_dataset(f) + x = data['Rad'].values.astype(np.float32) + x = np.copy(x) + elif f[-3:] == 'npy': + x = np.load(f).astype(np.float32) + x = (x - self.mu) / self.sd + x[~np.isfinite(x)] = 0. + x = self.transform(x) + samples.append(x) + + #if np.random.uniform() > 0.5: + # samples = [torch.flipud(s) for s in samples] + #if np.random.uniform() > 0.5: + # samples = [torch.fliplr(s) for s in samples] + #if np.random.uniform() > 0.5: + # samples = [torch.rot90(s, 1, [1, 2]) for s in samples] + + if self.jitter > 0: + ix = np.random.choice(range(self.jitter)) + iy = np.random.choice(range(self.jitter)) + samples = [s[:,ix:ix+self.new_size[0],iy:iy+self.new_size[1]] for s in samples] + + #factor1 = np.random.uniform(0.9,1.1) + #factor2 = np.random.uniform(0.9,1.1) + #samples = [s*factor1 for s in samples] + #samples[1] = samples[1]*factor2 + mask = (samples[0] != 0).float() + return samples, mask + + diff --git a/windflow/datasets/geostationary/schema.py b/windflow/datasets/geostationary/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..40590c0be2688a328782bb10ecfd40d0125749ea --- /dev/null +++ b/windflow/datasets/geostationary/schema.py @@ -0,0 +1,28 @@ +from petastorm.unischema import Unischema, UnischemaField +from petastorm.codecs import ScalarCodec, NdarrayCodec +from pyspark.sql.types import IntegerType, StringType, FloatType + +import numpy as np + +BandSchema1000 = Unischema('BandSchema1000', [ + UnischemaField('index', np.int32, (), ScalarCodec(IntegerType()), False), + UnischemaField('timestamp', np.int32, (), ScalarCodec(IntegerType()), False), + UnischemaField('spatial', np.string_, (), ScalarCodec(StringType()), False), + UnischemaField('file', np.string_, (), ScalarCodec(StringType()), False), + UnischemaField('data', np.float32, (1000, 1000), NdarrayCodec(), False), + UnischemaField('x_ul', np.float32, (), ScalarCodec(FloatType()), False), + UnischemaField('y_ul', np.float32, (), ScalarCodec(FloatType()), False), + UnischemaField('group_id', np.float32, (), ScalarCodec(FloatType()), False), +]) + +BandSchema256 = Unischema('BandSchema256', [ + UnischemaField('index', np.int32, (), ScalarCodec(IntegerType()), False), + UnischemaField('timestamp', np.int32, (), ScalarCodec(IntegerType()), False), + UnischemaField('spatial', np.string_, (), ScalarCodec(StringType()), False), + UnischemaField('file', np.string_, (), ScalarCodec(StringType()), False), + UnischemaField('data', np.float32, (256, 256), NdarrayCodec(), False), + UnischemaField('x_ul', np.float32, (), ScalarCodec(FloatType()), False), + UnischemaField('y_ul', np.float32, (), ScalarCodec(FloatType()), False), + UnischemaField('group_id', np.int32, (), ScalarCodec(IntegerType()), False), + UnischemaField('patch_id', np.int32, (), ScalarCodec(IntegerType()), False), +]) diff --git a/windflow/datasets/geostationary/stats.py b/windflow/datasets/geostationary/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..980b2d39f213f923aced9b5f85c1136627233e35 --- /dev/null +++ b/windflow/datasets/geostationary/stats.py @@ -0,0 +1,24 @@ +def get_sensor_stats(sensor): + ''' + Values computed by scripts/get_data_stats.py + inputs + sensor: (AHI, ABI, G16, G17, H8) + output + mean + standard deviation + ''' + if sensor in ['AHI', 'H8', 'H8_15']: + # computed from himawari + mu = (0.829, 0.512, 0.621, 0.663, 0.427, 0.333, 285.251, 235.759, 244.547, + 252.391, 272.476, 251.321, 274.869, 274.522, 272.611, 259.913) + sd = (0.292, 0.257, 0.325, 0.362, 0.283, 0.218, 6.243, 2.628, 3.472, 4.165, + 7.574, 4.078, 7.973, 8.122, 7.839, 5.408) + elif sensor in ['ABI', 'G16', 'G17']: + mu = (0.829, 0.621, 0.663, 0.077, 0.427, 0.333, 285.251, 235.759, 244.547, + 252.391, 272.476, 251.321, 274.869, 274.522, 272.611, 259.913) + sd = (0.292, 0.325, 0.362, 0.113, 0.283, 0.218, 6.243, 2.628, 3.472, 4.165, + 7.574, 4.078, 7.973, 8.122, 7.839, 5.408) + elif sensor in ['AHI12']: + mu = (0.06, 0.08, 0.15, 0.30, 0.32, 0.23) + sd = (0.03, 0.1, 0.2, 0.3, 0.3, 0.3) + return mu, sd \ No newline at end of file diff --git a/windflow/datasets/igrs/README.md b/windflow/datasets/igrs/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d714fdd7c686417a09cdeba648ad06c503e1a046 --- /dev/null +++ b/windflow/datasets/igrs/README.md @@ -0,0 +1,8 @@ +## NOAA Integrated Global Radiosonde Archive (IGRA) + + +https://www.ncdc.noaa.gov/data-access/weather-balloon/integrated-global-radiosonde-archive + +ftp://ftp.ncdc.noaa.gov/pub/data/igra + + diff --git a/windflow/datasets/igrs/build_database.py b/windflow/datasets/igrs/build_database.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec5db2003dd5a681c269504dabe7a7cc0087267 --- /dev/null +++ b/windflow/datasets/igrs/build_database.py @@ -0,0 +1,95 @@ +import os, sys +import glob + +import pandas as pd + +import sqlite3 + +igra_directory = '/nobackupp10/tvandal/data/IGRA/' + +# ['data-y2d', 'igra2-readme.txt', 'igra2-country-list.txt', 'igra2-station-list.txt', 'status.txt', +# 'igra2-list-format.txt', 'igra2-data-format.txt', 'igra2-us-states.txt', 'igra2-readme.txt.1'] + +#con = sqlite3.connect("data/portal_mammals.sqlite") +# Load the data into a DataFrame +#surveys_df = pd.read_sql_query("SELECT * from surveys", con) + +# Select only data for 2002 +#surveys2002 = surveys_df[surveys_df.year == 2002] + +# Write the new DataFrame to a new SQLite table +#surveys2002.to_sql("surveys2002", con, if_exists="replace") +#con.close() + +class RawIGRA: + def __init__(self, directory): + self.dir = directory + + + def stations(self): + path = os.path.join(self.dir, 'igra2-station-list.txt') + #names = ['id', 'lat', 'lon', 'elevation', 'state', 'name', 'first_year', 'last_year', 'n_obs'] + #df = pd.read_table(path, sep='\t') + data = [] + with open(path, 'r') as reader: + for line in reader.readlines(): + station = {'id': line[:11], 'lat': float(line[12:20]), 'lon': float(line[21:30]), + 'elevation': float(line[31:37]), 'state': line[38:40].strip(), 'name': line[41:71].strip(), + 'first_year': int(line[72:76]), 'last_year': int(line[77:81]), + 'n_obs': int(line[82:88])} + data.append(station) + data = pd.DataFrame(data) + return data + + def station_data(self, station): + path = os.path.join(self.dir, f'data-y2d/{station}-data.txt') + if not os.path.exists(path): + print(f"Cannot find data for station {station}") + return + print(f"Found data for stations {station}") + data = [] + + with open(path, 'r') as reader: + for i, line in enumerate(reader.readlines()): + if line[0] == '#': # header row + record_header = {'station': station, 'year': int(line[13:17]), 'month': int(line[18:20]), + 'day': int(line[21:23]), 'hour': int(line[24:26]), + 'reltime': int(line[27:31]), 'num_levels': int(line[32:36]), + 'p_src': line[37:45].strip(), 'np_src': line[46:54].strip(), + 'lat': int(line[55:62])/10000., 'lon': int(line[63:71])/10000., + } + #print(record_header) + else: + record = {'level_type1': int(line[0]), 'level_type2': int(line[1]), 'elapsed_time': int(line[3:8]), + 'pressure': int(line[9:15]), 'pflag': line[15], 'geopotential_height': int(line[16:21]), + 'zflag': line[21], 'temp': int(line[22:27]), 'tflag': line[27], 'rh': int(line[28:33]), + 'dpdp': int(line[34:39]), 'wind_direction': int(line[40:45]), 'wind_speed': int(line[46:51]) + } + record.update(record_header) + data.append(record) + + data = pd.DataFrame(data) + return data + + +def write_to_database(): + conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db')) + + igra = RawIGRA(igra_directory) + stations = igra.stations() + stations.to_sql("stations", conn, if_exists="replace") + alldata = [] + for i, row in stations.iterrows(): + station_data = igra.station_data(row.id) + if station_data is not None: + station_data.to_sql("records", conn, if_exists="append", index=False) + conn.commit() + +def query_records(): + conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db')) + df = pd.read_sql_query("SELECT * from records", conn) + print(df) + +if __name__ == '__main__': + write_to_database() + query_records() \ No newline at end of file diff --git a/windflow/datasets/igrs/igrs_to_dataframe.py b/windflow/datasets/igrs/igrs_to_dataframe.py new file mode 100644 index 0000000000000000000000000000000000000000..fccca09dbbee198bfa9d0ad172cf88c97e4e9f6d --- /dev/null +++ b/windflow/datasets/igrs/igrs_to_dataframe.py @@ -0,0 +1,85 @@ +import os, sys +import glob + +import pandas as pd + +import sqlite3 + +igra_directory = '/nobackupp10/tvandal/data/IGRA/' + +# ['data-y2d', 'igra2-readme.txt', 'igra2-country-list.txt', 'igra2-station-list.txt', 'status.txt', +# 'igra2-list-format.txt', 'igra2-data-format.txt', 'igra2-us-states.txt', 'igra2-readme.txt.1'] + + +class RawIGRA: + def __init__(self, directory): + self.dir = directory + + + def stations(self): + path = os.path.join(self.dir, 'igra2-station-list.txt') + #names = ['id', 'lat', 'lon', 'elevation', 'state', 'name', 'first_year', 'last_year', 'n_obs'] + #df = pd.read_table(path, sep='\t') + data = [] + with open(path, 'r') as reader: + for line in reader.readlines(): + station = {'id': line[:11], 'lat': float(line[12:20]), 'lon': float(line[21:30]), + 'elevation': float(line[31:37]), 'state': line[38:40].strip(), 'name': line[41:71].strip(), + 'first_year': int(line[72:76]), 'last_year': int(line[77:81]), + 'n_obs': int(line[82:88])} + data.append(station) + data = pd.DataFrame(data) + return data + + def station_data(self, station): + path = os.path.join(self.dir, f'data-y2d/{station}-data.txt') + if not os.path.exists(path): + print(f"Cannot find data for station {station}") + return + print(f"Found data for stations {station}") + data = [] + + with open(path, 'r') as reader: + for i, line in enumerate(reader.readlines()): + if line[0] == '#': # header row + record_header = {'station': station, 'year': int(line[13:17]), 'month': int(line[18:20]), + 'day': int(line[21:23]), 'hour': int(line[24:26]), + 'reltime': int(line[27:31]), 'num_levels': int(line[32:36]), + 'p_src': line[37:45].strip(), 'np_src': line[46:54].strip(), + 'lat': int(line[55:62])/10000., 'lon': int(line[63:71])/10000., + } + #print(record_header) + else: + record = {'level_type1': int(line[0]), 'level_type2': int(line[1]), 'elapsed_time': int(line[3:8]), + 'pressure': int(line[9:15]), 'pflag': line[15], 'geopotential_height': int(line[16:21]), + 'zflag': line[21], 'temp': int(line[22:27]), 'tflag': line[27], 'rh': int(line[28:33]), + 'dpdp': int(line[34:39]), 'wind_direction': int(line[40:45]), 'wind_speed': int(line[46:51]) + } + record.update(record_header) + data.append(record) + + data = pd.DataFrame(data) + return data + + +def write_to_database(): + conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db')) + + igra = RawIGRA(igra_directory) + stations = igra.stations() + stations.to_sql("stations", conn, if_exists="replace") + alldata = [] + for i, row in stations.iterrows(): + station_data = igra.station_data(row.id) + if station_data is not None: + station_data.to_sql("records", conn, if_exists="append", index=False) + conn.commit() + +def query_records(): + conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db')) + df = pd.read_sql_query("SELECT * from records", conn) + print(df) + +if __name__ == '__main__': + write_to_database() + query_records() \ No newline at end of file diff --git a/windflow/datasets/preprocess.py b/windflow/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..8875c0ca66d38c4f890c2d005e38ce4f1fdd6f12 --- /dev/null +++ b/windflow/datasets/preprocess.py @@ -0,0 +1,20 @@ +import numpy as np + +def image_histogram_equalization(image, number_bins=500): + # from http://www.janeriksolem.net/2009/06/histogram-equalization-with-python-and.html + # get image histogram + image_histogram, bins = np.histogram(image.flatten(), number_bins, density=True) + cdf = image_histogram.cumsum() # cumulative distribution function + cdf = cdf / cdf[-1] # normalize + + # use linear interpolation of cdf to find new pixel values + image_equalized = np.interp(image.flatten(), bins[:-1], cdf) + + image_equalized[~np.isfinite(image_equalized)] = 0. + + return image_equalized.reshape(image.shape)#, cdf + +def normalize_qv(qv): + qv_log = np.log(qv * 1e4 + 0.001) + qv_log_norm = (qv_log - 1.06) / 2.15 + return qv_log_norm \ No newline at end of file diff --git a/windflow/datasets/utils.py b/windflow/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab0ce58589a66da9d193aa44d7200ab3db7d57d --- /dev/null +++ b/windflow/datasets/utils.py @@ -0,0 +1,120 @@ +import xarray as xr +import numpy as np +import scipy.interpolate +import dask as da + +def interp_dim(x, scale): + x0, xlast = x[0], x[-1] + newlength = int(len(x) * scale) + y = np.linspace(x0, xlast, num=newlength, endpoint=False) + return y + +def blocks(data, width=352): + #n = data.t.shape[0] + w = data.x.shape[0] + h = data.y.shape[0] + d = data.band.shape[0] + + hs = np.arange(0, h, width) + ws = np.arange(0, w, width) + blocks = [] + for hindex in hs: + if hindex+width > h: + hindex = h - width + + for windex in ws: + if windex+width > w: + windex = w - width + blocks.append(data.sel(y=data.y.values[hindex:hindex+width], + x=data.x.values[windex:windex+width])) + return blocks + +def block_dask_array(arr, axis, size=128, stride=128): + arr = da.array.swapaxes(arr, axis, 0) + n = arr.shape[0] + stack = [] + for j in range(0, n, stride): + j = min(j, n - size) + stack.append(arr[j:j+size]) + stack = da.array.stack(stack) + stack = da.array.swapaxes(stack, axis+1, 1) + return stack + +def block_array(arr, axis, size=128, stride=128): + arr = np.swapaxes(arr, axis, 0) + n = arr.shape[0] + stack = [] + for j in range(0, n, stride): + j = min(j, n - size) + stack.append(arr[np.newaxis,j:j+size]) + stack = np.concatenate(stack, 0) + stack = np.swapaxes(stack, axis+1, 1) + return stack + +def xarray_to_block_list(arr, dim, size=128, stride=128): + n = arr[dim].shape[0] + stack = [] + for j in range(0, n, stride): + j = min(j, n - size) + stack.append(arr.isel({dim: np.arange(j,j+size)})) + return stack + +def interp(da, scale, fillna=False): + xnew = interp_dim(da['x'].values, scale) + ynew = interp_dim(da['y'].values, scale) + newcoords = dict(x=xnew, y=ynew) + return da.interp(newcoords) + +def regrid_2km(da, band): + if band == 2: + return interp(da, 1. / 4, fillna=False) + elif band in [1, 3, 5]: + return interp(da, 1. / 2, fillna=False) + return da + +def regrid_1km(da, band): + if band == 2: #(0.5 km) + return interp(da, 1./2, fillna=False) + elif band not in [1, 3, 5]: # 2km + return interp(da, 2., fillna=False) + return da + +def regrid_500m(da, band): + if band == 2: # 500m + return da + elif band in [1, 3, 5]: # 1km + return interp(da, 2., fillna=False) + return interp(da, 4., fillna=False) # 2km + + +def cartesian_to_speed(da): + lat_rad = np.radians(da.lat.values) + lon_rad = np.radians(da.lon.values) + a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2 + d = 2 * 6378.137 * np.arcsin(a**0.5) + size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # km + da['U'] = da['U'] * size_per_pixel * 1000 / 1800 + da['V'] = da['V'] * size_per_pixel * 1000 / 1800 + return da + + +def cartesian_to_speed_eco(da): + lat_rad = np.radians(da.lat_0.values) + lon_rad = np.radians(da.lon_0.values) + a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2 + d = 2 * 6378.137 * np.arcsin(a**0.5) + size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # km + da['ugrd_newP'] = da['ugrd_newP'] * size_per_pixel * 1000 / 10800 + da['vgrd_newP'] = da['vgrd_newP'] * size_per_pixel * 1000 / 10800 + return da + + +def speed_to_cartesian(da): + lat_rad = np.radians(da.lat.values) + lon_rad = np.radians(da.lon.values) + a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2 + d = 2 * 6378.137 * np.arcsin(a**0.5) + size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # km + da['U'] = da['U'] / size_per_pixel / 1000 * 1800 / 0.9 + da['V'] = da['U'] / size_per_pixel / 1000 * 1800 / 0.9 + return da \ No newline at end of file diff --git a/windflow/external_packages/__init__.py b/windflow/external_packages/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/external_packages/__pycache__/__init__.cpython-39.pyc b/windflow/external_packages/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca0f05cc7c666751f26bafdc177bec6bef98b6f5 Binary files /dev/null and b/windflow/external_packages/__pycache__/__init__.cpython-39.pyc differ diff --git a/windflow/external_packages/channelnorm_package/__init__.py b/windflow/external_packages/channelnorm_package/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/external_packages/channelnorm_package/channelnorm.py b/windflow/external_packages/channelnorm_package/channelnorm.py new file mode 100755 index 0000000000000000000000000000000000000000..76301af51f3d2d08e68c2e092e94f43a3e98cf16 --- /dev/null +++ b/windflow/external_packages/channelnorm_package/channelnorm.py @@ -0,0 +1,39 @@ +from torch.autograd import Function, Variable +from torch.nn.modules.module import Module +import channelnorm_cuda + +class ChannelNormFunction(Function): + + @staticmethod + def forward(ctx, input1, norm_deg=2): + assert input1.is_contiguous() + b, _, h, w = input1.size() + output = input1.new(b, 1, h, w).zero_() + + channelnorm_cuda.forward(input1, output, norm_deg) + ctx.save_for_backward(input1, output) + ctx.norm_deg = norm_deg + + return output + + @staticmethod + def backward(ctx, grad_output): + input1, output = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + + channelnorm_cuda.backward(input1, output, grad_output.data, + grad_input1.data, ctx.norm_deg) + + return grad_input1, None + + +class ChannelNorm(Module): + + def __init__(self, norm_deg=2): + super(ChannelNorm, self).__init__() + self.norm_deg = norm_deg + + def forward(self, input1): + return ChannelNormFunction.apply(input1, self.norm_deg) + diff --git a/windflow/external_packages/channelnorm_package/channelnorm_cuda.cc b/windflow/external_packages/channelnorm_package/channelnorm_cuda.cc new file mode 100755 index 0000000000000000000000000000000000000000..69d82eb184e97b2eefa9810ad156d1104cf84745 --- /dev/null +++ b/windflow/external_packages/channelnorm_package/channelnorm_cuda.cc @@ -0,0 +1,31 @@ +#include <torch/torch.h> +#include <ATen/ATen.h> + +#include "channelnorm_kernel.cuh" + +int channelnorm_cuda_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + channelnorm_kernel_forward(input1, output, norm_deg); + return 1; +} + + +int channelnorm_cuda_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); + m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); +} + diff --git a/windflow/external_packages/channelnorm_package/channelnorm_kernel.cu b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..99ace6855a61373443a6ddff7c7858eb474e9e48 --- /dev/null +++ b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cu @@ -0,0 +1,177 @@ +#include <ATen/ATen.h> +#include <ATen/Context.h> +#include <ATen/cuda/CUDAContext.h> + +#include "channelnorm_kernel.cuh" + +#define CUDA_NUM_THREADS 512 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +using at::Half; + +template <typename scalar_t> +__global__ void kernel_channelnorm_update_output( + const int n, + const scalar_t* __restrict__ input1, + const long4 input1_size, + const long4 input1_stride, + scalar_t* __restrict__ output, + const long4 output_size, + const long4 output_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int i1dim_c = DIM1(input1_size); + int i1dim_h = DIM2(input1_size); + int i1dim_w = DIM3(input1_size); + int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; + int i1dim_hw = i1dim_h * i1dim_w; + + float result = 0.0; + + for (int c = 0; c < i1dim_c; ++c) { + int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; + scalar_t val = input1[i1Index]; + result += static_cast<float>(val * val); + } + result = sqrt(result); + output[index] = static_cast<scalar_t>(result); +} + + +template <typename scalar_t> +__global__ void kernel_channelnorm_backward_input1( + const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + float val = 0.0; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + + int outIndex = b * dim_hw + y * dim_w + x; + val = static_cast<float>(gradOutput[outIndex]) * static_cast<float>(input1[index]) / (static_cast<float>(output[outIndex])+1e-9); + gradInput[index] = static_cast<scalar_t>(val); + +} + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + int n = output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] { + + kernel_channelnorm_update_output<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data<scalar_t>(), + input1_size, + input1_stride, + output.data<scalar_t>(), + output_size, + output_stride, + norm_deg); + + })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); +} + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + + int n = gradInput1.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] { + + kernel_channelnorm_backward_input1<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data<scalar_t>(), + input1_size, + input1_stride, + output.data<scalar_t>(), + output_size, + output_stride, + gradOutput.data<scalar_t>(), + gradOutput_size, + gradOutput_stride, + gradInput1.data<scalar_t>(), + gradInput1_size, + gradInput1_stride, + norm_deg + ); + + })); + + // TODO: Add ATen-equivalent check + +// THCudaCheck(cudaGetLastError()); +} diff --git a/windflow/external_packages/channelnorm_package/channelnorm_kernel.cuh b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cuh new file mode 100755 index 0000000000000000000000000000000000000000..3e6223f7fe60feb4bf9e4f66c3d849b84c89dcda --- /dev/null +++ b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include <ATen/ATen.h> + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg); + + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg); diff --git a/windflow/external_packages/channelnorm_package/setup.py b/windflow/external_packages/channelnorm_package/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..17ecacf5450630a412e0517c462c4ffa5a532718 --- /dev/null +++ b/windflow/external_packages/channelnorm_package/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +import os +import torch + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +#cxx_args = ['-std=c++11'] +cxx_args = ['-std=c++14'] + +nvcc_args = [ + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_70,code=compute_70' +] + +setup( + name='channelnorm_cuda', + ext_modules=[ + CUDAExtension('channelnorm_cuda', [ + 'channelnorm_cuda.cc', + 'channelnorm_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/windflow/external_packages/correlation_package/__init__.py b/windflow/external_packages/correlation_package/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/external_packages/correlation_package/__pycache__/__init__.cpython-39.pyc b/windflow/external_packages/correlation_package/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab63cf9a7d0b5c3ae0b6d4a71213570ea225014f Binary files /dev/null and b/windflow/external_packages/correlation_package/__pycache__/__init__.cpython-39.pyc differ diff --git a/windflow/external_packages/correlation_package/__pycache__/correlation.cpython-39.pyc b/windflow/external_packages/correlation_package/__pycache__/correlation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bc4894787bb408baceb0c7c27ca549a85fc72c3 Binary files /dev/null and b/windflow/external_packages/correlation_package/__pycache__/correlation.cpython-39.pyc differ diff --git a/windflow/external_packages/correlation_package/correlation.py b/windflow/external_packages/correlation_package/correlation.py new file mode 100755 index 0000000000000000000000000000000000000000..80a8b09c088c360e152230af583779d771d1bca1 --- /dev/null +++ b/windflow/external_packages/correlation_package/correlation.py @@ -0,0 +1,62 @@ +import torch +from torch.nn.modules.module import Module +from torch.autograd import Function +import correlation_cuda + +class CorrelationFunction(Function): + + def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): + super(CorrelationFunction, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) + + def forward(self, input1, input2): + self.save_for_backward(input1, input2) + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + output = input1.new() + + correlation_cuda.forward(input1, input2, rbot1, rbot2, output, + self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) + + return output + + def backward(self, grad_output): + input1, input2 = self.saved_tensors + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + + grad_input1 = input1.new() + grad_input2 = input2.new() + + correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, + self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) + + return grad_input1, grad_input2 + + +class Correlation(Module): + def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): + super(Correlation, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + + def forward(self, input1, input2): + + result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) + + return result + diff --git a/windflow/external_packages/correlation_package/correlation_cuda.cc b/windflow/external_packages/correlation_package/correlation_cuda.cc new file mode 100755 index 0000000000000000000000000000000000000000..feccd65295fa90a22564b08fc80464a76361a1aa --- /dev/null +++ b/windflow/external_packages/correlation_package/correlation_cuda.cc @@ -0,0 +1,173 @@ +#include <torch/torch.h> +#include <ATen/ATen.h> +#include <ATen/Context.h> +#include <ATen/cuda/CUDAContext.h> +#include <stdio.h> +#include <iostream> + +#include "correlation_cuda_kernel.cuh" + +int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + + int nInputChannels = input1.size(1); + int inputHeight = input1.size(2); + int inputWidth = input1.size(3); + + int kernel_radius = (kernel_size - 1) / 2; + int border_radius = kernel_radius + max_displacement; + + int paddedInputHeight = inputHeight + 2 * pad_size; + int paddedInputWidth = inputWidth + 2 * pad_size; + + int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); + + int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1)); + int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1)); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); + + rInput1.fill_(0); + rInput2.fill_(0); + output.fill_(0); + + int success = correlation_forward_cuda_kernel( + output, + output.size(0), + output.size(1), + output.size(2), + output.size(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.size(1), + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::cuda::getCurrentCUDAStream() + //at::globalContext().getCurrentCUDAStream() + ); + + //check for errors + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; + +} + +int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, + at::Tensor& gradInput1, at::Tensor& gradInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + int nInputChannels = input1.size(1); + int paddedInputHeight = input1.size(2)+ 2 * pad_size; + int paddedInputWidth = input1.size(3)+ 2 * pad_size; + + int height = input1.size(2); + int width = input1.size(3); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + gradInput1.resize_({batchSize, nInputChannels, height, width}); + gradInput2.resize_({batchSize, nInputChannels, height, width}); + + rInput1.fill_(0); + rInput2.fill_(0); + gradInput1.fill_(0); + gradInput2.fill_(0); + + int success = correlation_backward_cuda_kernel(gradOutput, + gradOutput.size(0), + gradOutput.size(1), + gradOutput.size(2), + gradOutput.size(3), + gradOutput.stride(0), + gradOutput.stride(1), + gradOutput.stride(2), + gradOutput.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + gradInput1, + gradInput1.stride(0), + gradInput1.stride(1), + gradInput1.stride(2), + gradInput1.stride(3), + gradInput2, + gradInput2.size(1), + gradInput2.stride(0), + gradInput2.stride(1), + gradInput2.stride(2), + gradInput2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::cuda::getCurrentCUDAStream() + //at::globalContext().getCurrentCUDAStream() + ); + + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); + m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); +} + diff --git a/windflow/external_packages/correlation_package/correlation_cuda_kernel.cu b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..eaf86fc129137d055de7400916567c6669b45c19 --- /dev/null +++ b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cu @@ -0,0 +1,564 @@ +#include <stdio.h> + +#include "correlation_cuda_kernel.cuh" + +#define CUDA_NUM_THREADS 1024 +#define THREADS_PER_BLOCK 32 +#define FULL_MASK 0xffffffff + +#include <ATen/ATen.h> +#include <ATen/NativeFunctions.h> +#include <ATen/Dispatch.h> +#include <ATen/cuda/CUDAApplyUtils.cuh> + +using at::Half; + +template<typename scalar_t> +__forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) { + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(FULL_MASK, val, offset); + return val; +} + +template<typename scalar_t> +__forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) { + + static __shared__ scalar_t shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; + + if (wid == 0) + val = warpReduceSum(val); + + return val; +} + + +template <typename scalar_t> +__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) +{ + + // n (batch size), c (num of channels), y (height), x (width) + int n = blockIdx.x; + int y = blockIdx.y; + int x = blockIdx.z; + + int ch_off = threadIdx.x; + scalar_t value; + + int dimcyx = channels * height * width; + int dimyx = height * width; + + int p_dimx = (width + 2 * pad_size); + int p_dimy = (height + 2 * pad_size); + int p_dimyxc = channels * p_dimy * p_dimx; + int p_dimxc = p_dimx * channels; + + for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { + value = input[n * dimcyx + c * dimyx + y * width + x]; + rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; + } +} + + +template<typename scalar_t> +__global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels, + const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1, + const int nInputChannels, const int inputHeight, const int inputWidth, + const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size, + const int max_displacement, const int stride1, const int stride2) { + + int32_t pInputWidth = inputWidth + 2 * pad_size; + int32_t pInputHeight = inputHeight + 2 * pad_size; + + int32_t kernel_rad = (kernel_size - 1) / 2; + + int32_t displacement_rad = max_displacement / stride2; + + int32_t displacement_size = 2 * displacement_rad + 1; + + int32_t n = blockIdx.x; + int32_t y1 = blockIdx.y * stride1 + max_displacement; + int32_t x1 = blockIdx.z * stride1 + max_displacement; + int32_t c = threadIdx.x; + + int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels; + + int32_t pdimxc = pInputWidth * nInputChannels; + + int32_t pdimc = nInputChannels; + + int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth; + int32_t tdimyx = outputHeight * outputWidth; + int32_t tdimx = outputWidth; + + int32_t nelems = kernel_size * kernel_size * pdimc; + + // element-wise product along channel axis + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + int x2 = x1 + ti * stride2; + int y2 = y1 + tj * stride2; + + float acc0 = 0.0f; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + // THREADS_PER_BLOCK + #pragma unroll + for (int ch = c; ch < pdimc; ch += blockDim.x) { + + int indx1 = n * pdimyxc + (y1 + j) * pdimxc + + (x1 + i) * pdimc + ch; + int indx2 = n * pdimyxc + (y2 + j) * pdimxc + + (x2 + i) * pdimc + ch; + acc0 += static_cast<float>(rInput1[indx1] * rInput2[indx2]); + } + } + } + + if (blockDim.x == warpSize) { + __syncwarp(); + acc0 = warpReduceSum(acc0); + } else { + __syncthreads(); + acc0 = blockReduceSum(acc0); + } + + if (threadIdx.x == 0) { + + int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + + blockIdx.z; + output[tindx] = static_cast<scalar_t>(acc0 / nelems); + } + } + } +} + + +template <typename scalar_t> +__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2) + { + // n (batch size), c (num of channels), y (height), x (width) + + int n = item; + int y = blockIdx.x * stride1 + pad_size; + int x = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int xmin = (x - kernel_rad - max_displacement) / stride1; + int ymin = (y - kernel_rad - max_displacement) / stride1; + + int xmax = (x + kernel_rad - max_displacement) / stride1; + int ymax = (y + kernel_rad - max_displacement) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + if (xmin > xmax || ymin > ymax) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + xmin = max(0,xmin); + xmax = min(outputWidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(outputHeight-1,ymax); + + int pInputWidth = inputWidth + 2 * pad_size; + int pInputHeight = inputHeight + 2 * pad_size; + + int pdimyxc = pInputHeight * pInputWidth * nInputChannels; + int pdimxc = pInputWidth * nInputChannels; + int pdimc = nInputChannels; + + int tdimcyx = nOutputChannels * outputHeight * outputWidth; + int tdimyx = outputHeight * outputWidth; + int tdimx = outputWidth; + + int odimcyx = nInputChannels * inputHeight* inputWidth; + int odimyx = inputHeight * inputWidth; + int odimx = inputWidth; + + scalar_t nelems = kernel_size * kernel_size * nInputChannels; + + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { + + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; + + scalar_t val2 = rInput2[indx2]; + + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; + prod_sum[tch_off] += gradOutput[tindx] * val2; + } + } + } + __syncthreads(); + + if(tch_off == 0) { + scalar_t reduce_sum = 0; + for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { + reduce_sum += prod_sum[idx]; + } + const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); + gradInput1[indx1] = reduce_sum / nelems; + } + +} + +template <typename scalar_t> +__global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput1, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2) +{ + // n (batch size), c (num of channels), y (height), x (width) + + int n = item; + int y = blockIdx.x * stride1 + pad_size; + int x = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int pInputWidth = inputWidth + 2 * pad_size; + int pInputHeight = inputHeight + 2 * pad_size; + + int pdimyxc = pInputHeight * pInputWidth * nInputChannels; + int pdimxc = pInputWidth * nInputChannels; + int pdimc = nInputChannels; + + int tdimcyx = nOutputChannels * outputHeight * outputWidth; + int tdimyx = outputHeight * outputWidth; + int tdimx = outputWidth; + + int odimcyx = nInputChannels * inputHeight* inputWidth; + int odimyx = inputHeight * inputWidth; + int odimx = inputWidth; + + scalar_t nelems = kernel_size * kernel_size * nInputChannels; + + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int xmin = (x - kernel_rad - max_displacement - i2) / stride1; + int ymin = (y - kernel_rad - max_displacement - j2) / stride1; + + int xmax = (x + kernel_rad - max_displacement - i2) / stride1; + int ymax = (y + kernel_rad - max_displacement - j2) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + if (xmin > xmax || ymin > ymax) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + xmin = max(0,xmin); + xmax = min(outputWidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(outputHeight-1,ymax); + + int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; + scalar_t val1 = rInput1[indx1]; + + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; + prod_sum[tch_off] += gradOutput[tindx] * val1; + } + } + } + + __syncthreads(); + + if(tch_off == 0) { + scalar_t reduce_sum = 0; + for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { + reduce_sum += prod_sum[idx]; + } + const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); + gradInput2[indx2] = reduce_sum / nelems; + } + +} + +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream) +{ + + int batchSize = ob; + + int nInputChannels = ic; + int inputWidth = iw; + int inputHeight = ih; + + int nOutputChannels = oc; + int outputWidth = ow; + int outputHeight = oh; + + dim3 blocks_grid(batchSize, inputHeight, inputWidth); + dim3 threads_block(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { + + channels_first<scalar_t><<<blocks_grid,threads_block, 0, stream>>>( + input1.data<scalar_t>(), rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { + + channels_first<scalar_t><<<blocks_grid,threads_block, 0, stream>>> ( + input2.data<scalar_t>(), rInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { + + correlation_forward<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>> + (output.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth, + rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, + rInput2.data<scalar_t>(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + + })); + + cudaError_t err = cudaGetLastError(); + + + // check for errors + if (err != cudaSuccess) { + printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); + return 0; + } + + return 1; +} + + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream) +{ + + int batchSize = gob; + int num = batchSize; + + int nInputChannels = ic; + int inputWidth = iw; + int inputHeight = ih; + + int nOutputChannels = goc; + int outputWidth = gow; + int outputHeight = goh; + + dim3 blocks_grid(batchSize, inputHeight, inputWidth); + dim3 threads_block(THREADS_PER_BLOCK); + + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { + + channels_first<scalar_t><<<blocks_grid, threads_block, 0, stream>>>( + input1.data<scalar_t>(), + rInput1.data<scalar_t>(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + channels_first<scalar_t><<<blocks_grid, threads_block, 0, stream>>>( + input2.data<scalar_t>(), + rInput2.data<scalar_t>(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); + + for (int n = 0; n < num; ++n) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + + correlation_backward_input1<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>> ( + n, gradInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, + gradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth, + rInput2.data<scalar_t>(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + })); + } + + for(int n = 0; n < batchSize; n++) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { + + correlation_backward_input2<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>>( + n, gradInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, + gradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth, + rInput1.data<scalar_t>(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + + })); + } + + // check for errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); + return 0; + } + + return 1; +} diff --git a/windflow/external_packages/correlation_package/correlation_cuda_kernel.cuh b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cuh new file mode 100755 index 0000000000000000000000000000000000000000..1586d3af6bc184bfea8482a991a6625f865f02b3 --- /dev/null +++ b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cuh @@ -0,0 +1,91 @@ +#pragma once + +#include <ATen/ATen.h> +#include <ATen/Context.h> +#include <cuda_runtime.h> + +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); + + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); diff --git a/windflow/external_packages/correlation_package/cuda/deform_conv_cuda.cu b/windflow/external_packages/correlation_package/cuda/deform_conv_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..087adfe9feb631c05d2872967394b218af946abf --- /dev/null +++ b/windflow/external_packages/correlation_package/cuda/deform_conv_cuda.cu @@ -0,0 +1,691 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> + +#include <THC/THC.h> +#include <THC/THCDeviceUtils.cuh> + +#include <vector> +#include <iostream> +#include <cmath> + + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) +{ + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) +{ + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) +{ + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) +{ + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/windflow/external_packages/correlation_package/cuda/deform_pool_cuda.cu b/windflow/external_packages/correlation_package/cuda/deform_pool_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..0be4cb866a95f338e486cdf08e80d901cdee7e88 --- /dev/null +++ b/windflow/external_packages/correlation_package/cuda/deform_pool_cuda.cu @@ -0,0 +1,87 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c + +// based on +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +#include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> + +#include <THC/THC.h> +#include <THC/THCDeviceUtils.cuh> + +#include <vector> +#include <iostream> +#include <cmath> + + +void DeformablePSROIPoolForward( + const at::Tensor data, const at::Tensor bbox, const at::Tensor trans, + at::Tensor out, at::Tensor top_count, const int batch, const int channels, + const int height, const int width, const int num_bbox, + const int channels_trans, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void DeformablePSROIPoolBackwardAcc( + const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox, + const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad, + at::Tensor trans_grad, const int batch, const int channels, + const int height, const int width, const int num_bbox, + const int channels_trans, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std); + +void deform_psroi_pooling_cuda_forward( + at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, + at::Tensor top_count, const int no_trans, const float spatial_scale, + const int output_dim, const int group_size, const int pooled_size, + const int part_size, const int sample_per_part, const float trans_std) +{ + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + + const int num_bbox = bbox.size(0); + if (num_bbox != out.size(0)) + AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", + out.size(0), num_bbox); + + DeformablePSROIPoolForward( + input, bbox, trans, out, top_count, batch, channels, height, width, + num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size, + pooled_size, part_size, sample_per_part, trans_std); +} + +void deform_psroi_pooling_cuda_backward( + at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, + at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, + const int no_trans, const float spatial_scale, const int output_dim, + const int group_size, const int pooled_size, const int part_size, + const int sample_per_part, const float trans_std) +{ + TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + + const int num_bbox = bbox.size(0); + if (num_bbox != out_grad.size(0)) + AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", + out_grad.size(0), num_bbox); + + DeformablePSROIPoolBackwardAcc( + out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch, + channels, height, width, num_bbox, channels_trans, no_trans, + spatial_scale, output_dim, group_size, pooled_size, part_size, + sample_per_part, trans_std); +} diff --git a/windflow/external_packages/correlation_package/setup.py b/windflow/external_packages/correlation_package/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..b852d94cdbacd3e6bdae451eb1be903523dee376 --- /dev/null +++ b/windflow/external_packages/correlation_package/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +import os +import torch + +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +cxx_args = ['-std=c++11', '-std=c++14'] + +nvcc_args = [ + '-gencode', 'arch=compute_50,code=sm_50', + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_70,code=compute_70' +] + +setup( + name='correlation_cuda', + ext_modules=[ + CUDAExtension('correlation_cuda', [ + 'correlation_cuda.cc', + 'correlation_cuda_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/windflow/external_packages/resample2d_package/__init__.py b/windflow/external_packages/resample2d_package/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/external_packages/resample2d_package/resample2d.py b/windflow/external_packages/resample2d_package/resample2d.py new file mode 100755 index 0000000000000000000000000000000000000000..92ea0d01f3cd41c778d60a8b9073d5884a74bce0 --- /dev/null +++ b/windflow/external_packages/resample2d_package/resample2d.py @@ -0,0 +1,49 @@ +from torch.nn.modules.module import Module +from torch.autograd import Function, Variable +import resample2d_cuda + +class Resample2dFunction(Function): + + @staticmethod + def forward(ctx, input1, input2, kernel_size=1, bilinear= True): + assert input1.is_contiguous() + assert input2.is_contiguous() + + ctx.save_for_backward(input1, input2) + ctx.kernel_size = kernel_size + ctx.bilinear = bilinear + + _, d, _, _ = input1.size() + b, _, h, w = input2.size() + output = input1.new(b, d, h, w).zero_() + + resample2d_cuda.forward(input1, input2, output, kernel_size, bilinear) + + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + assert grad_output.is_contiguous() + + input1, input2 = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + grad_input2 = Variable(input1.new(input2.size()).zero_()) + + resample2d_cuda.backward(input1, input2, grad_output.data, + grad_input1.data, grad_input2.data, + ctx.kernel_size, ctx.bilinear) + + return grad_input1, grad_input2, None, None + +class Resample2d(Module): + + def __init__(self, kernel_size=1, bilinear = True): + super(Resample2d, self).__init__() + self.kernel_size = kernel_size + self.bilinear = bilinear + + def forward(self, input1, input2): + input1_c = input1.contiguous() + return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.bilinear) diff --git a/windflow/external_packages/resample2d_package/resample2d_cuda.cc b/windflow/external_packages/resample2d_package/resample2d_cuda.cc new file mode 100755 index 0000000000000000000000000000000000000000..75cc6260dc6177f4b597a6d821d3a1deae0c21eb --- /dev/null +++ b/windflow/external_packages/resample2d_package/resample2d_cuda.cc @@ -0,0 +1,32 @@ +#include <ATen/ATen.h> +#include <torch/torch.h> + +#include "resample2d_kernel.cuh" + +int resample2d_cuda_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size, bool bilinear) { + resample2d_kernel_forward(input1, input2, output, kernel_size, bilinear); + return 1; +} + +int resample2d_cuda_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size, bool bilinear) { + resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, bilinear); + return 1; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); + m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); +} + diff --git a/windflow/external_packages/resample2d_package/resample2d_kernel.cu b/windflow/external_packages/resample2d_package/resample2d_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..a67430b0d78731a1abf18417726fd3da830bbb3e --- /dev/null +++ b/windflow/external_packages/resample2d_package/resample2d_kernel.cu @@ -0,0 +1,323 @@ +#include <ATen/ATen.h> +#include <ATen/Context.h> +#include <ATen/cuda/CUDAContext.h> + +#define CUDA_NUM_THREADS 512 +#define THREADS_PER_BLOCK 64 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +template <typename scalar_t> +__global__ void kernel_resample2d_update_output(const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size, bool bilinear) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + scalar_t val = 0.0f; + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast<scalar_t>(x) + dx; + scalar_t yf = static_cast<scalar_t>(y) + dy; + scalar_t alpha = xf - floor(xf); // alpha + scalar_t beta = yf - floor(yf); // beta + + if (bilinear) { + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + val += static_cast<float>((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx)); + val += static_cast<float>((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx)); + val += static_cast<float>((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx)); + val += static_cast<float>((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx)); + } + } + + output[index] = val; + } + else { + int xN = max(min( int (floor(xf + 0.5)), dim_w - 1), 0); + int yN = max(min( int (floor(yf + 0.5)), dim_h - 1), 0); + + output[index] = static_cast<float> ( DIM3_INDEX(input1, b, c, yN, xN) ); + } + +} + + +template <typename scalar_t> +__global__ void kernel_resample2d_backward_input1( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, bool bilinear) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + int dim_b = DIM0(gradOutput_size); + int dim_c = DIM1(gradOutput_size); + int dim_h = DIM2(gradOutput_size); + int dim_w = DIM3(gradOutput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast<scalar_t>(x) + dx; + scalar_t yf = static_cast<scalar_t>(y) + dy; + scalar_t alpha = xf - int(xf); // alpha + scalar_t beta = yf - int(yf); // beta + + int idim_h = DIM2(input1_size); + int idim_w = DIM3(input1_size); + + int xL = max(min( int (floor(xf)), idim_w-1), 0); + int xR = max(min( int (floor(xf)+1), idim_w -1), 0); + int yT = max(min( int (floor(yf)), idim_h-1), 0); + int yB = max(min( int (floor(yf)+1), idim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + } + } + +} + +template <typename scalar_t> +__global__ void kernel_resample2d_backward_input2( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, bool bilinear) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + scalar_t output = 0.0; + int kernel_rad = (kernel_size - 1)/2; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int odim_c = DIM1(gradOutput_size); + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast<scalar_t>(x) + dx; + scalar_t yf = static_cast<scalar_t>(y) + dy; + + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + if (c % 2) { + float gamma = 1 - (xf - floor(xf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + } + } + } + } + else { + float gamma = 1 - (yf - floor(yf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + } + } + } + + } + + gradInput[index] = output; + +} + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size, + bool bilinear) { + + int n = output.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] { + + kernel_resample2d_update_output<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data<float>(), + input1_size, + input1_stride, + input2.data<float>(), + input2_size, + input2_stride, + output.data<float>(), + output_size, + output_stride, + kernel_size, + bilinear); + +// })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); + +} + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size, + bool bilinear) { + + int n = gradOutput.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] { + + kernel_resample2d_backward_input1<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data<float>(), + input1_size, + input1_stride, + input2.data<float>(), + input2_size, + input2_stride, + gradOutput.data<float>(), + gradOutput_size, + gradOutput_stride, + gradInput1.data<float>(), + gradInput1_size, + gradInput1_stride, + kernel_size, + bilinear + ); + +// })); + + const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3)); + const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3)); + + n = gradInput2.numel(); + +// AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] { + + + kernel_resample2d_backward_input2<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data<float>(), + input1_size, + input1_stride, + input2.data<float>(), + input2_size, + input2_stride, + gradOutput.data<float>(), + gradOutput_size, + gradOutput_stride, + gradInput2.data<float>(), + gradInput2_size, + gradInput2_stride, + kernel_size, + bilinear + ); + +// })); + + // TODO: Use the ATen equivalent to get last error + + // THCudaCheck(cudaGetLastError()); + +} diff --git a/windflow/external_packages/resample2d_package/resample2d_kernel.cuh b/windflow/external_packages/resample2d_package/resample2d_kernel.cuh new file mode 100755 index 0000000000000000000000000000000000000000..a25951599231fd3dcfa409228c34738e9f039cdb --- /dev/null +++ b/windflow/external_packages/resample2d_package/resample2d_kernel.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include <ATen/ATen.h> + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size, + bool bilinear); + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size, + bool bilinear); \ No newline at end of file diff --git a/windflow/external_packages/resample2d_package/setup.py b/windflow/external_packages/resample2d_package/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..6720026f58231f9fbd468429df6a210d88bbbe03 --- /dev/null +++ b/windflow/external_packages/resample2d_package/setup.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +import os +import torch + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +#cxx_args = ['-std=c++11'] +cxx_args = ['-std=c++14'] + +nvcc_args = [ + '-gencode', 'arch=compute_50,code=sm_50', + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_70,code=compute_70' +] + +setup( + name='resample2d_cuda', + ext_modules=[ + CUDAExtension('resample2d_cuda', [ + 'resample2d_cuda.cc', + 'resample2d_kernel.cu' + ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/windflow/inference/__pycache__/inference_flows.cpython-39.pyc b/windflow/inference/__pycache__/inference_flows.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d80a1ed9993ef81eafed5c093853f03fd66b2ca5 Binary files /dev/null and b/windflow/inference/__pycache__/inference_flows.cpython-39.pyc differ diff --git a/windflow/inference/inference_flows.py b/windflow/inference/inference_flows.py new file mode 100644 index 0000000000000000000000000000000000000000..e047603f0fb33c0acc458651bb053f4fe5cca9d5 --- /dev/null +++ b/windflow/inference/inference_flows.py @@ -0,0 +1,324 @@ +import sys +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +import time +import glob + +import numpy as np +import pandas as pd +import torch +from torch import nn +import xarray as xr +import matplotlib.pyplot as plt +import datetime as dt +import scipy.stats + +from ..datasets.geostationary import goesr, stats +from ..networks.models import get_flow_model +from ..utils import plotting +from ..datasets.preprocess import image_histogram_equalization +from ..datasets.utils import cartesian_to_speed + +def split_array(arr, tile_size, overlap): + c, h, w = arr.shape + patches = [] + indicies = [] + for i in range(0, h, tile_size-overlap): + for j in range(0, w, tile_size-overlap): + i = min(i, h - tile_size) + j = min(j, w - tile_size) + indicies.append([i, j]) + patches.append(arr[np.newaxis,:,i:i+tile_size, j:j+tile_size]) + indices = np.array(indicies) + patches = np.concatenate(patches, axis=0) + return patches, indicies + +def reassemble_split_array(arr, upperleft, shape, trim=0): + assert len(arr) > 0 + + # perform inference on patches + height, width = arr.shape[2:4] + counter = np.zeros(shape) + out_sum = np.zeros(shape) + for i, x in enumerate(arr): + ix, iy = upperleft[i] + if trim > 0: + counter[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += 1 + out_sum[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += x[:,trim:-trim,trim:-trim] + else: + counter[:,ix:ix+height,iy:iy+width] += 1 + out_sum[:,ix:ix+height,iy:iy+width] += x + + out = out_sum / counter + return out + +def reassemble_with_2d_gaussian(arr, upperleft, shape, trim=0): + assert len(arr) > 0 + + # perform inference on patches + height, width = arr.shape[2:4] + counter = np.zeros(shape) + out_sum = np.zeros(shape) + + nsig = 3 + x = np.linspace(-nsig, nsig, height+1-trim*2) + kern1d = np.diff(scipy.stats.norm.cdf(x)) + kern2d = np.outer(kern1d, kern1d) + pdf = kern2d/kern2d.sum() + + for i, x in enumerate(arr): + ix, iy = upperleft[i] + if trim > 0: + counter[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += pdf + out_sum[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += x[:,trim:-trim,trim:-trim] * pdf + else: + counter[:,ix:ix+height,iy:iy+width] += pdf + out_sum[:,ix:ix+height,iy:iy+width] += x * pdf + + out = out_sum / counter + return out + + +def write_to_netcdf(filename, ua, va, lat, lon, ts, + pressure=None): + + ua = xr.DataArray(ua.astype(np.float32), coords=[('pressure', pressure),('lat', lat), ('lon', lon)]) + va = xr.DataArray(va.astype(np.float32), coords=[('pressure', pressure),('lat', lat), ('lon', lon)]) + newds = xr.Dataset(dict(ua=ua, va=va, time=ts, pressure=pressure)) + newds.to_netcdf(filename) + print("Dataset written to file: {}".format(filename)) + + +def inference(model, X0, X1, + tile_size=512, + overlap=128, + upsample_flow_factor=None, + batch_size=32): + c, h, w = X0.shape + trim = 0 #overlap // 4 + + if isinstance(upsample_flow_factor, int): + upsample_flow = nn.Upsample(scale_factor=upsample_flow_factor, mode='bilinear') + else: + upsample_flow = None + + #if isinstance(upsample_input, float) + + f_sum = np.zeros((1, 2, h, w)) + f_counter = np.zeros((1, 1, h, w)) + + x0_patches, upperleft = split_array(X0, tile_size, overlap) + x1_patches, _ = split_array(X1, tile_size, overlap) + + x0_patches = torch.from_numpy(x0_patches).float() + x1_patches = torch.from_numpy(x1_patches).float() + pred = [] + for batch in range(0, x1_patches.shape[0], batch_size): + x0_batch = x0_patches[batch:batch+batch_size].cuda() + x1_batch = x1_patches[batch:batch+batch_size].cuda() + + # testing 2x upsampling before inference + #upsample_2x = nn.Upsample(scale_factor=2, mode='bilinear') + #x0_batch = upsample_2x(x0_batch) + #x1_batch = upsample_2x(x1_batch) + + #try: + model_out = model(x0_batch, x1_batch, test_mode=True)[0] + #except TypeError: + #model_out = model(x0, x1)[0] + # model_out = model(torch.cat([x0_batch, x1_batch], 1))[0] + + if upsample_flow: + model_out = upsample_flow(model_out) #* upsample_flow_factor + #model_out = upsample_2x(model_out) / 2 + + pred.append(model_out.cpu().detach().numpy()) + + pred = np.concatenate(pred, 0) + #UV = reassemble_split_array(pred, upperleft, (2, h, w), trim=trim) + UV = reassemble_with_2d_gaussian(pred, upperleft, (2, h, w), trim=trim) + + return UV + + +class FlowRunner(object): + ''' + Operator to perform inference on general inputs and flow models + ''' + def __init__(self, model_name, + tile_size=512, + overlap=384, + batch_size=8, + upsample_input=None, + device='cuda:0'): + self.model = get_flow_model(model_name, small=False) + self.model_name = model_name + self.tile_size = tile_size + self.batch_size = batch_size + self.overlap = overlap + + # upsample input 2x can improve feature tracking at inference time + #if upsample_input_factor is not None: + # self.upsample_input = nn.Upsample(scale_factor=upsample_input_factor, mode='bilinear') + #else: + # upsample_input_factor = 1 + + if self.model_name.lower() in ['flownets', 'pwcnet', 'pwc-net', 'maskflownet']: + self.upsample_flow_factor = 4 #/ upsample_input_factor + else: + self.upsample_flow_factor = None #1 / upsample_input_factor + + #self.model.to(device) + self.model = self.model.cuda() + self.model = torch.nn.DataParallel(self.model) + #self.model = torch.nn.DataParallel(self.model, device_ids=[int(list(device)[-1])]) + #self.model = torch.nn.DataParallel(self.model, device_ids=[0,]) + + def load_checkpoint(self, checkpoint_file): + checkpoint = torch.load(checkpoint_file) + self.global_step = checkpoint['global_step'] + #for key in checkpoint['model']: + # if 'module' in key: + # new_key = key.replace('module.', '') + # checkpoint[new_key] = checkpoint[key] + # del checkpoint[key] + + try: + self.model.module.load_state_dict(checkpoint['model']) + except: + self.model.load_state_dict(checkpoint['model']) + + print(f"Loaded checkpoint: {checkpoint_file}") + + def preprocess(self, x): + x[~np.isfinite(x)] = 0. + x = image_histogram_equalization(x) + return x + + def forward(self, img1, img2): + mask = (img1 == img1) + mask[~mask] = np.nan + + img1_norm = self.preprocess(img1) + img2_norm = self.preprocess(img2) + + flows = inference(self.model, img1_norm[np.newaxis], img2_norm[np.newaxis], + tile_size=self.tile_size, overlap=self.overlap, + upsample_flow_factor=self.upsample_flow_factor, + batch_size=self.batch_size) + return img1, flows * mask[np.newaxis] + + + +class GeoFlows(FlowRunner): + ''' + Object to manage optical flow prediction for geostationary L1b observations + ''' + def __init__(self, model_name, + data_directory, + product='ABI-L1b-RadF', + upsample_data=None, + channels=[10], + timestep=10, + spatial=None, + **kwargs): + FlowRunner.__init__(self, model_name, **kwargs) + self.data_directory = data_directory + self.timestep = timestep + self.upsample_data = upsample_data + self.spatial = spatial + + self.goes = goesr.GOESL1b(product=product, channels=channels, + data_directory=data_directory) + self.channels = channels + + def flows_by_time(self, t, reproject=False): + t2 = t + dt.timedelta(minutes=self.timestep) + file1 = self.goes.snapshot_file(t.year, t.timetuple().tm_yday, t.hour, + t.minute, spatial=self.spatial).values[0] + file2 = self.goes.snapshot_file(t2.year, t2.timetuple().tm_yday, t2.hour, + t2.minute, spatial=self.spatial).values[0] + + obj1 = goesr.L1bBand(file1) + obj2 = goesr.L1bBand(file2) + + if self.upsample_data is not None: + obj1.interp(self.upsample_data) + obj2.interp(self.upsample_data) + + lats, lons = obj1.latlon() + + if reproject: + data1 = obj1.reproject_to_latlon() + data2 = obj2.reproject_to_latlon() + else: + data1 = obj1.open_dataset() + data2 = obj2.open_dataset() + + img1 = data1['Rad'].values + img2 = data2['Rad'].values + + flows = self.forward(img1, img2)[1] + + if reproject: + data1['U'] = xr.DataArray(flows[0], dims=['lat', 'lon'], + coords=dict(lat=data1.lat.values, lon=data1.lon.values)) + data1['V']= xr.DataArray(flows[1], dims=['lat', 'lon'], + coords=dict(lat=data1.lat.values, lon=data1.lon.values)) + else: + data1['lat'] = xr.DataArray(lats, dims=['y', 'x'], + coords=dict(y=data1.y.values, x=data1.x.values)) + data1['lon'] = xr.DataArray(lons, dims=['y', 'x'], + coords=dict(y=data1.y.values, x=data1.x.values)) + + data1['U'] = xr.DataArray(flows[0], dims=['y', 'x'], + coords=dict(y=data1.y.values, x=data1.x.values)) + data1['V']= xr.DataArray(flows[1], dims=['y', 'x'], + coords=dict(y=data1.y.values, x=data1.x.values)) + + data1 = cartesian_to_speed(data1) + data1['U'] *= 2000 / (self.timestep * 60) * 0.8 # m/s on 2km grid -- 0.8 scales to remove bias + data1['V'] *= 2000 / (self.timestep * 60) * 0.8 # m/s on 2km grid + return data1 + + + '''def flows_iterate(self, start_year, start_dayofyear, start_hour, steps=6*24): + + files = self.goes.local_files(start_year, start_dayofyear).reset_index() + files = files[files['hour'] >= start_hour] + + curr_date = dt.datetime(start_year, 1, 1, 1) + dt.timedelta(days=start_dayofyear-1) + while len(files) < steps: + curr_date += dt.timedelta(days=1) + more_files = self.goes.local_files(curr_date.year, curr_date.timetuple().tm_yday).reset_index() + files = pd.concat([files, more_files]) + + band1 = goesr.L1bBand(files['file'].values[0,0]) + data1 = band1.open_dataset() + img1 = data1['Rad'].values + + for idx in range(1, steps): + band2 = goesr.L1bBand(files['file'].values[idx,0]) + data2 = band2.open_dataset() + img2 = data2['Rad'].values + + flows = self.forward(img1, img2)[1] + + yield band1, flows + + band1 = band2 + data1 = data2 + img1 = img2 + ''' + +if __name__ == "__main__": + + data_directory = '/nex/datapool/geonex/public/GOES16/NOAA-L1B/' + model = FlowNetS.FlowNetS(None, 1) + checkpoint_path = '/nobackupp10/tvandal/nex-ai-opticalflow/scripts/experiments/gmao_osse/flownets-size_512-lognorm/checkpoint.flownet.pth.tar' + model.training = False + + inference = inference_flows.GeoFlows(model, data_directory) + + diff --git a/windflow/inference/inference_tools.py b/windflow/inference/inference_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..e040d27cc867c59e1be3bb76cd7cb06e2d71ad74 --- /dev/null +++ b/windflow/inference/inference_tools.py @@ -0,0 +1,242 @@ +import sys +import os + +import xarray as xr +from .utils import blocks +import torch +import torch.nn as nn +import torchvision + +from slomo import flownet as fl +from slomo import unet + +import numpy as np + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +def block_predictions_to_dataarray(predictions, block): + block_predictions = np.concatenate(predictions, 0) + block_predictions[block_predictions < 0] = 0 + block_predictions[block_predictions > 1] = 1 + + N_pred = block_predictions.shape[0] + T = np.arange(0,N_pred) + da = xr.DataArray(block_predictions,#[:,:,shave:-shave,shave:-shave], + coords=[T, block.band.values, + block.y.values,#[shave:-shave], + block.x.values,],#][shave:-shave]], + dims=['t', 'band', 'y', 'x']) + + return da + +def merge_and_average_dataarrays(dataarrays): + ds = xr.merge([xr.Dataset({k: d}) for k, d in enumerate(dataarrays)]) + das = [] + for b in range(0,len(dataarrays)): + das.append(ds[b]) + + return xr.concat(das).mean('concat_dims', skipna=True) + +# Split 3D numpy array into patches with overlap +def split_array(arr, tile_size=128, overlap=16): + ''' + Split a 3D numpy array into patches for inference + (Channels, Height, Width) + Args: + tile_size: width and height of patches to return + overlap: number of pixels to overlap between patches + Returns: + dict(patches, upper_left): patches and indices of original array + ''' + arr = arr[np.newaxis] + width, height = arr.shape[2:4] + arrs = dict(patches=[], upper_left=[]) + for i in range(0, width, tile_size - overlap): + for j in range(0, height, tile_size - overlap): + i = min(i, width - tile_size) + j = min(j, height - tile_size) + arrs['patches'].append(arr[:,:,i:i+tile_size,j:j+tile_size]) + arrs['upper_left'].append([[i,j]]) + arrs['patches'] = np.concatenate(arrs['patches']) + arrs['upper_left'] = np.concatenate(arrs['upper_left']) + return arrs['patches'], arrs['upper_left'] + +## Load and return flownet, interpnet, and warper +def load_models(n_channels, model_path, multivariate=False, + nn_model=unet.UNetSmall): + + model_filename = os.path.join(model_path, 'best.flownet.pth.tar') + if multivariate: + flownet = fl.SloMoFlowNetMV(n_channels, model=nn_model)#.cuda() + interpnet = fl.SloMoInterpNetMV(n_channels, model=nn_model)#.cuda() + else: + flownet = fl.SloMoFlowNet(n_channels, model=nn_model)#.cuda() + interpnet = fl.SloMoInterpNet(n_channels, model=nn_model)#.cuda() + + warper = fl.FlowWarper() + + flownet = nn.DataParallel(flownet) + interpnet = nn.DataParallel(interpnet) + warper = nn.DataParallel(warper) + + flownet = flownet.to(device) + interpnet = interpnet.to(device) + warper = warper.to(device) + + def load_checkpoint(flownet, interpnet): + checkpoint = torch.load(model_filename) + flownet.load_state_dict(checkpoint['flownet_state_dict']) + interpnet.load_state_dict(checkpoint['interpnet_state_dict']) + + epoch = 0 + if os.path.isfile(model_filename): + print("loading checkpoint %s" % model_filename) + checkpoint = torch.load(model_filename) + flownet.load_state_dict(checkpoint['flownet_state_dict']) + interpnet.load_state_dict(checkpoint['interpnet_state_dict']) + epoch = checkpoint['epoch'] + print("=> loaded checkpoint '{}' (epoch {})" + .format(model_filename, epoch)) + else: + print("=> no checkpoint found at '{}'".format(model_filename)) + return flownet, interpnet + + flownet.train() + interpnet.train() + flownet, interpnet = load_checkpoint(flownet, interpnet) + return flownet, interpnet, warper + +## Perform interpolation for a single timepoint t + +def single_inference(X0, X1, t, flownet, interpnet, + multivariate): + X0_arr_torch = torch.from_numpy(X0) + X1_arr_torch = torch.from_numpy(X1) + + # nans to 0 + X0_arr_torch[np.isnan(X0_arr_torch)] = 0. + X1_arr_torch[np.isnan(X1_arr_torch)] = 0. + + X0_arr_torch = torch.unsqueeze(X0_arr_torch, 0).to(device, dtype=torch.float) + X1_arr_torch = torch.unsqueeze(X1_arr_torch, 0).to(device, dtype=torch.float) + + f = flownet(X0_arr_torch, X1_arr_torch) + n_channels = X0_arr_torch.shape[1] + + if multivariate: + f_01 = f[:,:2*n_channels] + f_10 = f[:,2*n_channels:] + else: + f_01 = f[:,:2] + f_10 = f[:,2:] + + predicted_frames = [] + + I_t, g0, g1, V_t0, V_t1, delta_f_t0, delta_f_t1 = interpnet(X0_arr_torch, X1_arr_torch, f_01, f_10, t) + predicted_frames.append(I_t.cpu().detach().numpy()) + + out = {'f_01': f_01, 'f_10': f_10, 'I_t': I_t, + 'V_t0': V_t0, 'V_t1': V_t1, 'delta_f_t0': delta_f_t0, + 'delta_f_t1': delta_f_t0} + + for k in out.keys(): + out[k] = out[k][0].cpu().detach().numpy() + + #torch.cuda.empty_cache() + return out + +## Split interpolation for a single timepoint t +def single_inference_split(X0, X1, t, flownet, interpnet, + multivariate, block_size=128, overlap=16, + discard=0): + X0_split, upper_left_idxs = split_array(X0, block_size, overlap) + X1_split, _ = split_array(X1, block_size, overlap) + + assert len(X0_split) > 0 + + # perform inference on patches + height, width = X0.shape[1:3] + #counter = np.zeros((1, height, width)) + counter = np.zeros((1, height-discard*2, width-discard*2)) + res_sum = {} + for i, (x0, x1) in enumerate(zip(X0_split, X1_split)): + ix, iy = upper_left_idxs[i] + res_i = single_inference(x0, x1, t, flownet, interpnet, + multivariate) + keys = res_i.keys() + if i == 0: + res_sum = {k: np.zeros((res_i[k].shape[0], height-discard*2, width-discard*2)) for k in keys} + + for var in keys: + if discard > 0: + res_i[var] = res_i[var][:,discard:-discard,discard:-discard] + res_sum[var][:,ix:ix+block_size-discard*2,iy:iy+block_size-discard*2] += res_i[var] + counter[:,ix:ix+block_size-discard*2,iy:iy+block_size-discard*2] += 1. + + out = {} + for var in res_sum.keys(): + out[var] = res_sum[var] / counter + + return out + + +def _inference(X0, X1, flownet, interpnet, warper, + multivariate, T=4, block_size=None): + ''' + Given two consecutive frames, interpolation T between them + using flownet and interpnet + Returns: + Interpolated Frames + ''' + + X0_arr_torch = torch.from_numpy(X0.values) + X1_arr_torch = torch.from_numpy(X1.values) + + # nans to 0 + X0_arr_torch[np.isnan(X0_arr_torch)] = 0. + X1_arr_torch[np.isnan(X1_arr_torch)] = 0. + + X0_arr_torch = torch.unsqueeze(X0_arr_torch, 0).to(device, dtype=torch.float) + X1_arr_torch = torch.unsqueeze(X1_arr_torch, 0).to(device, dtype=torch.float) + + f = flownet(X0_arr_torch, X1_arr_torch) + n_channels = X0_arr_torch.shape[1] + + if multivariate: + f_01 = f[:,:2*n_channels] + f_10 = f[:,2*n_channels:] + else: + f_01 = f[:,:2] + f_10 = f[:,2:] + + predicted_frames = [] + + for j in range(1,T+1): + t = 1. * j / (T+1) + I_t, g0, g1, V_t0, V_t1, delta_f_t0, delta_f_t1 = interpnet(X0_arr_torch, X1_arr_torch, f_01, f_10, t) + predicted_frames.append(I_t.cpu().detach().numpy()) + + torch.cuda.empty_cache() + return predicted_frames + +## Perform interpolation over multiple time steps +def inference(X0, X1, flownet, interpnet, warper, + multivariate, T=4, block_size=352): + ''' + Given two consecutive frames, interpolation T between them + using flownet and interpnet + Returns: + Interpolated Frames + ''' + # Get a list of dataarrays chunked + X0_blocks = blocks(X0, width=block_size) + X1_blocks = blocks(X1, width=block_size) + + interpolated_blocks = [] + for x0, x1 in zip(X0_blocks, X1_blocks): + predicted_frames = _inference(x0, x1, flownet, + interpnet, warper, + multivariate, T) + predicted_frames = [x0.values[np.newaxis]] + predicted_frames + [x1.values[np.newaxis]] + interpolated_blocks += [block_predictions_to_dataarray(predicted_frames, x0)] + return merge_and_average_dataarrays(interpolated_blocks) diff --git a/windflow/inference/utils.py b/windflow/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51fba9e3af95f80bd45482dbdfd05c0882cee366 --- /dev/null +++ b/windflow/inference/utils.py @@ -0,0 +1,100 @@ +import xarray as xr +import numpy as np +import cv2 + +import scipy.interpolate + +def fillmiss(x): + if x.ndim != 2: + raise ValueError("X have only 2 dimensions.") + mask = ~np.isnan(x) + xx, yy = np.meshgrid(np.arange(x.shape[1]), np.arange(x.shape[0])) + xym = np.vstack( (np.ravel(xx[mask]), np.ravel(yy[mask])) ).T + data0 = np.ravel(x[mask]) + interp0 = scipy.interpolate.NearestNDInterpolator(xym, data0) + result0 = interp0(np.ravel(xx), np.ravel(yy)).reshape(xx.shape) + return result0 + +def interp_dim(x, scale): + x0, xlast = x[0], x[-1] + newlength = int(len(x) * scale) + y = np.linspace(x0, xlast, num=newlength, endpoint=False) + return y + +def interp_tensor(X, scale, fill=True, how=cv2.INTER_NEAREST): + nlt = int(X.shape[1]*scale) + nln = int(X.shape[2]*scale) + newshape = (X.shape[0], nlt, nln) + scaled_tensor = np.empty(newshape) + for j, im in enumerate(X): + # fill im with nearest neighbor + if fill: + #im = fillmiss(im) + im[np.isnan(im)] = 0 + + scaled_tensor[j] = cv2.resize(im, (newshape[2], newshape[1]), + interpolation=how) + return scaled_tensor + +def interp_da(da, scale, how=cv2.INTER_LINEAR): + """ + Assume da is of dimensions ('time','lat', 'lon') + """ + tensor = da.values + + # interpolate lat and lons + latnew = interp_dim(da[da.dims[1]].values, scale) + lonnew = interp_dim(da[da.dims[2]].values, scale) + + # lets store our interpolated data + scaled_tensor = interp_tensor(tensor, scale, fill=True, how=how) + + if latnew.shape[0] != scaled_tensor.shape[1]: + raise ValueError("New shape is shitty") + # intialize a new dataarray + return xr.DataArray(scaled_tensor, coords=[da[da.dims[0]].values, latnew, lonnew], + dims=da.dims) + +def interp_da2d(da, scale, fillna=False, how=cv2.INTER_NEAREST): + """ + Assume da is of dimensions ('time','lat', 'lon') + """ + # lets store our interpolated data + newshape = (int(da.shape[0]*scale),int(da.shape[1]*scale)) + im = da.values + scaled_tensor = np.empty(newshape) + # fill im with nearest neighbor + if fillna: + filled = fillmiss(im) + else: + filled = im + scaled_tensor = cv2.resize(filled, dsize=(0,0), fx=scale, fy=scale, + interpolation=how) + + # interpolate lat and lons + latnew = interp_dim(da[da.dims[1]].values, scale) + lonnew = interp_dim(da[da.dims[0]].values, scale) + + # intialize a new dataarray + return xr.DataArray(scaled_tensor, coords=[lonnew, latnew], + dims=da.dims) + +def blocks(data, width=352): + #n = data.t.shape[0] + w = data.x.shape[0] + h = data.y.shape[0] + d = data.band.shape[0] + + hs = np.arange(0, h, width) + ws = np.arange(0, w, width) + blocks = [] + for hindex in hs: + if hindex+width > h: + hindex = h - width + + for windex in ws: + if windex+width > w: + windex = w - width + blocks.append(data.sel(y=data.y.values[hindex:hindex+width], + x=data.x.values[windex:windex+width])) + return blocks diff --git a/windflow/losses/__init__.py b/windflow/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f37cff02a0b650e96239e8f7041b71ad9ed115c2 --- /dev/null +++ b/windflow/losses/__init__.py @@ -0,0 +1,2 @@ +from .ssim import SSIM, ssim +from .losses import SmoothnessLoss, CharbonnierLoss, L1Loss, L1, L2Loss, RMSVDLoss \ No newline at end of file diff --git a/windflow/losses/losses.py b/windflow/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..3457b0b12106d1cdb95305ba9acb0ea53ba49d01 --- /dev/null +++ b/windflow/losses/losses.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn + +class SmoothnessLoss(nn.Module): + def __init__(self): + super(SmoothnessLoss, self).__init__() + + def forward(self, flows, mask=None): + gradx = torch.abs(flows[:,:,:,:-1] - flows[:,:,:,1:]) + grady = torch.abs(flows[:,:,:-1,:] - flows[:,:,1:,:]) + normalize = 1. + if isinstance(mask, torch.Tensor): + gradx *= mask[:,:,:,:-1] + grady *= mask[:,:,:-1,:] + normalize = torch.mean(mask) + loss_smooth = (torch.mean(gradx) + torch.mean(grady)) / normalize + return loss_smooth + +class CharbonnierLoss(nn.Module): + def __init__(self, alpha, eps=1e-6): + super(CharbonnierLoss, self).__init__() + self.alpha = alpha + self.eps = eps + + def forward(self, I0, I1, mask=None): + err = ((I0 - I1)**2 + self.eps)**self.alpha + norm = 1. + if isinstance(mask, torch.Tensor): + err *= mask + norm = torch.mean(mask) + return torch.mean(err) / norm + +def EPE(input_flow, target_flow): + return torch.norm(target_flow-input_flow,p=2,dim=1).mean() + +class L1(nn.Module): + def __init__(self): + super(L1, self).__init__() + + def forward(self, output, target): + lossvalue = torch.abs(output - target).mean() + return lossvalue + +class L2(nn.Module): + def __init__(self): + super(L2, self).__init__() + def forward(self, output, target): + lossvalue = torch.norm(output-target,p=2,dim=1).mean() + return lossvalue + +class L1Loss(nn.Module): + def __init__(self, args): + super(L1Loss, self).__init__() + self.args = args + self.loss = L1() + self.loss_labels = ['L1', 'EPE'] + + def forward(self, output, target): + lossvalue = self.loss(output, target) + #epevalue = EPE(output, target) + return lossvalue + #return [lossvalue, epevalue] + +class L2Loss(nn.Module): + def __init__(self, args): + super(L2Loss, self).__init__() + self.args = args + self.loss = L2() + + def forward(self, output, target): + lossvalue = self.loss(output, target) + #epevalue = EPE(output, target) + return lossvalue #+ epevalue + + +class RMSVDLoss(nn.Module): + def __init__(self): + super(RMSVDLoss, self).__init__() + + def forward(self, output, target): + # output and target shape (N, 2, H, W) + u_err = (output[:,0] - target[:,0])**2 + v_err = (output[:,1] - target[:,1])**2 + rmsvd = (u_err.mean() + v_err.mean())**0.5 + return rmsvd diff --git a/windflow/losses/ssim.py b/windflow/losses/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..b1e7d6cf32205922db285a9c837ec35d271e615e --- /dev/null +++ b/windflow/losses/ssim.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) \ No newline at end of file diff --git a/windflow/networks/PWCNet.py b/windflow/networks/PWCNet.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5531de356a9c353ca64c1097afa8e6287fe3c6 --- /dev/null +++ b/windflow/networks/PWCNet.py @@ -0,0 +1,1146 @@ +""" +implementation of the PWC-DC network for optical flow estimation by Sun et al., 2018 + +Jinwei Gu and Zhile Ren + +""" + +import torch +import torch.nn as nn +from torch.autograd import Variable +import os +os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory +#from correlation_package.modules.corr import Correlation +from spatial_correlation_sampler import SpatialCorrelationSampler + +import numpy as np + + +__all__ = [ + 'pwc_dc_net', 'pwc_dc_net_old' + ] + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.1)) + +def predict_flow(in_planes): + return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True) + +class CorrelationLayerSimple(nn.Module): + def __init__(self, md=4): + super(CorrelationLayerSimple,self).__init__() + self.md = md + self.model = SpatialCorrelationSampler(kernel_size=1, patch_size=md*2+1, + stride=1, dilation_patch=1, padding=0) + def forward(self, x1, x2): + output = self.model(x1, x2) + b, ph, pw, h, w = output.size() + return output.view(b, ph * pw, h, w) + +class PWCDCNetSmall(nn.Module): + """ + PWC-DC net. add dilation convolution and densenet connections + + """ + def __init__(self, md=4): + """ + input: md --- maximum displacement (for correlation. default: 4), after warpping + + """ + super(PWCDCNetSmall, self).__init__() + + self.conv1a = conv(1, 32, kernel_size=3, stride=2) + self.conv1aa = conv(32, 32, kernel_size=3, stride=1) + self.conv1b = conv(32, 32, kernel_size=3, stride=1) + self.conv2a = conv(32, 32, kernel_size=3, stride=2) + self.conv2aa = conv(32, 32, kernel_size=3, stride=1) + self.conv2b = conv(32, 32, kernel_size=3, stride=1) + self.conv3a = conv(32, 32, kernel_size=3, stride=2) + self.conv3aa = conv(32, 32, kernel_size=3, stride=1) + self.conv3b = conv(32, 32, kernel_size=3, stride=1) + self.conv4a = conv(32, 32, kernel_size=3, stride=2) + self.conv4aa = conv(32, 32, kernel_size=3, stride=1) + self.conv4b = conv(32, 32, kernel_size=3, stride=1) + self.conv5a = conv(32, 32, kernel_size=3, stride=2) + self.conv5aa = conv(32, 32, kernel_size=3, stride=1) + self.conv5b = conv(32, 32, kernel_size=3, stride=1) + + # self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) + + self.corr = CorrelationLayerSimple(md=md) + + self.leakyRELU = nn.LeakyReLU(0.1) + + nd = (2*md+1)**2 + #dd = np.cumsum([128,128,96,64,32]) + dd = np.cumsum([32,32,32,32,32]) + + od = nd + self.conv5_0 = conv(od, 32, kernel_size=3, stride=1) + self.conv5_1 = conv(od+dd[0], 32, kernel_size=3, stride=1) + self.conv5_2 = conv(od+dd[1], 32, kernel_size=3, stride=1) + self.conv5_3 = conv(od+dd[2], 32, kernel_size=3, stride=1) + self.conv5_4 = conv(od+dd[3], 32, kernel_size=3, stride=1) + self.predict_flow5 = predict_flow(od+dd[4]) + self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+32+4 + self.conv4_0 = conv(od, 32, kernel_size=3, stride=1) + self.conv4_1 = conv(od+dd[0],32, kernel_size=3, stride=1) + self.conv4_2 = conv(od+dd[1],32, kernel_size=3, stride=1) + self.conv4_3 = conv(od+dd[2],32, kernel_size=3, stride=1) + self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow4 = predict_flow(od+dd[4]) + self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+32+4 + self.conv3_0 = conv(od, 32, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0],32, kernel_size=3, stride=1) + self.conv3_2 = conv(od+dd[1],32, kernel_size=3, stride=1) + self.conv3_3 = conv(od+dd[2],32, kernel_size=3, stride=1) + self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow3 = predict_flow(od+dd[4]) + self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+32+4 + self.conv2_0 = conv(od, 32, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0],32, kernel_size=3, stride=1) + self.conv2_2 = conv(od+dd[1],32, kernel_size=3, stride=1) + self.conv2_3 = conv(od+dd[2],32, kernel_size=3, stride=1) + self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow2 = predict_flow(od+dd[4]) + self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + + self.dc_conv1 = conv(od+dd[4], 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv2 = conv(32, 32, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc_conv3 = conv(32, 32, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc_conv4 = conv(32, 32, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc_conv5 = conv(32, 32, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc_conv6 = conv(32, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv7 = predict_flow(32) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight.data, mode='fan_in') + if m.bias is not None: + m.bias.data.zero_() + + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + + """ + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + grid = torch.cat((xx,yy),1).float() + + if x.is_cuda: + grid = grid.to(x.get_device()) + vgrid = Variable(grid) + flo + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 + + vgrid = vgrid.permute(0,2,3,1) + output = nn.functional.grid_sample(x, vgrid) + mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device()) + mask = nn.functional.grid_sample(mask, vgrid) + + # if W==128: + # np.save('mask.npy', mask.cpu().data.numpy()) + # np.save('warp.npy', output.cpu().data.numpy()) + + mask[mask<0.9999] = 0 + mask[mask>0] = 1 + + return output*mask + + + def forward(self, im1, im2): + #im1 = x[:,:3,:,:] + #im2 = x[:,3:,:,:] + + c11 = self.conv1b(self.conv1aa(self.conv1a(im1))) + c21 = self.conv1b(self.conv1aa(self.conv1a(im2))) + c12 = self.conv2b(self.conv2aa(self.conv2a(c11))) + c22 = self.conv2b(self.conv2aa(self.conv2a(c21))) + c13 = self.conv3b(self.conv3aa(self.conv3a(c12))) + c23 = self.conv3b(self.conv3aa(self.conv3a(c22))) + c14 = self.conv4b(self.conv4aa(self.conv4a(c13))) + c24 = self.conv4b(self.conv4aa(self.conv4a(c23))) + c15 = self.conv5b(self.conv5aa(self.conv5a(c14))) + c25 = self.conv5b(self.conv5aa(self.conv5a(c24))) + #c16 = self.conv6b(self.conv6a(self.conv6aa(c15))) + #c26 = self.conv6b(self.conv6a(self.conv6aa(c25))) + + corr5 = self.corr(c15, c25) + corr5 = self.leakyRELU(corr5) + + # x = torch.cat((self.conv6_0(corr6), corr6),1) + # x = torch.cat((self.conv6_1(x), x),1) + # x = torch.cat((self.conv6_2(x), x),1) + # x = torch.cat((self.conv6_3(x), x),1) + # x = torch.cat((self.conv6_4(x), x),1) + # flow6 = self.predict_flow6(x) + # up_flow6 = self.deconv6(flow6) + # up_feat6 = self.upfeat6(x) + + # warp5 = self.warp(c25, up_flow6*0.625) + # corr5 = self.corr(c15, warp5) + # corr5 = self.leakyRELU(corr5) + # x = torch.cat((corr5, c15, up_flow6, up_feat6), 1) + + x = torch.cat((self.conv5_0(corr5), corr5), 1) + #x = torch.cat((self.conv5_0(x), x),1) + x = torch.cat((self.conv5_1(x), x),1) + x = torch.cat((self.conv5_2(x), x),1) + x = torch.cat((self.conv5_3(x), x),1) + x = torch.cat((self.conv5_4(x), x),1) + flow5 = self.predict_flow5(x) + up_flow5 = self.deconv5(flow5) + up_feat5 = self.upfeat5(x) + + warp4 = self.warp(c24, up_flow5*1.25) + corr4 = self.corr(c14, warp4) + corr4 = self.leakyRELU(corr4) + x = torch.cat((corr4, c14, up_flow5, up_feat5), 1) + x = torch.cat((self.conv4_0(x), x),1) + x = torch.cat((self.conv4_1(x), x),1) + x = torch.cat((self.conv4_2(x), x),1) + x = torch.cat((self.conv4_3(x), x),1) + x = torch.cat((self.conv4_4(x), x),1) + flow4 = self.predict_flow4(x) + up_flow4 = self.deconv4(flow4) + up_feat4 = self.upfeat4(x) + + + warp3 = self.warp(c23, up_flow4*2.5) + corr3 = self.corr(c13, warp3) + corr3 = self.leakyRELU(corr3) + + + x = torch.cat((corr3, c13, up_flow4, up_feat4), 1) + x = torch.cat((self.conv3_0(x), x),1) + x = torch.cat((self.conv3_1(x), x),1) + x = torch.cat((self.conv3_2(x), x),1) + x = torch.cat((self.conv3_3(x), x),1) + x = torch.cat((self.conv3_4(x), x),1) + flow3 = self.predict_flow3(x) + up_flow3 = self.deconv3(flow3) + up_feat3 = self.upfeat3(x) + + + warp2 = self.warp(c22, up_flow3*5.0) + corr2 = self.corr(c12, warp2) + corr2 = self.leakyRELU(corr2) + x = torch.cat((corr2, c12, up_flow3, up_feat3), 1) + x = torch.cat((self.conv2_0(x), x),1) + x = torch.cat((self.conv2_1(x), x),1) + x = torch.cat((self.conv2_2(x), x),1) + x = torch.cat((self.conv2_3(x), x),1) + x = torch.cat((self.conv2_4(x), x),1) + flow2 = self.predict_flow2(x) + + x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) + flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x))) + + if self.training: + return flow2, flow3, flow4, flow5 + else: + return flow2 + + +class PWCDCNet(nn.Module): + """ + PWC-DC net. add dilation convolution and densenet connections + + """ + def __init__(self, md=4): + """ + input: md --- maximum displacement (for correlation. default: 4), after warpping + + """ + super(PWCDCNet,self).__init__() + + self.conv1a = conv(1, 16, kernel_size=3, stride=2) + self.conv1aa = conv(16, 16, kernel_size=3, stride=1) + self.conv1b = conv(16, 16, kernel_size=3, stride=1) + self.conv2a = conv(16, 32, kernel_size=3, stride=2) + self.conv2aa = conv(32, 32, kernel_size=3, stride=1) + self.conv2b = conv(32, 32, kernel_size=3, stride=1) + self.conv3a = conv(32, 64, kernel_size=3, stride=2) + self.conv3aa = conv(64, 64, kernel_size=3, stride=1) + self.conv3b = conv(64, 64, kernel_size=3, stride=1) + self.conv4a = conv(64, 96, kernel_size=3, stride=2) + self.conv4aa = conv(96, 96, kernel_size=3, stride=1) + self.conv4b = conv(96, 96, kernel_size=3, stride=1) + self.conv5a = conv(96, 128, kernel_size=3, stride=2) + self.conv5aa = conv(128,128, kernel_size=3, stride=1) + self.conv5b = conv(128,128, kernel_size=3, stride=1) + self.conv6aa = conv(128,196, kernel_size=3, stride=2) + self.conv6a = conv(196,196, kernel_size=3, stride=1) + self.conv6b = conv(196,196, kernel_size=3, stride=1) + + # self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) + + self.corr = CorrelationLayerSimple(md=md) + + self.leakyRELU = nn.LeakyReLU(0.1) + + nd = (2*md+1)**2 + dd = np.cumsum([128,128,96,64,32]) + + od = nd + self.conv6_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv6_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv6_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv6_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow6 = predict_flow(od+dd[4]) + self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+128+4 + self.conv5_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow5 = predict_flow(od+dd[4]) + self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+96+4 + self.conv4_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow4 = predict_flow(od+dd[4]) + self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+64+4 + self.conv3_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow3 = predict_flow(od+dd[4]) + self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+32+4 + self.conv2_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow2 = predict_flow(od+dd[4]) + self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + + self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv7 = predict_flow(32) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight.data, mode='fan_in') + if m.bias is not None: + m.bias.data.zero_() + + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + + """ + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + grid = torch.cat((xx,yy),1).float() + + if x.is_cuda: + grid = grid.cuda(x.get_device()) + vgrid = Variable(grid) + flo + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 + + vgrid = vgrid.permute(0,2,3,1) + output = nn.functional.grid_sample(x, vgrid) + mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device()) + mask = nn.functional.grid_sample(mask, vgrid) + + # if W==128: + # np.save('mask.npy', mask.cpu().data.numpy()) + # np.save('warp.npy', output.cpu().data.numpy()) + + mask[mask<0.9999] = 0 + mask[mask>0] = 1 + + return output #*mask + + + #def forward(self, im1, im2): + def forward(self, x): + N, C, H, W = x.shape + im1 = x[:,:C//2,:,:] + im2 = x[:,C//2:,:,:] + + c11 = self.conv1b(self.conv1aa(self.conv1a(im1))) + c21 = self.conv1b(self.conv1aa(self.conv1a(im2))) + c12 = self.conv2b(self.conv2aa(self.conv2a(c11))) + c22 = self.conv2b(self.conv2aa(self.conv2a(c21))) + c13 = self.conv3b(self.conv3aa(self.conv3a(c12))) + c23 = self.conv3b(self.conv3aa(self.conv3a(c22))) + c14 = self.conv4b(self.conv4aa(self.conv4a(c13))) + c24 = self.conv4b(self.conv4aa(self.conv4a(c23))) + c15 = self.conv5b(self.conv5aa(self.conv5a(c14))) + c25 = self.conv5b(self.conv5aa(self.conv5a(c24))) + c16 = self.conv6b(self.conv6a(self.conv6aa(c15))) + c26 = self.conv6b(self.conv6a(self.conv6aa(c25))) + + + corr6 = self.corr(c16, c26) + corr6 = self.leakyRELU(corr6) + + + x = torch.cat((self.conv6_0(corr6), corr6),1) + x = torch.cat((self.conv6_1(x), x),1) + x = torch.cat((self.conv6_2(x), x),1) + x = torch.cat((self.conv6_3(x), x),1) + x = torch.cat((self.conv6_4(x), x),1) + flow6 = self.predict_flow6(x) + up_flow6 = self.deconv6(flow6) + up_feat6 = self.upfeat6(x) + + + warp5 = self.warp(c25, up_flow6*0.625) + corr5 = self.corr(c15, warp5) + corr5 = self.leakyRELU(corr5) + x = torch.cat((corr5, c15, up_flow6, up_feat6), 1) + x = torch.cat((self.conv5_0(x), x),1) + x = torch.cat((self.conv5_1(x), x),1) + x = torch.cat((self.conv5_2(x), x),1) + x = torch.cat((self.conv5_3(x), x),1) + x = torch.cat((self.conv5_4(x), x),1) + flow5 = self.predict_flow5(x) + up_flow5 = self.deconv5(flow5) + up_feat5 = self.upfeat5(x) + + warp4 = self.warp(c24, up_flow5*1.25) + corr4 = self.corr(c14, warp4) + corr4 = self.leakyRELU(corr4) + x = torch.cat((corr4, c14, up_flow5, up_feat5), 1) + x = torch.cat((self.conv4_0(x), x),1) + x = torch.cat((self.conv4_1(x), x),1) + x = torch.cat((self.conv4_2(x), x),1) + x = torch.cat((self.conv4_3(x), x),1) + x = torch.cat((self.conv4_4(x), x),1) + flow4 = self.predict_flow4(x) + up_flow4 = self.deconv4(flow4) + up_feat4 = self.upfeat4(x) + + + warp3 = self.warp(c23, up_flow4*2.5) + corr3 = self.corr(c13, warp3) + corr3 = self.leakyRELU(corr3) + + + x = torch.cat((corr3, c13, up_flow4, up_feat4), 1) + x = torch.cat((self.conv3_0(x), x),1) + x = torch.cat((self.conv3_1(x), x),1) + x = torch.cat((self.conv3_2(x), x),1) + x = torch.cat((self.conv3_3(x), x),1) + x = torch.cat((self.conv3_4(x), x),1) + flow3 = self.predict_flow3(x) + up_flow3 = self.deconv3(flow3) + up_feat3 = self.upfeat3(x) + + + warp2 = self.warp(c22, up_flow3*5.0) + corr2 = self.corr(c12, warp2) + corr2 = self.leakyRELU(corr2) + x = torch.cat((corr2, c12, up_flow3, up_feat3), 1) + x = torch.cat((self.conv2_0(x), x),1) + x = torch.cat((self.conv2_1(x), x),1) + x = torch.cat((self.conv2_2(x), x),1) + x = torch.cat((self.conv2_3(x), x),1) + x = torch.cat((self.conv2_4(x), x),1) + flow2 = self.predict_flow2(x) + + x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) + flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x))) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return flow2 + + +class PWCDCNetFullRes(nn.Module): + """ + PWC-DC net. add dilation convolution and densenet connections + + """ + def __init__(self, md=4): + """ + input: md --- maximum displacement (for correlation. default: 4), after warpping + + """ + super(PWCDCNetFullRes,self).__init__() + + #self.conv1a = conv(1, 16, kernel_size=3, stride=2) + self.conv1a = conv(1, 16, kernel_size=3, stride=1) + self.conv1aa = conv(16, 16, kernel_size=3, stride=1) + self.conv1b = conv(16, 16, kernel_size=3, stride=1) + self.conv2a = conv(16, 32, kernel_size=3, stride=2) + self.conv2aa = conv(32, 32, kernel_size=3, stride=1) + self.conv2b = conv(32, 32, kernel_size=3, stride=1) + self.conv3a = conv(32, 64, kernel_size=3, stride=2) + self.conv3aa = conv(64, 64, kernel_size=3, stride=1) + self.conv3b = conv(64, 64, kernel_size=3, stride=1) + self.conv4a = conv(64, 96, kernel_size=3, stride=2) + self.conv4aa = conv(96, 96, kernel_size=3, stride=1) + self.conv4b = conv(96, 96, kernel_size=3, stride=1) + self.conv5a = conv(96, 128, kernel_size=3, stride=2) + self.conv5aa = conv(128,128, kernel_size=3, stride=1) + self.conv5b = conv(128,128, kernel_size=3, stride=1) + self.conv6aa = conv(128,196, kernel_size=3, stride=2) + self.conv6a = conv(196,196, kernel_size=3, stride=1) + self.conv6b = conv(196,196, kernel_size=3, stride=1) + + # self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) + + self.corr = CorrelationLayerSimple(md=md) + + self.leakyRELU = nn.LeakyReLU(0.1) + + nd = (2*md+1)**2 + dd = np.cumsum([128,128,96,64,32]) + + od = nd + self.conv6_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv6_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv6_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv6_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow6 = predict_flow(od+dd[4]) + self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+128+4 + self.conv5_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow5 = predict_flow(od+dd[4]) + self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+96+4 + self.conv4_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow4 = predict_flow(od+dd[4]) + self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+64+4 + self.conv3_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow3 = predict_flow(od+dd[4]) + self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+32+4 + self.conv2_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow2 = predict_flow(od+dd[4]) + self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat2 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+16+4 + self.conv1_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv1_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv1_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv1_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv1_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow1 = predict_flow(od+dd[4]) + self.deconv1 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat1 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + + self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv7 = predict_flow(32) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight.data, mode='fan_in') + if m.bias is not None: + m.bias.data.zero_() + + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + + """ + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + grid = torch.cat((xx,yy),1).float() + + if x.is_cuda: + grid = grid.cuda(x.get_device()) + vgrid = Variable(grid) + flo + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 + + vgrid = vgrid.permute(0,2,3,1) + output = nn.functional.grid_sample(x, vgrid) + mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device()) + mask = nn.functional.grid_sample(mask, vgrid) + + # if W==128: + # np.save('mask.npy', mask.cpu().data.numpy()) + # np.save('warp.npy', output.cpu().data.numpy()) + + mask[mask<0.9999] = 0 + mask[mask>0] = 1 + + return output #*mask + + + def forward(self, im1, im2): + #im1 = x[:,:3,:,:] + #im2 = x[:,3:,:,:] + + c11 = self.conv1b(self.conv1aa(self.conv1a(im1))) + c21 = self.conv1b(self.conv1aa(self.conv1a(im2))) + c12 = self.conv2b(self.conv2aa(self.conv2a(c11))) + c22 = self.conv2b(self.conv2aa(self.conv2a(c21))) + c13 = self.conv3b(self.conv3aa(self.conv3a(c12))) + c23 = self.conv3b(self.conv3aa(self.conv3a(c22))) + c14 = self.conv4b(self.conv4aa(self.conv4a(c13))) + c24 = self.conv4b(self.conv4aa(self.conv4a(c23))) + c15 = self.conv5b(self.conv5aa(self.conv5a(c14))) + c25 = self.conv5b(self.conv5aa(self.conv5a(c24))) + c16 = self.conv6b(self.conv6a(self.conv6aa(c15))) + c26 = self.conv6b(self.conv6a(self.conv6aa(c25))) + + + corr6 = self.corr(c16, c26) + corr6 = self.leakyRELU(corr6) + + + x = torch.cat((self.conv6_0(corr6), corr6),1) + x = torch.cat((self.conv6_1(x), x),1) + x = torch.cat((self.conv6_2(x), x),1) + x = torch.cat((self.conv6_3(x), x),1) + x = torch.cat((self.conv6_4(x), x),1) + flow6 = self.predict_flow6(x) + up_flow6 = self.deconv6(flow6) + up_feat6 = self.upfeat6(x) + + + warp5 = self.warp(c25, up_flow6*0.625) + corr5 = self.corr(c15, warp5) + corr5 = self.leakyRELU(corr5) + x = torch.cat((corr5, c15, up_flow6, up_feat6), 1) + x = torch.cat((self.conv5_0(x), x),1) + x = torch.cat((self.conv5_1(x), x),1) + x = torch.cat((self.conv5_2(x), x),1) + x = torch.cat((self.conv5_3(x), x),1) + x = torch.cat((self.conv5_4(x), x),1) + flow5 = self.predict_flow5(x) + up_flow5 = self.deconv5(flow5) + up_feat5 = self.upfeat5(x) + + warp4 = self.warp(c24, up_flow5*1.25) + corr4 = self.corr(c14, warp4) + corr4 = self.leakyRELU(corr4) + x = torch.cat((corr4, c14, up_flow5, up_feat5), 1) + x = torch.cat((self.conv4_0(x), x),1) + x = torch.cat((self.conv4_1(x), x),1) + x = torch.cat((self.conv4_2(x), x),1) + x = torch.cat((self.conv4_3(x), x),1) + x = torch.cat((self.conv4_4(x), x),1) + flow4 = self.predict_flow4(x) + up_flow4 = self.deconv4(flow4) + up_feat4 = self.upfeat4(x) + + + warp3 = self.warp(c23, up_flow4*2.5) + corr3 = self.corr(c13, warp3) + corr3 = self.leakyRELU(corr3) + x = torch.cat((corr3, c13, up_flow4, up_feat4), 1) + x = torch.cat((self.conv3_0(x), x),1) + x = torch.cat((self.conv3_1(x), x),1) + x = torch.cat((self.conv3_2(x), x),1) + x = torch.cat((self.conv3_3(x), x),1) + x = torch.cat((self.conv3_4(x), x),1) + flow3 = self.predict_flow3(x) + up_flow3 = self.deconv3(flow3) + up_feat3 = self.upfeat3(x) + + + warp2 = self.warp(c22, up_flow3*5.0) + corr2 = self.corr(c12, warp2) + corr2 = self.leakyRELU(corr2) + x = torch.cat((corr2, c12, up_flow3, up_feat3), 1) + x = torch.cat((self.conv2_0(x), x),1) + x = torch.cat((self.conv2_1(x), x),1) + x = torch.cat((self.conv2_2(x), x),1) + x = torch.cat((self.conv2_3(x), x),1) + x = torch.cat((self.conv2_4(x), x),1) + flow2 = self.predict_flow2(x) + up_flow2 = self.deconv2(flow2) + up_feat2 = self.upfeat2(x) + + warp1 = self.warp(c21, up_flow2*10.0) + corr1 = self.corr(c11, warp1) + corr1 = self.leakyRELU(corr1) + x = torch.cat((corr1, c11, up_flow2, up_feat2), 1) + x = torch.cat((self.conv1_0(x), x),1) + x = torch.cat((self.conv1_1(x), x),1) + x = torch.cat((self.conv1_2(x), x),1) + x = torch.cat((self.conv1_3(x), x),1) + x = torch.cat((self.conv1_4(x), x),1) + flow1 = self.predict_flow1(x) + + x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) + flow1 = flow1 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x))) + + if self.training: + return flow1,flow2,flow3,flow4,flow5,flow6 + else: + return flow1 + + +class PWCNetSmallFullRes(nn.Module): + """ + PWC-DC net. add dilation convolution and densenet connections + + """ + def __init__(self, md=4): + """ + input: md --- maximum displacement (for correlation. default: 4), after warpping + + """ + super(PWCNetSmallFullRes,self).__init__() + + #self.conv1a = conv(1, 16, kernel_size=3, stride=2) + self.conv1a = conv(2, 16, kernel_size=3, stride=1) + self.conv1aa = conv(16, 16, kernel_size=3, stride=1) + self.conv1b = conv(16, 16, kernel_size=3, stride=1) + self.conv2a = conv(16, 32, kernel_size=3, stride=2) + self.conv2aa = conv(32, 32, kernel_size=3, stride=1) + self.conv2b = conv(32, 32, kernel_size=3, stride=1) + self.conv3a = conv(32, 64, kernel_size=3, stride=2) + self.conv3aa = conv(64, 64, kernel_size=3, stride=1) + self.conv3b = conv(64, 64, kernel_size=3, stride=1) + + # self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) + # self.corr = CorrelationLayerSimple(md=md) + + self.leakyRELU = nn.LeakyReLU(0.1) + + nd = 64 #(2*md+1)**2 + dd = np.cumsum([64,32]) + + od = nd + self.conv3_0 = conv(od, 64, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0], 32, kernel_size=3, stride=1) + self.predict_flow3 = predict_flow(od+dd[1]) + self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat3 = deconv(od+dd[1], 2, kernel_size=4, stride=2, padding=1) + + od = nd+32+4 + self.conv2_0 = conv(od, 64, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0],32, kernel_size=3, stride=1) + self.predict_flow2 = predict_flow(od+dd[1]) + self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat2 = deconv(od+dd[1], 2, kernel_size=4, stride=2, padding=1) + + od = nd+16+4 + self.conv1_0 = conv(od, 64, kernel_size=3, stride=1) + self.conv1_1 = conv(od+dd[0], 32, kernel_size=3, stride=1) + self.predict_flow1 = predict_flow(od+dd[1]) + self.deconv1 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat1 = deconv(od+dd[1], 2, kernel_size=4, stride=2, padding=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight.data, mode='fan_in') + if m.bias is not None: + m.bias.data.zero_() + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + + """ + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + grid = torch.cat((xx,yy),1).float() + + if x.is_cuda: + grid = grid.cuda(x.get_device()) + vgrid = Variable(grid) + flo + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 + + vgrid = vgrid.permute(0,2,3,1) + output = nn.functional.grid_sample(x, vgrid) + mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device()) + mask = nn.functional.grid_sample(mask, vgrid) + + # if W==128: + # np.save('mask.npy', mask.cpu().data.numpy()) + # np.save('warp.npy', output.cpu().data.numpy()) + + mask[mask<0.9999] = 0 + mask[mask>0] = 1 + + return output #*mask + + + def forward(self, im1, im2): + #im1 = x[:,:3,:,:] + #im2 = x[:,3:,:,:] + x = torch.cat([im1, im2], 1) + + c11 = self.conv1b(self.conv1aa(self.conv1a(x))) + c12 = self.conv2b(self.conv2aa(self.conv2a(c11))) + c13 = self.conv3b(self.conv3aa(self.conv3a(c12))) + + x = torch.cat((self.conv3_0(c13), c13),1) + x = torch.cat((self.conv3_1(x), x),1) + flow3 = self.predict_flow3(x) + up_flow3 = self.deconv3(flow3) + up_feat3 = self.upfeat3(x) + + x = torch.cat((c12, up_flow3, up_feat3), 1) + x = torch.cat((self.conv2_0(x), x),1) + x = torch.cat((self.conv2_1(x), x),1) + flow2 = self.predict_flow2(x) + up_flow2 = self.deconv2(flow2) + up_feat2 = self.upfeat2(x) + + x = torch.cat((c11, up_flow2, up_feat2), 1) + x = torch.cat((self.conv1_0(x), x),1) + x = torch.cat((self.conv1_1(x), x),1) + flow1 = self.predict_flow1(x) + + if self.training: + return flow1,flow2,flow3 + else: + return flow1 + + +class PWCDCNet_old(nn.Module): + """ + PWC-DC net. add dilation convolution and densenet connections + + """ + def __init__(self, md=4): + """ + input: md --- maximum displacement (for correlation. default: 4), after warpping + + """ + super(PWCDCNet_old,self).__init__() + + self.conv1a = conv(3, 16, kernel_size=3, stride=2) + self.conv1b = conv(16, 16, kernel_size=3, stride=1) + self.conv2a = conv(16, 32, kernel_size=3, stride=2) + self.conv2b = conv(32, 32, kernel_size=3, stride=1) + self.conv3a = conv(32, 64, kernel_size=3, stride=2) + self.conv3b = conv(64, 64, kernel_size=3, stride=1) + self.conv4a = conv(64, 96, kernel_size=3, stride=2) + self.conv4b = conv(96, 96, kernel_size=3, stride=1) + self.conv5a = conv(96, 128, kernel_size=3, stride=2) + self.conv5b = conv(128,128, kernel_size=3, stride=1) + self.conv6a = conv(128,196, kernel_size=3, stride=2) + self.conv6b = conv(196,196, kernel_size=3, stride=1) + + self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) + self.leakyRELU = nn.LeakyReLU(0.1) + + nd = (2*md+1)**2 + dd = np.cumsum([128,128,96,64,32]) + + od = nd + self.conv6_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv6_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv6_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv6_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow6 = predict_flow(od+dd[4]) + self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+128+4 + self.conv5_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow5 = predict_flow(od+dd[4]) + self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+96+4 + self.conv4_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow4 = predict_flow(od+dd[4]) + self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+64+4 + self.conv3_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow3 = predict_flow(od+dd[4]) + self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + od = nd+32+4 + self.conv2_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.predict_flow2 = predict_flow(od+dd[4]) + self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + + self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv7 = predict_flow(32) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal_(m.weight.data, mode='fan_in') + if m.bias is not None: + m.bias.data.zero_() + + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + + """ + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + grid = torch.cat((xx,yy),1).float() + + if x.is_cuda: + grid = grid.cuda() + vgrid = Variable(grid) + flo + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 + + vgrid = vgrid.permute(0,2,3,1) + output = nn.functional.grid_sample(x, vgrid) + mask = torch.autograd.Variable(torch.ones(x.size())).cuda() + mask = nn.functional.grid_sample(mask, vgrid) + + mask[mask<0.999] = 0 + mask[mask>0] = 1 + + return output*mask + + + def forward(self,x): + im1 = x[:,:3,:,:] + im2 = x[:,3:,:,:] + + c11 = self.conv1b(self.conv1a(im1)) + c21 = self.conv1b(self.conv1a(im2)) + c12 = self.conv2b(self.conv2a(c11)) + c22 = self.conv2b(self.conv2a(c21)) + c13 = self.conv3b(self.conv3a(c12)) + c23 = self.conv3b(self.conv3a(c22)) + c14 = self.conv4b(self.conv4a(c13)) + c24 = self.conv4b(self.conv4a(c23)) + c15 = self.conv5b(self.conv5a(c14)) + c25 = self.conv5b(self.conv5a(c24)) + c16 = self.conv6b(self.conv6a(c15)) + c26 = self.conv6b(self.conv6a(c25)) + + corr6 = self.corr(c16, c26) + corr6 = self.leakyRELU(corr6) + x = torch.cat((corr6, self.conv6_0(corr6)),1) + x = torch.cat((self.conv6_1(x), x),1) + x = torch.cat((x, self.conv6_2(x)),1) + x = torch.cat((x, self.conv6_3(x)),1) + x = torch.cat((x, self.conv6_4(x)),1) + flow6 = self.predict_flow6(x) + up_flow6 = self.deconv6(flow6) + up_feat6 = self.upfeat6(x) + + warp5 = self.warp(c25, up_flow6*0.625) + corr5 = self.corr(c15, warp5) + corr5 = self.leakyRELU(corr5) + x = torch.cat((corr5, c15, up_flow6, up_feat6), 1) + x = torch.cat((x, self.conv5_0(x)),1) + x = torch.cat((self.conv5_1(x), x),1) + x = torch.cat((x, self.conv5_2(x)),1) + x = torch.cat((x, self.conv5_3(x)),1) + x = torch.cat((x, self.conv5_4(x)),1) + flow5 = self.predict_flow5(x) + up_flow5 = self.deconv5(flow5) + up_feat5 = self.upfeat5(x) + + warp4 = self.warp(c24, up_flow5*1.25) + corr4 = self.corr(c14, warp4) + corr4 = self.leakyRELU(corr4) + x = torch.cat((corr4, c14, up_flow5, up_feat5), 1) + x = torch.cat((x, self.conv4_0(x)),1) + x = torch.cat((self.conv4_1(x), x),1) + x = torch.cat((x, self.conv4_2(x)),1) + x = torch.cat((x, self.conv4_3(x)),1) + x = torch.cat((x, self.conv4_4(x)),1) + flow4 = self.predict_flow4(x) + up_flow4 = self.deconv4(flow4) + up_feat4 = self.upfeat4(x) + + warp3 = self.warp(c23, up_flow4*2.5) + corr3 = self.corr(c13, warp3) + corr3 = self.leakyRELU(corr3) + x = torch.cat((corr3, c13, up_flow4, up_feat4), 1) + x = torch.cat((x, self.conv3_0(x)),1) + x = torch.cat((self.conv3_1(x), x),1) + x = torch.cat((x, self.conv3_2(x)),1) + x = torch.cat((x, self.conv3_3(x)),1) + x = torch.cat((x, self.conv3_4(x)),1) + flow3 = self.predict_flow3(x) + up_flow3 = self.deconv3(flow3) + up_feat3 = self.upfeat3(x) + + warp2 = self.warp(c22, up_flow3*5.0) + corr2 = self.corr(c12, warp2) + corr2 = self.leakyRELU(corr2) + x = torch.cat((corr2, c12, up_flow3, up_feat3), 1) + x = torch.cat((x, self.conv2_0(x)),1) + x = torch.cat((self.conv2_1(x), x),1) + x = torch.cat((x, self.conv2_2(x)),1) + x = torch.cat((x, self.conv2_3(x)),1) + x = torch.cat((x, self.conv2_4(x)),1) + flow2 = self.predict_flow2(x) + + x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) + flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x))) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return flow2 + + + + + +def pwc_dc_net(path=None): + + model = PWCDCNet() + if path is not None: + data = torch.load(path) + if 'state_dict' in data.keys(): + model.load_state_dict(data['state_dict']) + else: + model.load_state_dict(data) + return model + + + + +def pwc_dc_net_old(path=None): + + model = PWCDCNet_old() + if path is not None: + data = torch.load(path) + if 'state_dict' in data.keys(): + model.load_state_dict(data['state_dict']) + else: + model.load_state_dict(data) + return model diff --git a/windflow/networks/__init__.py b/windflow/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8aad493343bf538715a691bc64f6a825b5b19d7 --- /dev/null +++ b/windflow/networks/__init__.py @@ -0,0 +1,4 @@ +from .raft import raft +from .raft.raft import RAFT +# from .flownet import * +# from .maskflownet import MaskFlownet diff --git a/windflow/networks/flownet/FlowNet2.py b/windflow/networks/flownet/FlowNet2.py new file mode 100755 index 0000000000000000000000000000000000000000..45ac9f8c9198d29b7d92f01c2006e8bb31c1996d --- /dev/null +++ b/windflow/networks/flownet/FlowNet2.py @@ -0,0 +1,501 @@ +import torch +import torch.nn as nn +from torch.nn import init + +import math +import numpy as np + +from ...external_packages.channelnorm_package.channelnorm import ChannelNorm +from ...external_packages.resample2d_package.resample2d import Resample2d + +from . import FlowNetC, FlowNetS, FlowNetSD, FlowNetFusion + +from .submodules import * + + +'Parameter count = 162,518,834' + +class FlowNet2(nn.Module): + + def __init__(self, input_channels=1, batchNorm=False, div_flow = 20., fp16=False): + super(FlowNet2,self).__init__() + self.batchNorm = batchNorm + self.div_flow = div_flow + self.rgb_max = 1 #args.rgb_max + self.input_channels = input_channels + + self.channelnorm = ChannelNorm() + + # First Block (FlowNetC) + self.flownetc = FlowNetC.FlowNetC(input_channels=self.input_channels, + batchNorm=self.batchNorm, fp16=fp16) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + + if fp16: + self.resample1 = nn.Sequential( + tofp32(), + Resample2d(), + tofp16()) + else: + self.resample1 = Resample2d() + + # Block (FlowNetS1) + s_input_channels = input_channels*4 + 2 + self.flownets_1 = FlowNetS.FlowNetS(input_channels=s_input_channels, + batchNorm=self.batchNorm) + self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear') + if fp16: + self.resample2 = nn.Sequential( + tofp32(), + Resample2d(), + tofp16()) + else: + self.resample2 = Resample2d() + + + # Block (FlowNetS2) + self.flownets_2 = FlowNetS.FlowNetS(input_channels=s_input_channels, + batchNorm=self.batchNorm) + + # Block (FlowNetSD) + self.flownets_d = FlowNetSD.FlowNetSD(input_channels=input_channels*2, + batchNorm=self.batchNorm) + self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest') + self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest') + + if fp16: + self.resample3 = nn.Sequential( + tofp32(), + Resample2d(), + tofp16()) + else: + self.resample3 = Resample2d() + + if fp16: + self.resample4 = nn.Sequential( + tofp32(), + Resample2d(), + tofp16()) + else: + self.resample4 = Resample2d() + + # Block (FLowNetFusion) + fusion_input_channels = self.input_channels + 8 + self.flownetfusion = FlowNetFusion.FlowNetFusion(input_channels=fusion_input_channels, + batchNorm=self.batchNorm) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + + def init_deconv_bilinear(self, weight): + f_shape = weight.size() + heigh, width = f_shape[-2], f_shape[-1] + f = np.ceil(width/2.0) + c = (2 * f - 1 - f % 2) / (2.0 * f) + bilinear = np.zeros([heigh, width]) + for x in range(width): + for y in range(heigh): + value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) + bilinear[x, y] = value + min_dim = min(f_shape[0], f_shape[1]) + weight.data.fill_(0.) + for i in range(min_dim): + weight.data[i,i,:,:] = torch.from_numpy(bilinear) + return + + def forward(self, x): + #rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,)) + + #x = inputs #(inputs - rgb_mean) / self.rgb_max + #x1 = x[:,:,0,:,:] + #x2 = x[:,:,1,:,:] + #x = torch.cat((x1,x2), dim = 1) + + # flownetc + flownetc_flow2 = self.flownetc(x)[0] + flownetc_flow = self.upsample1(flownetc_flow2*self.div_flow) + + # warp img1 to img0; magnitude of diff between img0 and and warped_img1, + resampled_img1 = self.resample1(x[:,self.input_channels:,:,:], flownetc_flow) + diff_img0 = x[:,:self.input_channels,:,:] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + + # concat img0, img1, img1->img0, flow, diff-mag ; + concat1 = torch.cat((x, resampled_img1, flownetc_flow/self.div_flow, norm_diff_img0), dim=1) + + # flownets1 + flownets1_flow2 = self.flownets_1(concat1)[0] + flownets1_flow = self.upsample2(flownets1_flow2*self.div_flow) + + # warp img1 to img0 using flownets1; magnitude of diff between img0 and and warped_img1 + resampled_img1 = self.resample2(x[:,self.input_channels:,:,:], flownets1_flow) + diff_img0 = x[:,:self.input_channels,:,:] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + + # concat img0, img1, img1->img0, flow, diff-mag + concat2 = torch.cat((x, resampled_img1, flownets1_flow/self.div_flow, norm_diff_img0), dim=1) + + # flownets2 + flownets2_flow2 = self.flownets_2(concat2)[0] + flownets2_flow = self.upsample4(flownets2_flow2 * self.div_flow) + norm_flownets2_flow = self.channelnorm(flownets2_flow) + + diff_flownets2_flow = self.resample4(x[:,self.input_channels:,:,:], flownets2_flow) + # if not diff_flownets2_flow.volatile: + # diff_flownets2_flow.register_hook(save_grad(self.args.grads, 'diff_flownets2_flow')) + + diff_flownets2_img1 = self.channelnorm((x[:,:self.input_channels,:,:]-diff_flownets2_flow)) + # if not diff_flownets2_img1.volatile: + # diff_flownets2_img1.register_hook(save_grad(self.args.grads, 'diff_flownets2_img1')) + + # flownetsd + flownetsd_flow2 = self.flownets_d(x)[0] + flownetsd_flow = self.upsample3(flownetsd_flow2 / self.div_flow) + norm_flownetsd_flow = self.channelnorm(flownetsd_flow) + + diff_flownetsd_flow = self.resample3(x[:,self.input_channels:,:,:], flownetsd_flow) + # if not diff_flownetsd_flow.volatile: + # diff_flownetsd_flow.register_hook(save_grad(self.args.grads, 'diff_flownetsd_flow')) + + diff_flownetsd_img1 = self.channelnorm((x[:,:self.input_channels,:,:]-diff_flownetsd_flow)) + # if not diff_flownetsd_img1.volatile: + # diff_flownetsd_img1.register_hook(save_grad(self.args.grads, 'diff_flownetsd_img1')) + + # concat img1 flownetsd, flownets2, norm_flownetsd, norm_flownets2, diff_flownetsd_img1, diff_flownets2_img1 + concat3 = torch.cat((x[:,:self.input_channels,:,:], flownetsd_flow, flownets2_flow, norm_flownetsd_flow, norm_flownets2_flow, diff_flownetsd_img1, diff_flownets2_img1), dim=1) + flownetfusion_flow = self.flownetfusion(concat3) + + # if not flownetfusion_flow.volatile: + # flownetfusion_flow.register_hook(save_grad(self.args.grads, 'flownetfusion_flow')) + + return flownetfusion_flow, + +class FlowNet2C(FlowNetC.FlowNetC): + def __init__(self, batchNorm=False, div_flow=20): + super(FlowNet2C,self).__init__(batchNorm=batchNorm, div_flow=20) + self.rgb_max = 1 #args.rgb_max + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,)) + + x = inputs #(inputs - rgb_mean) / self.rgb_max + x1 = x[:,:,0,:,:] + x2 = x[:,:,1,:,:] + + # FlownetC top input stream + out_conv1a = self.conv1(x1) + out_conv2a = self.conv2(out_conv1a) + out_conv3a = self.conv3(out_conv2a) + + # FlownetC bottom input stream + out_conv1b = self.conv1(x2) + + out_conv2b = self.conv2(out_conv1b) + out_conv3b = self.conv3(out_conv2b) + + # Merge streams + out_corr = self.corr(out_conv3a, out_conv3b) # False + out_corr = self.corr_activation(out_corr) + + # Redirect top input stream and concatenate + out_conv_redir = self.conv_redir(out_conv3a) + + in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) + + # Merged conv layers + out_conv3_1 = self.conv3_1(in_conv3_1) + + out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) + + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) + + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) + + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1) + + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) + + flow2 = self.predict_flow2(concat2) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return self.upsample1(flow2*self.div_flow) + +class FlowNet2S(FlowNetS.FlowNetS): + def __init__(self, batchNorm=False, div_flow=20): + super(FlowNet2S,self).__init__(input_channels = 6, batchNorm=batchNorm) + self.rgb_max = 1 #args.rgb_max + self.div_flow = div_flow + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,)) + x = (inputs - rgb_mean) / self.rgb_max + x = torch.cat( (x[:,:,0,:,:], x[:,:,1,:,:]), dim = 1) + + out_conv1 = self.conv1(x) + + out_conv2 = self.conv2(out_conv1) + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + + concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + + concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + + concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) + flow2 = self.predict_flow2(concat2) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return self.upsample1(flow2*self.div_flow) + +class FlowNet2SD(FlowNetSD.FlowNetSD): + def __init__(self, batchNorm=False, div_flow=20): + super(FlowNet2SD,self).__init__(batchNorm=batchNorm) + self.rgb_max = 1 #args.rgb_max + self.div_flow = div_flow + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,)) + x = inputs #(inputs - rgb_mean) / self.rgb_max + x = torch.cat( (x[:,:,0,:,:], x[:,:,1,:,:]), dim = 1) + + out_conv0 = self.conv0(x) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) + out_interconv5 = self.inter_conv5(concat5) + flow5 = self.predict_flow5(out_interconv5) + + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + + concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) + out_interconv4 = self.inter_conv4(concat4) + flow4 = self.predict_flow4(out_interconv4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + + concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) + out_interconv3 = self.inter_conv3(concat3) + flow3 = self.predict_flow3(out_interconv3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + + concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) + out_interconv2 = self.inter_conv2(concat2) + flow2 = self.predict_flow2(out_interconv2) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return self.upsample1(flow2*self.div_flow) + +class FlowNet2CS(nn.Module): + + def __init__(self, batchNorm=False, div_flow = 20., fp16=False): + super(FlowNet2CS,self).__init__() + self.batchNorm = batchNorm + self.div_flow = div_flow + self.rgb_max = 1 #args.rgb_max + self.args = args + + self.channelnorm = ChannelNorm() + + # First Block (FlowNetC) + self.flownetc = FlowNetC.FlowNetC(batchNorm=self.batchNorm, fp16=fp16) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + + if fp16: + self.resample1 = nn.Sequential( + tofp32(), + Resample2d(), + tofp16()) + else: + self.resample1 = Resample2d() + + # Block (FlowNetS1) + self.flownets_1 = FlowNetS.FlowNetS(batchNorm=self.batchNorm) + self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear') + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + # init_deconv_bilinear(m.weight) + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,)) + + x = inputs #(inputs - rgb_mean) / self.rgb_max + x1 = x[:,:,0,:,:] + x2 = x[:,:,1,:,:] + x = torch.cat((x1,x2), dim = 1) + + # flownetc + flownetc_flow2 = self.flownetc(x)[0] + flownetc_flow = self.upsample1(flownetc_flow2*self.div_flow) + + # warp img1 to img0; magnitude of diff between img0 and and warped_img1, + resampled_img1 = self.resample1(x[:,3:,:,:], flownetc_flow) + diff_img0 = x[:,:3,:,:] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + + # concat img0, img1, img1->img0, flow, diff-mag ; + concat1 = torch.cat((x, resampled_img1, flownetc_flow/self.div_flow, norm_diff_img0), dim=1) + + # flownets1 + flownets1_flow2 = self.flownets_1(concat1)[0] + flownets1_flow = self.upsample2(flownets1_flow2*self.div_flow) + + return flownets1_flow + +class FlowNet2CSS(nn.Module): + + def __init__(self, batchNorm=False, div_flow = 20., fp16=False): + super(FlowNet2CSS,self).__init__() + self.batchNorm = batchNorm + self.div_flow = div_flow + self.rgb_max = 1 #args.rgb_max + self.args = args + + self.channelnorm = ChannelNorm() + + # First Block (FlowNetC) + self.flownetc = FlowNetC.FlowNetC(args, batchNorm=self.batchNorm) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + + if fp16: + self.resample1 = nn.Sequential( + tofp32(), + Resample2d(), + tofp16()) + else: + self.resample1 = Resample2d() + + # Block (FlowNetS1) + self.flownets_1 = FlowNetS.FlowNetS(args, batchNorm=self.batchNorm) + self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear') + if fp16: + self.resample2 = nn.Sequential( + tofp32(), + Resample2d(), + tofp16()) + else: + self.resample2 = Resample2d() + + + # Block (FlowNetS2) + self.flownets_2 = FlowNetS.FlowNetS(args, batchNorm=self.batchNorm) + self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest') + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + # init_deconv_bilinear(m.weight) + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,)) + + x = inputs #(inputs - rgb_mean) / self.rgb_max + x1 = x[:,:,0,:,:] + x2 = x[:,:,1,:,:] + x = torch.cat((x1,x2), dim = 1) + + # flownetc + flownetc_flow2 = self.flownetc(x)[0] + flownetc_flow = self.upsample1(flownetc_flow2*self.div_flow) + + # warp img1 to img0; magnitude of diff between img0 and and warped_img1, + resampled_img1 = self.resample1(x[:,3:,:,:], flownetc_flow) + diff_img0 = x[:,:3,:,:] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + + # concat img0, img1, img1->img0, flow, diff-mag ; + concat1 = torch.cat((x, resampled_img1, flownetc_flow/self.div_flow, norm_diff_img0), dim=1) + + # flownets1 + flownets1_flow2 = self.flownets_1(concat1)[0] + flownets1_flow = self.upsample2(flownets1_flow2*self.div_flow) + + # warp img1 to img0 using flownets1; magnitude of diff between img0 and and warped_img1 + resampled_img1 = self.resample2(x[:,3:,:,:], flownets1_flow) + diff_img0 = x[:,:3,:,:] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + + # concat img0, img1, img1->img0, flow, diff-mag + concat2 = torch.cat((x, resampled_img1, flownets1_flow/self.div_flow, norm_diff_img0), dim=1) + + # flownets2 + flownets2_flow2 = self.flownets_2(concat2)[0] + flownets2_flow = self.upsample3(flownets2_flow2 * self.div_flow) + + return flownets2_flow + diff --git a/windflow/networks/flownet/FlowNetC.py b/windflow/networks/flownet/FlowNetC.py new file mode 100755 index 0000000000000000000000000000000000000000..3aee083ef3174543fbb6eca5ca65340b8693ccd8 --- /dev/null +++ b/windflow/networks/flownet/FlowNetC.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +from torch.nn import init + +import math +import numpy as np + +from ...external_packages.correlation_package.correlation import Correlation + +from .submodules import * +'Parameter count , 39,175,298 ' + +class FlowNetC(nn.Module): + def __init__(self, input_channels=1, batchNorm=True, div_flow = 20, fp16=False): + super(FlowNetC,self).__init__() + + self.batchNorm = batchNorm + self.div_flow = div_flow + self.input_channels = input_channels + + self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2) + self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) + self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) + self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) + + if fp16: + self.corr = nn.Sequential( + tofp32(), + Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1), + tofp16()) + else: + self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) + + self.corr_activation = nn.LeakyReLU(0.1,inplace=True) + self.conv3_1 = conv(self.batchNorm, 473, 256) + self.conv4 = conv(self.batchNorm, 256, 512, stride=2) + self.conv4_1 = conv(self.batchNorm, 512, 512) + self.conv5 = conv(self.batchNorm, 512, 512, stride=2) + self.conv5_1 = conv(self.batchNorm, 512, 512) + self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) + self.conv6_1 = conv(self.batchNorm,1024, 1024) + + self.deconv5 = deconv(1024,512) + self.deconv4 = deconv(1026,256) + self.deconv3 = deconv(770,128) + self.deconv2 = deconv(386,64) + + self.predict_flow6 = predict_flow(1024) + self.predict_flow5 = predict_flow(1026) + self.predict_flow4 = predict_flow(770) + self.predict_flow3 = predict_flow(386) + self.predict_flow2 = predict_flow(194) + + self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) + self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) + self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) + self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + + def forward(self, x): + x1 = x[:,0:self.input_channels,:,:] + x2 = x[:,self.input_channels:,:,:] + + out_conv1a = self.conv1(x1) + out_conv2a = self.conv2(out_conv1a) + out_conv3a = self.conv3(out_conv2a) + + # FlownetC bottom input stream + out_conv1b = self.conv1(x2) + + out_conv2b = self.conv2(out_conv1b) + out_conv3b = self.conv3(out_conv2b) + + # Merge streams + #out_corr = self.corr(out_conv3a, out_conv3b) # False + out_corr = self.corr(out_conv3a, out_conv3b) # False + out_corr = self.corr_activation(out_corr) + + # Redirect top input stream and concatenate + out_conv_redir = self.conv_redir(out_conv3a) + + in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) + + # Merged conv layers + out_conv3_1 = self.conv3_1(in_conv3_1) + + out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) + + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) + + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) + + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1) + + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) + + flow2 = self.predict_flow2(concat2) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return flow2, diff --git a/windflow/networks/flownet/FlowNetFusion.py b/windflow/networks/flownet/FlowNetFusion.py new file mode 100755 index 0000000000000000000000000000000000000000..81b079ae1bb919affc5780b9dfcdc763936fbc5a --- /dev/null +++ b/windflow/networks/flownet/FlowNetFusion.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from torch.nn import init + +import math +import numpy as np + +from .submodules import * +'Parameter count = 581,226' + +class FlowNetFusion(nn.Module): + def __init__(self, input_channels=11, batchNorm=True): + super(FlowNetFusion,self).__init__() + + self.batchNorm = batchNorm + self.conv0 = conv(self.batchNorm, input_channels, 64) + self.conv1 = conv(self.batchNorm, 64, 64, stride=2) + self.conv1_1 = conv(self.batchNorm, 64, 128) + self.conv2 = conv(self.batchNorm, 128, 128, stride=2) + self.conv2_1 = conv(self.batchNorm, 128, 128) + + self.deconv1 = deconv(128,32) + self.deconv0 = deconv(162,16) + + self.inter_conv1 = i_conv(self.batchNorm, 162, 32) + self.inter_conv0 = i_conv(self.batchNorm, 82, 16) + + self.predict_flow2 = predict_flow(128) + self.predict_flow1 = predict_flow(32) + self.predict_flow0 = predict_flow(16) + + self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + + def forward(self, x): + out_conv0 = self.conv0(x) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + + flow2 = self.predict_flow2(out_conv2) + flow2_up = self.upsampled_flow2_to_1(flow2) + out_deconv1 = self.deconv1(out_conv2) + + concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1) + out_interconv1 = self.inter_conv1(concat1) + flow1 = self.predict_flow1(out_interconv1) + flow1_up = self.upsampled_flow1_to_0(flow1) + out_deconv0 = self.deconv0(concat1) + + concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1) + out_interconv0 = self.inter_conv0(concat0) + flow0 = self.predict_flow0(out_interconv0) + + return flow0 + diff --git a/windflow/networks/flownet/FlowNetHR.py b/windflow/networks/flownet/FlowNetHR.py new file mode 100644 index 0000000000000000000000000000000000000000..d5eb9f9512851c896f0fc8aee7a20db9fcc149f8 --- /dev/null +++ b/windflow/networks/flownet/FlowNetHR.py @@ -0,0 +1,92 @@ +''' +Portions of this code copyright 2017, Clement Pinard +''' + +import torch +import torch.nn as nn +from torch.nn import init + +import math +import numpy as np + +from .flownet_submodules import * + +#def deconv(in_planes, out_planes): +# return nn.Sequential( +# nn.Conv2d(in_planes, out_planes*4, kernel_size=1, bias=True), +# nn.PixelShuffle(2), +# nn.Conv2d(out_planes, out_planes, kernel_size=3, padding=1, bias=True), +# nn.LeakyReLU(0.1,inplace=True) +# ) + +class FlowNetHR(nn.Module): + def __init__(self, args, input_channels = 1, batchNorm=True, partial=False): + super(FlowNetHR, self).__init__() + + self.batchNorm = batchNorm + self.conv0 = conv(self.batchNorm, input_channels*2, 32, kernel_size=7, partial=partial) + self.conv0_1 = conv(self.batchNorm, 32, 32, partial=partial) + self.conv1 = conv(self.batchNorm, 32, 64, kernel_size=5, stride=2, partial=partial) + self.conv1_1 = conv(self.batchNorm, 64, 64, partial=partial) + self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2, partial=partial) + self.conv2_1 = conv(self.batchNorm, 128, 128, partial=partial) + self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2, partial=partial) + self.conv3_1 = conv(self.batchNorm, 256, 256, partial=partial) + + self.deconv2 = deconv(256, 64) + self.deconv1 = deconv(194, 32) + self.deconv0 = deconv(98, 16) + + self.predict_flow3 = predict_flow(256, partial=partial) + self.predict_flow2 = predict_flow(194, partial=partial) + self.predict_flow1 = predict_flow(98, partial=partial) + self.predict_flow0 = predict_flow(50, partial=partial) + + self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) + self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) + #self.upsampled_flow3_to_2 = nn.UpsamplingBilinear2d(scale_factor=2) + #self.upsampled_flow2_to_1 = nn.UpsamplingBilinear2d(scale_factor=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample_bilinear = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, x1, x2): + x = torch.cat([x1,x2],1) + + out_conv0 = self.conv0_1(self.conv0(x)) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) # 64 + out_conv2 = self.conv2_1(self.conv2(out_conv1)) # 128 + out_conv3 = self.conv3_1(self.conv3(out_conv2)) # 256 + + flow3 = self.predict_flow3(out_conv3) # 2 + flow3_up = self.upsample_bilinear(flow3) # 2 + out_deconv2 = self.deconv2(out_conv3) # 64 + + concat = torch.cat((out_conv2, flow3_up, out_deconv2), 1) # 128 + 64 + 2 = 194 + flow2 = self.predict_flow2(concat) + flow2_up = self.upsample_bilinear(flow2) + out_deconv1 = self.deconv1(concat) + + concat = torch.cat((out_conv1, flow2_up, out_deconv1), 1) # 64 + 32 + 2 = 98 + flow1 = self.predict_flow1(concat) + flow1_up = self.upsample_bilinear(flow1) + out_deconv0 = self.deconv0(concat) + + concat = torch.cat((out_conv0, flow1_up, out_deconv0), 1) # 32 + 16 + 2 = 50 + flow0 = self.predict_flow0(concat) + + if self.training: + return flow0,flow1,flow2,flow3 + else: + return flow0, + diff --git a/windflow/networks/flownet/FlowNetPartial.py b/windflow/networks/flownet/FlowNetPartial.py new file mode 100644 index 0000000000000000000000000000000000000000..92145c8035ef8977cb683d1c829d8b9dd25ae811 --- /dev/null +++ b/windflow/networks/flownet/FlowNetPartial.py @@ -0,0 +1,95 @@ +''' +Portions of this code copyright 2017, Clement Pinard +''' + +import torch +import torch.nn as nn +from torch.nn import init + +import math +import numpy as np + +from .flownet_submodules import * +from .layers import partial + +def conv_partial(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, partial=True): + if batchNorm: + return nn.Sequential( + partial.Conv2d_partial(in_planes, out_planes, kernel_size=kernel_size, stride=stride, bias=False, partial=partial), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.1,inplace=True) + ) + else: + return nn.Sequential( + partial.Conv2d_partial(in_planes, out_planes, kernel_size=kernel_size, stride=stride, partial=partial, bias=True), + nn.LeakyReLU(0.1,inplace=True) + ) + + +class FlowNetPartial(nn.Module): + def __init__(self, args, input_channels = 1, batchNorm=False, partial=True): + super(FlowNetHR, self).__init__() + self.batchNorm = batchNorm + self.conv0 = conv_partial(self.batchNorm, input_channels*2, 32, kernel_size=7, partial=True) + self.conv0_1 = conv_partial(self.batchNorm, 32, 32, partial=True) + self.conv1 = conv_partial(self.batchNorm, 32, 64, kernel_size=5, stride=2, partial=True) + self.conv1_1 = conv_partial(self.batchNorm, 64, 64, partial=True) + self.conv2 = conv_partial(self.batchNorm, 64, 128, kernel_size=5, stride=2, partial=True) + self.conv2_1 = conv_partial(self.batchNorm, 128, 128, partial=True) + self.conv3 = conv_partial(self.batchNorm, 128, 256, kernel_size=5, stride=2, partial=True) + self.conv3_1 = conv_partial(self.batchNorm, 256, 256, partial=True) + + self.deconv2 = deconv(256, 64) + self.deconv1 = deconv(194, 32) + self.deconv0 = deconv(98, 16) + + self.predict_flow3 = predict_flow(256) + self.predict_flow2 = predict_flow(194) + self.predict_flow1 = predict_flow(98) + self.predict_flow0 = predict_flow(50) + + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample_bilinear = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, x1, x2): + x = torch.cat([x1,x2],1) + + out_conv0 = self.conv0_1(self.conv0(x)) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) # 64 + out_conv2 = self.conv2_1(self.conv2(out_conv1)) # 128 + out_conv3 = self.conv3_1(self.conv3(out_conv2)) # 256 + + + flow3 = self.predict_flow3(out_conv3) # 2 + flow3_up = self.upsample_bilinear(flow3) # 2 + out_deconv2 = self.deconv2(out_conv3) # 64 + + concat = torch.cat((out_conv2, flow3_up, out_deconv2), 1) # 128 + 64 + 2 = 194 + flow2 = self.predict_flow2(concat) + flow2_up = self.upsample_bilinear(flow2) + out_deconv1 = self.deconv1(concat) + + concat = torch.cat((out_conv1, flow2_up, out_deconv1), 1) # 64 + 32 + 2 = 98 + flow1 = self.predict_flow1(concat) + flow1_up = self.upsample_bilinear(flow1) + out_deconv0 = self.deconv0(concat) + + concat = torch.cat((out_conv0, flow1_up, out_deconv0), 1) # 32 + 16 + 2 = 50 + flow0 = self.predict_flow0(concat) + + if self.training: + return flow0,flow1,flow2,flow3 + else: + return flow0, + diff --git a/windflow/networks/flownet/FlowNetS.py b/windflow/networks/flownet/FlowNetS.py new file mode 100755 index 0000000000000000000000000000000000000000..9288cddb5474d3deb89440ea1bc3eb6c65798836 --- /dev/null +++ b/windflow/networks/flownet/FlowNetS.py @@ -0,0 +1,95 @@ +''' +Portions of this code copyright 2017, Clement Pinard +''' + +import torch +import torch.nn as nn +from torch.nn import init + +import math +import numpy as np + +from .submodules import * +'Parameter count : 38,676,504 ' + +class FlowNetS(nn.Module): + def __init__(self, input_channels = 12, batchNorm=True): + super(FlowNetS,self).__init__() + + self.batchNorm = batchNorm + self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2) + self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) + self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) + self.conv3_1 = conv(self.batchNorm, 256, 256) + self.conv4 = conv(self.batchNorm, 256, 512, stride=2) + self.conv4_1 = conv(self.batchNorm, 512, 512) + self.conv5 = conv(self.batchNorm, 512, 512, stride=2) + self.conv5_1 = conv(self.batchNorm, 512, 512) + self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) + self.conv6_1 = conv(self.batchNorm,1024, 1024) + + self.deconv5 = deconv(1024,512) + self.deconv4 = deconv(1026,256) + self.deconv3 = deconv(770,128) + self.deconv2 = deconv(386,64) + + self.predict_flow6 = predict_flow(1024) + self.predict_flow5 = predict_flow(1026) + self.predict_flow4 = predict_flow(770) + self.predict_flow3 = predict_flow(386) + self.predict_flow2 = predict_flow(194) + + self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) + self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) + self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) + self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + + def forward(self, x): + out_conv1 = self.conv1(x) + + out_conv2 = self.conv2(out_conv1) + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + + concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + + concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + + concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) + flow2 = self.predict_flow2(concat2) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return flow2, + diff --git a/windflow/networks/flownet/FlowNetSD.py b/windflow/networks/flownet/FlowNetSD.py new file mode 100755 index 0000000000000000000000000000000000000000..0f9eff0d9d269b39bafa307194bf2e7879d089d1 --- /dev/null +++ b/windflow/networks/flownet/FlowNetSD.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +from torch.nn import init + +import math +import numpy as np + +from .submodules import * +'Parameter count = 45,371,666' + +class FlowNetSD(nn.Module): + def __init__(self, input_channels=2, batchNorm=True): + super(FlowNetSD,self).__init__() + + self.batchNorm = batchNorm + self.conv0 = conv(self.batchNorm, input_channels, 64) + self.conv1 = conv(self.batchNorm, 64, 64, stride=2) + self.conv1_1 = conv(self.batchNorm, 64, 128) + self.conv2 = conv(self.batchNorm, 128, 128, stride=2) + self.conv2_1 = conv(self.batchNorm, 128, 128) + self.conv3 = conv(self.batchNorm, 128, 256, stride=2) + self.conv3_1 = conv(self.batchNorm, 256, 256) + self.conv4 = conv(self.batchNorm, 256, 512, stride=2) + self.conv4_1 = conv(self.batchNorm, 512, 512) + self.conv5 = conv(self.batchNorm, 512, 512, stride=2) + self.conv5_1 = conv(self.batchNorm, 512, 512) + self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) + self.conv6_1 = conv(self.batchNorm,1024, 1024) + + self.deconv5 = deconv(1024,512) + self.deconv4 = deconv(1026,256) + self.deconv3 = deconv(770,128) + self.deconv2 = deconv(386,64) + + self.inter_conv5 = i_conv(self.batchNorm, 1026, 512) + self.inter_conv4 = i_conv(self.batchNorm, 770, 256) + self.inter_conv3 = i_conv(self.batchNorm, 386, 128) + self.inter_conv2 = i_conv(self.batchNorm, 194, 64) + + self.predict_flow6 = predict_flow(1024) + self.predict_flow5 = predict_flow(512) + self.predict_flow4 = predict_flow(256) + self.predict_flow3 = predict_flow(128) + self.predict_flow2 = predict_flow(64) + + self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') + + + + def forward(self, x): + out_conv0 = self.conv0(x) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) + out_interconv5 = self.inter_conv5(concat5) + flow5 = self.predict_flow5(out_interconv5) + + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + + concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) + out_interconv4 = self.inter_conv4(concat4) + flow4 = self.predict_flow4(out_interconv4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + + concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) + out_interconv3 = self.inter_conv3(concat3) + flow3 = self.predict_flow3(out_interconv3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + + concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) + out_interconv2 = self.inter_conv2(concat2) + flow2 = self.predict_flow2(out_interconv2) + + if self.training: + return flow2,flow3,flow4,flow5,flow6 + else: + return flow2, diff --git a/windflow/networks/flownet/__init__.py b/windflow/networks/flownet/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..bbe2eadb45c8e9cebe36ba1c595eb209afef653c --- /dev/null +++ b/windflow/networks/flownet/__init__.py @@ -0,0 +1 @@ +from . import FlowNetS, FlowNetC, FlowNetSD, FlowNetFusion, FlowNet2 \ No newline at end of file diff --git a/windflow/networks/flownet/submodules.py b/windflow/networks/flownet/submodules.py new file mode 100755 index 0000000000000000000000000000000000000000..dd6aac3d581faa86968d1522edf41b2a8b7fc1a0 --- /dev/null +++ b/windflow/networks/flownet/submodules.py @@ -0,0 +1,92 @@ +# freda (todo) : + +import torch.nn as nn +import torch +import numpy as np + +def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): + if batchNorm: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.1,inplace=True) + ) + else: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), + nn.LeakyReLU(0.1,inplace=True) + ) + +def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True): + if batchNorm: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), + nn.BatchNorm2d(out_planes), + ) + else: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), + ) + +def predict_flow(in_planes): + return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) + +def deconv(in_planes, out_planes): + return nn.Sequential( + nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), + nn.LeakyReLU(0.1,inplace=True) + ) + +class tofp16(nn.Module): + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +class tofp32(nn.Module): + def __init__(self): + super(tofp32, self).__init__() + + def forward(self, input): + return input.float() + + +def init_deconv_bilinear(weight): + f_shape = weight.size() + heigh, width = f_shape[-2], f_shape[-1] + f = np.ceil(width/2.0) + c = (2 * f - 1 - f % 2) / (2.0 * f) + bilinear = np.zeros([heigh, width]) + for x in range(width): + for y in range(heigh): + value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) + bilinear[x, y] = value + weight.data.fill_(0.) + for i in range(f_shape[0]): + for j in range(f_shape[1]): + weight.data[i,j,:,:] = torch.from_numpy(bilinear) + + +def save_grad(grads, name): + def hook(grad): + grads[name] = grad + return hook + +''' +def save_grad(grads, name): + def hook(grad): + grads[name] = grad + return hook +import torch +from channelnorm_package.modules.channelnorm import ChannelNorm +model = ChannelNorm().cuda() +grads = {} +a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True) +a.register_hook(save_grad(grads, 'a')) +b = model(a) +y = torch.mean(b) +y.backward() + +''' diff --git a/windflow/networks/layers/partial.py b/windflow/networks/layers/partial.py new file mode 100644 index 0000000000000000000000000000000000000000..90cccd8b0caa4cdbbde75e517c0e2bb1778e1423 --- /dev/null +++ b/windflow/networks/layers/partial.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Conv2d_partial(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, partial=False): + super(Conv2d_partial, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + + self.partial = partial + + def forward(self, input): + if self.partial: + self.padding = 0 + + pad_val = (self.kernel_size[0] - 1) // 2 + if pad_val > 0: + if (self.kernel_size[0] - self.stride[0]) % 2 == 0: + pad_top = pad_val + pad_bottom = pad_val + pad_left = pad_val + pad_right = pad_val + else: + pad_top = pad_val + pad_bottom = self.kernel_size[0] - self.stride[0] - pad_top + pad_left = pad_val + pad_right = self.kernel_size[0] - self.stride[0] - pad_left + + p0 = torch.ones_like(input) + p0 = p0.sum() + + input = F.pad(input, (pad_left, pad_right, pad_top, pad_bottom) , mode='constant', value=0) + + p1 = torch.ones_like(input) + p1 = p1.sum() + + ratio = torch.div(p1, p0 + 1e-8) + input = torch.mul(input, ratio) + + return F.conv2d(input, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) diff --git a/windflow/networks/maskflownet.py b/windflow/networks/maskflownet.py new file mode 100644 index 0000000000000000000000000000000000000000..2c57c427a01a33a1b34736ca65a2abc685a568a4 --- /dev/null +++ b/windflow/networks/maskflownet.py @@ -0,0 +1,724 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from torchvision import ops +# ops.DeformConv2d() + +#from .correlation_package.correlation import Correlation +from spatial_correlation_sampler import SpatialCorrelationSampler + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, activation=True): + if activation: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.1)) + else: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True)) + + +def predict_flow(in_planes): + return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=True) + + +def predict_mask(in_planes): + return nn.Conv2d(in_planes, 1, kernel_size=3, stride=1, padding=1, bias=True) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True) + + +def deformable_conv(in_planes, out_planes, kernel_size=3, strides=1, padding=1, use_bias=True): + return ops.DeformConv2d(in_planes, out_planes, kernel_size, strides, padding, bias=use_bias) + + +def upsample_kernel2d(w, device): + c = w // 2 + kernel = 1 - torch.abs(c - torch.arange(w, dtype=torch.float32, device=device)) / (c + 1) + kernel = kernel.repeat(w).view(w,-1) * kernel.unsqueeze(1) + return kernel.view(1, 1, w, w) + + +def downsample_kernel2d(w, device): + kernel = ((w + 1) - torch.abs(w - torch.arange(w * 2 + 1, dtype=torch.float32, device=device))) / (2 * w + 1) + kernel = kernel.repeat(w).view(w,-1) * kernel.unsqueeze(1) + return kernel.view(1, 1, w * 2 + 1, w * 2 + 1) + +class CorrelationLayerSimple(nn.Module): + def __init__(self, md=4): + super(CorrelationLayerSimple,self).__init__() + self.md = md + self.model = SpatialCorrelationSampler(kernel_size=1, patch_size=md*2+1, + stride=1, dilation_patch=1, padding=0) + def forward(self, x1, x2): + output = self.model(x1, x2) + b, ph, pw, h, w = output.size() + return output.view(b, ph * pw, h, w) + +def Upsample(img, factor): + if factor == 1: + return img + B, C, H, W = img.shape + batch_img = img.view(B*C, 1, H, W) + batch_img = F.pad(batch_img, [0, 1, 0, 1], mode='replicate') + kernel = upsample_kernel2d(factor * 2 - 1, img.device) + upsamp_img = F.conv_transpose2d(batch_img, kernel, stride=factor, padding=(factor-1)) + upsamp_img = upsamp_img[:, :, : -1, :-1] + _, _, H_up, W_up = upsamp_img.shape + return upsamp_img.view(B, C, H_up, W_up) + + +def Downsample(img, factor): + if factor == 1: + return img + B, C, H, W = img.shape + batch_img = img.view(B*C, 1, H, W) + kernel = downsample_kernel2d(factor // 2, img.device) + upsamp_img = F.conv2d(batch_img, kernel, stride=factor, padding=factor//2) + upsamp_nom = F.conv2d(torch.ones_like(batch_img), kernel, stride=factor, padding=factor//2) + _, _, H_up, W_up = upsamp_img.shape + upsamp_img = upsamp_img.view(B, C, H_up, W_up) + upsamp_nom = upsamp_nom.view(B, C, H_up, W_up) + return upsamp_img / upsamp_nom + + + +class MaskFlownet_S(nn.Module): + """ + PWC-DC net. add dilation convolution and densenet connections + """ + def __init__(self, in_ch=1, **kwargs): + """ + input: md --- maximum displacement (for correlation. default: 4), after warpping + """ + super(MaskFlownet_S, self).__init__() + self.scale = 1. # 20 * config.network.flow_multiplier.get(1.) + md = 4 + self.md = md + self.in_ch = in_ch + self.strides = [64, 32, 16, 8, 4] + self.deform_bias = True #config.network.deform_bias.get(True) + self.upfeat_ch = [16, 16, 16, 16] #config.network.upfeat_ch.get([16, 16, 16, 16]) + + self.conv1a = conv(in_ch, 16, kernel_size=3, stride=2) + self.conv1b = conv(16, 16, kernel_size=3, stride=1) + self.conv1c = conv(16, 16, kernel_size=3, stride=1) + self.conv2a = conv(16, 32, kernel_size=3, stride=2) + self.conv2b = conv(32, 32, kernel_size=3, stride=1) + self.conv2c = conv(32, 32, kernel_size=3, stride=1) + self.conv3a = conv(32, 64, kernel_size=3, stride=2) + self.conv3b = conv(64, 64, kernel_size=3, stride=1) + self.conv3c = conv(64, 64, kernel_size=3, stride=1) + self.conv4a = conv(64, 96, kernel_size=3, stride=2) + self.conv4b = conv(96, 96, kernel_size=3, stride=1) + self.conv4c = conv(96, 96, kernel_size=3, stride=1) + self.conv5a = conv(96, 128, kernel_size=3, stride=2) + self.conv5b = conv(128, 128, kernel_size=3, stride=1) + self.conv5c = conv(128, 128, kernel_size=3, stride=1) + self.conv6a = conv(128, 196, kernel_size=3, stride=2) + self.conv6b = conv(196, 196, kernel_size=3, stride=1) + self.conv6c = conv(196, 196, kernel_size=3, stride=1) + + #self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) + self.corr = CorrelationLayerSimple(md=md) + + self.leakyRELU = nn.LeakyReLU(0.1) + + nd = (2*md+1)**2 + dd = np.cumsum([128,128,96,64,32]) + + od = nd + self.conv6_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv6_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv6_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv6_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow6 = predict_flow(od + dd[4]) + self.pred_mask6 = predict_mask(od + dd[4]) + self.upfeat5 = deconv(od+dd[4], self.upfeat_ch[0], kernel_size=4, stride=2, padding=1) + # self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + # self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + # od = nd+128+4 + od = nd+128+18 + self.conv5_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow5 = predict_flow(od + dd[4]) + self.pred_mask5 = predict_mask(od + dd[4]) + self.upfeat4 = deconv(od+dd[4], self.upfeat_ch[1], kernel_size=4, stride=2, padding=1) + # self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + # self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + # od = nd+96+4 + od = nd+96+18 + self.conv4_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow4 = predict_flow(od + dd[4]) + self.pred_mask4 = predict_mask(od + dd[4]) + self.upfeat3 = deconv(od+dd[4], self.upfeat_ch[2], kernel_size=4, stride=2, padding=1) + # self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + # self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + # od = nd+64+4 + od = nd+64+18 + self.conv3_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow3 = predict_flow(od + dd[4]) + self.pred_mask3 = predict_mask(od + dd[4]) + self.upfeat2 = deconv(od+dd[4], self.upfeat_ch[3], kernel_size=4, stride=2, padding=1) + # self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + # self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + # od = nd+32+4 + od = nd+32+18 + self.conv2_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow2 = predict_flow(od + dd[4]) + + self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv7 = predict_flow(32) + + # self.upfeat5 = deconv() + + self.deform5 = deformable_conv(128, 128) + self.deform4 = deformable_conv(96, 96) + self.deform3 = deformable_conv(64, 64) + self.deform2 = deformable_conv(32, 32) + + self.conv5f = conv(16, 128, kernel_size=3, stride=1, padding=1, activation=False) + self.conv4f = conv(16, 96, kernel_size=3, stride=1, padding=1, activation=False) + self.conv3f = conv(16, 64, kernel_size=3, stride=1, padding=1, activation=False) + self.conv2f = conv(16, 32, kernel_size=3, stride=1, padding=1, activation=False) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.kaiming_normal(m.weight.data, mode='fan_in') + if m.bias is not None: + m.bias.data.zero_() + + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + """ + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + grid = torch.cat((xx,yy),1).float() + + device = x.device + grid = grid.to(device) + # vgrid = Variable(grid) + flo + vgrid = Variable(grid) + torch.flip(flo, [1]) + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 + + vgrid = vgrid.permute(0,2,3,1) + # vgrid = vgrid.permute(0,2,3,1).clamp(-1.1, 1.1) + output = nn.functional.grid_sample(x, vgrid, align_corners=True) + mask = torch.autograd.Variable(torch.ones(x.size())).to(device) + mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) + + # if W==128: + # np.save('mask.npy', mask.cpu().data.numpy()) + # np.save('warp.npy', output.cpu().data.numpy()) + + mask[mask<0.9999] = 0 + mask[mask>0] = 1 + + return output*mask + + + def forward(self, x): + im1 = x[:,:self.in_ch,:,:] + im2 = x[:,self.in_ch:,:,:] + + c11 = self.conv1c(self.conv1b(self.conv1a(im1))) + c21 = self.conv1c(self.conv1b(self.conv1a(im2))) + c12 = self.conv2c(self.conv2b(self.conv2a(c11))) + c22 = self.conv2c(self.conv2b(self.conv2a(c21))) + c13 = self.conv3c(self.conv3b(self.conv3a(c12))) + c23 = self.conv3c(self.conv3b(self.conv3a(c22))) + c14 = self.conv4c(self.conv4b(self.conv4a(c13))) + c24 = self.conv4c(self.conv4b(self.conv4a(c23))) + c15 = self.conv5c(self.conv5b(self.conv5a(c14))) + c25 = self.conv5c(self.conv5b(self.conv5a(c24))) + c16 = self.conv6c(self.conv6b(self.conv6a(c15))) + c26 = self.conv6c(self.conv6b(self.conv6a(c25))) + + + corr6 = self.corr(c16, c26) + corr6 = self.leakyRELU(corr6) + + + x = torch.cat((self.conv6_0(corr6), corr6),1) + x = torch.cat((self.conv6_1(x), x),1) + x = torch.cat((self.conv6_2(x), x),1) + x = torch.cat((self.conv6_3(x), x),1) + x = torch.cat((self.conv6_4(x), x),1) + flow6 = self.pred_flow6(x) + mask6 = self.pred_mask6(x) + + feat5 = self.leakyRELU(self.upfeat5(x)) + flow5 = Upsample(flow6, 2) + mask5 = Upsample(mask6, 2) + warp5 = (flow5*self.scale/self.strides[1]).unsqueeze(1) + warp5 = torch.repeat_interleave(warp5, 9, 1) + S1, S2, S3, S4, S5 = warp5.shape + warp5 = warp5.view(S1, S2*S3, S4, S5) + warp5 = self.deform5(c25, warp5) + tradeoff5 = feat5 + warp5 = (warp5 * F.sigmoid(mask5)) + self.conv5f(tradeoff5) + warp5 = self.leakyRELU(warp5) + corr5 = self.corr(c15, warp5) + corr5 = self.leakyRELU(corr5) + x = torch.cat((corr5, c15, feat5, flow5), 1) + x = torch.cat((self.conv5_0(x), x),1) + x = torch.cat((self.conv5_1(x), x),1) + x = torch.cat((self.conv5_2(x), x),1) + x = torch.cat((self.conv5_3(x), x),1) + x = torch.cat((self.conv5_4(x), x),1) + flow5 = flow5 + self.pred_flow5(x) + mask5 = self.pred_mask5(x) + + feat4 = self.leakyRELU(self.upfeat4(x)) + flow4 = Upsample(flow5, 2) + mask4 = Upsample(mask5, 2) + warp4 = (flow4*self.scale/self.strides[2]).unsqueeze(1) + warp4 = torch.repeat_interleave(warp4, 9, 1) + S1, S2, S3, S4, S5 = warp4.shape + warp4 = warp4.view(S1, S2*S3, S4, S5) + warp4 = self.deform4(c24, warp4) + tradeoff4 = feat4 + warp4 = (warp4 * F.sigmoid(mask4)) + self.conv4f(tradeoff4) + warp4 = self.leakyRELU(warp4) + corr4 = self.corr(c14, warp4) + corr4 = self.leakyRELU(corr4) + x = torch.cat((corr4, c14, feat4, flow4), 1) + x = torch.cat((self.conv4_0(x), x),1) + x = torch.cat((self.conv4_1(x), x),1) + x = torch.cat((self.conv4_2(x), x),1) + x = torch.cat((self.conv4_3(x), x),1) + x = torch.cat((self.conv4_4(x), x),1) + flow4 = flow4 + self.pred_flow4(x) + mask4 = self.pred_mask4(x) + + feat3 = self.leakyRELU(self.upfeat3(x)) + flow3 = Upsample(flow4, 2) + mask3 = Upsample(mask4, 2) + warp3 = (flow3*self.scale/self.strides[3]).unsqueeze(1) + warp3 = torch.repeat_interleave(warp3, 9, 1) + S1, S2, S3, S4, S5 = warp3.shape + warp3 = warp3.view(S1, S2*S3, S4, S5) + warp3 = self.deform3(c23, warp3) + tradeoff3 = feat3 + warp3 = (warp3 * F.sigmoid(mask3)) + self.conv3f(tradeoff3) + warp3 = self.leakyRELU(warp3) + corr3 = self.corr(c13, warp3) + corr3 = self.leakyRELU(corr3) + x = torch.cat((corr3, c13, feat3, flow3), 1) + x = torch.cat((self.conv3_0(x), x),1) + x = torch.cat((self.conv3_1(x), x),1) + x = torch.cat((self.conv3_2(x), x),1) + x = torch.cat((self.conv3_3(x), x),1) + x = torch.cat((self.conv3_4(x), x),1) + flow3 = flow3 + self.pred_flow3(x) + mask3 = self.pred_mask3(x) + + feat2 = self.leakyRELU(self.upfeat2(x)) + flow2 = Upsample(flow3, 2) + mask2 = Upsample(mask3, 2) + warp2 = (flow2*self.scale/self.strides[4]).unsqueeze(1) + warp2 = torch.repeat_interleave(warp2, 9, 1) + S1, S2, S3, S4, S5 = warp2.shape + warp2 = warp2.view(S1, S2*S3, S4, S5) + warp2 = self.deform2(c22, warp2) + tradeoff2 = feat2 + warp2 = (warp2 * F.sigmoid(mask2)) + self.conv2f(tradeoff2) + warp2 = self.leakyRELU(warp2) + corr2 = self.corr(c12, warp2) + corr2 = self.leakyRELU(corr2) + x = torch.cat((corr2, c12, feat2, flow2), 1) + x = torch.cat((self.conv2_0(x), x),1) + x = torch.cat((self.conv2_1(x), x),1) + x = torch.cat((self.conv2_2(x), x),1) + x = torch.cat((self.conv2_3(x), x),1) + x = torch.cat((self.conv2_4(x), x),1) + flow2 = flow2 + self.pred_flow2(x) + + x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) + flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x))) + + predictions = [flow * self.scale for flow in [flow6, flow5, flow4, flow3, flow2]] + occlusion_masks = [] + occlusion_masks.append(F.sigmoid(mask2)) + c1s = [c11, c12, c13, c14, c15, c16] + c2s = [c21, c12, c13, c24, c25, c26] + flows = [flow6, flow5, flow4, flow3, flow2] + mask0 = Upsample(mask2, 4) + mask0 = F.sigmoid(mask0) - 0.5 + c30 = im1 + c40 = self.warp(im2, Upsample(flow2, 4)*self.scale) + c30 = torch.cat((c30, torch.zeros_like(mask0)), 1) + c40 = torch.cat((c40, mask0), 1) + srcs = [c1s, c2s, flows, c30, c40] + return predictions, occlusion_masks, srcs + + +class MaskFlownet(nn.Module): + def __init__(self, in_ch=1, **kwargs): + super(MaskFlownet, self).__init__(**kwargs) + self.strides = [64, 32, 16, 8, 4] + self.md = 2 + self.scale = 1. #config.network.flow_multiplier.get(1.) + self.deform_bias = True #config.network.deform_bias.get(True) + self.upfeat_ch = [16, 16, 16, 16] #config.network.upfeat_ch.get([16, 16, 16, 16]) + self.in_ch = in_ch + self.MaskFlownet_S = MaskFlownet_S(in_ch=in_ch) + self.activate = nn.LeakyReLU(0.1) + + self.conv1x = conv(self.in_ch+1, 16, stride=2) # image + + self.conv1y = conv(16, 16, stride=1) + self.conv1z = conv(16, 16, stride=1) + self.conv2x = conv(16, 32, stride=2) + self.conv2y = conv(32, 32, stride=1) + self.conv2z = conv(32, 32, stride=1) + self.conv3x = conv(32, 64, stride=2) + self.conv3y = conv(64, 64, stride=1) + self.conv3z = conv(64, 64, stride=1) + self.conv4x = conv(64, 96, stride=2) + self.conv4y = conv(96, 96, stride=1) + self.conv4z = conv(96, 96, stride=1) + self.conv5x = conv(96, 128, stride=2) + self.conv5y = conv(128, 128, stride=1) + self.conv5z = conv(128, 128, stride=1) + self.conv6x = conv(128, 196, stride=2) + self.conv6y = conv(196, 196, stride=1) + self.conv6z = conv(196, 196, stride=1) + + self.leakyRELU = nn.LeakyReLU(0.1) + #self.corr = Correlation(pad_size=self.md, kernel_size=1, max_displacement=self.md, stride1=1, stride2=1, corr_multiply=1) + self.corr = CorrelationLayerSimple(md=self.md) + + + nd = (2*self.md+1)**2 + dd = np.cumsum([128,128,96,64,32]) + + od = nd+nd+2 + self.conv6_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv6_1 = conv(od+dd[0], 128, kernel_size=3, stride=1) + self.conv6_2 = conv(od+dd[1], 96, kernel_size=3, stride=1) + self.conv6_3 = conv(od+dd[2], 64, kernel_size=3, stride=1) + self.conv6_4 = conv(od+dd[3], 32, kernel_size=3, stride=1) + self.pred_flow6 = predict_flow(od + dd[4]) + self.upfeat5 = deconv(od+dd[4], self.upfeat_ch[0], kernel_size=4, stride=2, padding=1) + + # od = nd+128+4 + od = nd+nd+128+16+2+2 + self.conv5_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow5 = predict_flow(od + dd[4]) + self.upfeat4 = deconv(od+dd[4], self.upfeat_ch[1], kernel_size=4, stride=2, padding=1) + # self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + # self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + # od = nd+96+4 + # od = nd+96+18 + od = nd+nd+96+16+2+2 + self.conv4_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow4 = predict_flow(od + dd[4]) + self.upfeat3 = deconv(od+dd[4], self.upfeat_ch[2], kernel_size=4, stride=2, padding=1) + # self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + # self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + # od = nd+64+4 + # od = nd+64+18 + od = nd+nd+64+16+2+2 + self.conv3_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow3 = predict_flow(od + dd[4]) + self.upfeat2 = deconv(od+dd[4], self.upfeat_ch[3], kernel_size=4, stride=2, padding=1) + # self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) + # self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) + + # od = nd+32+4 + # od = nd+32+18 + od = nd+nd+32+16+2+2 + self.conv2_0 = conv(od, 128, kernel_size=3, stride=1) + self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1) + self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1) + self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1) + self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) + self.pred_flow2 = predict_flow(od + dd[4]) + + self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) + self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) + self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) + self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) + self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) + self.dc_conv7 = predict_flow(32) + + # self.upfeat5 = deconv() + + self.deform6 = deformable_conv(196, 196) + self.deform5 = deformable_conv(128, 128) + self.deform4 = deformable_conv(96, 96) + self.deform3 = deformable_conv(64, 64) + self.deform2 = deformable_conv(32, 32) + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + """ + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1,-1).repeat(H,1) + yy = torch.arange(0, H).view(-1,1).repeat(1,W) + xx = xx.view(1,1,H,W).repeat(B,1,1,1) + yy = yy.view(1,1,H,W).repeat(B,1,1,1) + grid = torch.cat((xx,yy),1).float() + + if x.is_cuda: + grid = grid.cuda() + # vgrid = Variable(grid) + flo + vgrid = Variable(grid) + torch.flip(flo, [1]) + + # scale grid to [-1,1] + vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 + vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 + + # vgrid = vgrid.permute(0,2,3,1) + vgrid = vgrid.permute(0,2,3,1).clamp(-1.1, 1.1) + output = nn.functional.grid_sample(x, vgrid, align_corners=True) + mask = torch.autograd.Variable(torch.ones(x.size())).cuda() + mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) + + # if W==128: + # np.save('mask.npy', mask.cpu().data.numpy()) + # np.save('warp.npy', output.cpu().data.numpy()) + + mask[mask<0.9999] = 0 + mask[mask>0] = 1 + + return output*mask + + def forward(self, x): + im1 = x[:,:self.in_ch,:,:] + im2 = x[:,self.in_ch:,:,:] + + _, _, srcs = self.MaskFlownet_S(x) + c1s, c2s, flows, c30, c40 = srcs + c11, c12, c13, c14, c15, c16 = c1s + c21, c22, c23, c24, c25, c26 = c2s + + c31 = self.conv1z(self.conv1y(self.conv1x(c30))) + c32 = self.conv2z(self.conv2y(self.conv2x(c31))) + c33 = self.conv3z(self.conv3y(self.conv3x(c32))) + c34 = self.conv4z(self.conv4y(self.conv4x(c33))) + c35 = self.conv5z(self.conv5y(self.conv5x(c34))) + c36 = self.conv6z(self.conv6y(self.conv6x(c35))) + + c41 = self.conv1z(self.conv1y(self.conv1x(c40))) + c42 = self.conv2z(self.conv2y(self.conv2x(c41))) + c43 = self.conv3z(self.conv3y(self.conv3x(c42))) + c44 = self.conv4z(self.conv4y(self.conv4x(c43))) + c45 = self.conv5z(self.conv5y(self.conv5x(c44))) + c46 = self.conv6z(self.conv6y(self.conv6x(c45))) + + flow6 = flows[0] + + warp6u = (flow6*self.scale/self.strides[0]).unsqueeze(1) + warp6u = torch.repeat_interleave(warp6u, 9, 1) + S1, S2, S3, S4, S5 = warp6u.shape + warp6u = warp6u.view(S1, S2*S3, S4, S5) + warp6u = self.deform6(c26, warp6u) + warp6u = self.leakyRELU(warp6u) + corr6u = self.leakyRELU(self.corr(c16, warp6u)) + warp6v = c46 + corr6v = self.leakyRELU(self.corr(c36, warp6v)) + x = torch.cat((corr6u, corr6v, flow6),1) + x = torch.cat((self.conv6_0(x), x),1) + x = torch.cat((self.conv6_1(x), x),1) + x = torch.cat((self.conv6_2(x), x),1) + x = torch.cat((self.conv6_3(x), x),1) + x = torch.cat((self.conv6_4(x), x),1) + flow6 = flow6 + self.pred_flow6(x) + + feat5 = self.leakyRELU(self.upfeat5(x)) + flow5 = Upsample(flow6, 2) + warp5u = (flow5*self.scale/self.strides[1]).unsqueeze(1) + warp5u = torch.repeat_interleave(warp5u, 9, 1) + S1, S2, S3, S4, S5 = warp5u.shape + warp5u = warp5u.view(S1, S2*S3, S4, S5) + warp5u = self.deform5(c25, warp5u) + warp5u = self.leakyRELU(warp5u) + corr5u = self.leakyRELU(self.corr(c15, warp5u)) + warp5v = c45 + corr5v = self.leakyRELU(self.corr(c35, warp5v)) + x = torch.cat((c15, feat5, corr5u, corr5v, flow5, flows[1]),1) + x = torch.cat((self.conv5_0(x), x),1) + x = torch.cat((self.conv5_1(x), x),1) + x = torch.cat((self.conv5_2(x), x),1) + x = torch.cat((self.conv5_3(x), x),1) + x = torch.cat((self.conv5_4(x), x),1) + flow5 = flow5 + self.pred_flow5(x) + + feat4 = self.leakyRELU(self.upfeat4(x)) + flow4 = Upsample(flow5, 2) + warp4u = (flow4*self.scale/self.strides[2]).unsqueeze(1) + warp4u = torch.repeat_interleave(warp4u, 9, 1) + S1, S2, S3, S4, S5 = warp4u.shape + warp4u = warp4u.view(S1, S2*S3, S4, S5) + warp4u = self.deform4(c24, warp4u) + warp4u = self.leakyRELU(warp4u) + corr4u = self.leakyRELU(self.corr(c14, warp4u)) + warp4v = c44 + corr4v = self.leakyRELU(self.corr(c34, warp4v)) + x = torch.cat((c14, feat4, corr4u, corr4v, flow4, flows[2]),1) + x = torch.cat((self.conv4_0(x), x),1) + x = torch.cat((self.conv4_1(x), x),1) + x = torch.cat((self.conv4_2(x), x),1) + x = torch.cat((self.conv4_3(x), x),1) + x = torch.cat((self.conv4_4(x), x),1) + flow4 = flow4 + self.pred_flow4(x) + + feat3 = self.leakyRELU(self.upfeat3(x)) + flow3 = Upsample(flow4, 2) + warp3u = (flow3*self.scale/self.strides[3]).unsqueeze(1) + warp3u = torch.repeat_interleave(warp3u, 9, 1) + S1, S2, S3, S4, S5 = warp3u.shape + warp3u = warp3u.view(S1, S2*S3, S4, S5) + warp3u = self.deform3(c23, warp3u) + warp3u = self.leakyRELU(warp3u) + corr3u = self.leakyRELU(self.corr(c13, warp3u)) + warp3v = c43 + corr3v = self.leakyRELU(self.corr(c33, warp3v)) + x = torch.cat((c13, feat3, corr3u, corr3v, flow3, flows[3]),1) + x = torch.cat((self.conv3_0(x), x),1) + x = torch.cat((self.conv3_1(x), x),1) + x = torch.cat((self.conv3_2(x), x),1) + x = torch.cat((self.conv3_3(x), x),1) + x = torch.cat((self.conv3_4(x), x),1) + flow3 = flow3 + self.pred_flow3(x) + + feat2 = self.leakyRELU(self.upfeat2(x)) + flow2 = Upsample(flow3, 2) + warp2u = (flow2*self.scale/self.strides[4]).unsqueeze(1) + warp2u = torch.repeat_interleave(warp2u, 9, 1) + S1, S2, S3, S4, S5 = warp2u.shape + warp2u = warp2u.view(S1, S2*S3, S4, S5) + warp2u = self.deform2(c22, warp2u) + warp2u = self.leakyRELU(warp2u) + corr2u = self.leakyRELU(self.corr(c12, warp2u)) + warp2v = c42 + corr2v = self.leakyRELU(self.corr(c32, warp2v)) + x = torch.cat((c12, feat2, corr2u, corr2v, flow2, flows[4]),1) + x = torch.cat((self.conv2_0(x), x),1) + x = torch.cat((self.conv2_1(x), x),1) + x = torch.cat((self.conv2_2(x), x),1) + x = torch.cat((self.conv2_3(x), x),1) + x = torch.cat((self.conv2_4(x), x),1) + flow2 = flow2 + self.pred_flow2(x) + + x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) + flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x))) + + #preds = [flow * self.scale for flow in [flow6, flow5, flow4, flow3, flow2]] + preds = [flow * self.scale for flow in [flow2, flow3, flow4, flow5, flow6]] + visuals = [] + visuals.append(flow2[:,:1]) + + return preds + #return preds, visuals, [] + + +class EpeLoss(nn.Module): + def __init__(self, eps = 0): + super(EpeLoss, self).__init__() + self.eps = eps + + def forward(self, pred, label): + loss = ((pred - label).pow(2).sum(1) + self.eps).sqrt() + return loss.view(loss.shape[0], -1).mean(1) + + +class EpeLossWithMask(nn.Module): + def __init__(self, eps=1e-8, q=None): + super(EpeLossWithMask, self).__init__() + self.eps = eps + self.q = q + + def forward(self, pred, label, mask): + if self.q is not None: + loss = ((pred - label).abs().sum(1) + self.eps) ** self.q + else: + loss = ((pred - label).pow(2).sum(1) + self.eps).sqrt() + loss = loss * mask.squeeze(1) + loss = loss.view(loss.shape[0], -1).sum(1) / mask.view(mask.shape[0], -1).sum(1) + return loss + + +class MultiscaleEpe(nn.Module): + def __init__(self, scales, weights, match, eps = 1e-8, q = None): + super(MultiscaleEpe, self).__init__() + + self.scales = scales + self.weights = weights + self.match = match + self.eps = eps + self.q = q + + def forward(self, flow, mask, *predictions): + losses = 0 + if self.match == 'upsampling': + for p, w, s in zip(predictions, self.weights, self.scales): + losses += EpeLossWithMask(eps=self.eps, q=self.q)(Upsample(p, s), flow, mask) * w + elif self.match == 'downsampling': + for p, w, s in zip(predictions, self.weights, self.scales): + losses += EpeLossWithMask(eps=self.eps, q=self.q)(p, Downsample(flow, s), Downsample(mask, s)) * w + else: + raise NotImplementedError + return losses diff --git a/windflow/networks/models.py b/windflow/networks/models.py new file mode 100644 index 0000000000000000000000000000000000000000..395e30f2b0cdc2c359197cca365c7cb946a65e4d --- /dev/null +++ b/windflow/networks/models.py @@ -0,0 +1,18 @@ +from . import warper, RAFT + + +def get_flow_model(model_name, small=True): + # select base model + model_name = model_name.lower() + if model_name == 'raft': + raft_args ={'small': small, + 'lr': 1e-5, + 'mixed_precision': True, + 'dropout': 0.0, + 'corr_levels': 4, + 'corr_radius': 4} + model = RAFT(raft_args) + else: + model = None + + return model diff --git a/windflow/networks/raft/__init__.py b/windflow/networks/raft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/networks/raft/corr.py b/windflow/networks/raft/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..ae260b992728197063fbe0f68b87b20459fba077 --- /dev/null +++ b/windflow/networks/raft/corr.py @@ -0,0 +1,91 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i].float() + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/windflow/networks/raft/datasets.py b/windflow/networks/raft/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..3411fdacfb900024005e8997d07c600e963a95ca --- /dev/null +++ b/windflow/networks/raft/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/windflow/networks/raft/extractor.py b/windflow/networks/raft/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..312f2b473fcbcf60722157add2db15deed3e5a94 --- /dev/null +++ b/windflow/networks/raft/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=1, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, input_dim=1, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/windflow/networks/raft/raft.py b/windflow/networks/raft/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7973d246324c2aa92f740d63a14d9c7b0da091 --- /dev/null +++ b/windflow/networks/raft/raft.py @@ -0,0 +1,143 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args['small']: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args['corr_levels'] = 4 + args['corr_radius'] = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args['corr_levels'] = 4 + args['corr_radius'] = 4 + + if 'dropout' not in self.args: + self.args['dropout'] = 0. + + if 'alternate_corr' not in self.args: + self.args['alternate_corr'] = False + + # feature network, context network, and update block + if args['small']: + self.fnet = SmallEncoder(input_dim=1, output_dim=128, norm_fn='instance', dropout=args['dropout']) + self.cnet = SmallEncoder(input_dim=1, output_dim=hdim+cdim, norm_fn='none', dropout=args['dropout']) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(input_dim=1, output_dim=256, norm_fn='instance', dropout=args['dropout']) + self.cnet = BasicEncoder(input_dim=1, output_dim=hdim+cdim, norm_fn='batch', dropout=args['dropout']) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + + def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): + """ Estimate optical flow between pair of frames """ + + #image1 = 2 * (image1 / 255.0) - 1.0 + #image2 = 2 * (image2 / 255.0) - 1.0 + + #image1 = image1.contiguous() + #image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args['mixed_precision']): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args['alternate_corr']: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args['corr_radius']) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args['corr_radius']) + + # run the context network + with autocast(enabled=self.args['mixed_precision']): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1.float()) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args['mixed_precision']): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return flow_up, + + return flow_predictions diff --git a/windflow/networks/raft/update.py b/windflow/networks/raft/update.py new file mode 100644 index 0000000000000000000000000000000000000000..68c92a4922bf0e45f16e55a5deb15c0271cb2fee --- /dev/null +++ b/windflow/networks/raft/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/windflow/networks/raft/utils/__init__.py b/windflow/networks/raft/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/networks/raft/utils/augmentor.py b/windflow/networks/raft/utils/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..e81c4f2b5c16c31c0ae236d744f299d430228a04 --- /dev/null +++ b/windflow/networks/raft/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/windflow/networks/raft/utils/flow_viz.py b/windflow/networks/raft/utils/flow_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..dcee65e89b91b07ee0496aeb4c7e7436abf99641 --- /dev/null +++ b/windflow/networks/raft/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/windflow/networks/raft/utils/frame_utils.py b/windflow/networks/raft/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c491135efaffc25bd61ec3ecde99d236f5deb12 --- /dev/null +++ b/windflow/networks/raft/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/windflow/networks/raft/utils/utils.py b/windflow/networks/raft/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f32d281c1c46353a0a2bf36b0550adb74125c65 --- /dev/null +++ b/windflow/networks/raft/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/windflow/networks/utils.py b/windflow/networks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..141799a72d384b4ba838f86ac95026aa54199fb8 --- /dev/null +++ b/windflow/networks/utils.py @@ -0,0 +1,25 @@ +import numpy as np + +def split_array(arr, tile_size=128, overlap=16): + ''' + Split a 3D numpy array into patches for inference + (Channels, Height, Width) + Args: + tile_size: width and height of patches to return + overlap: number of pixels to overlap between patches + Returns: + dict(patches, upper_left): patches and indices of original array + ''' + arr = arr[np.newaxis] + width, height = arr.shape[2:4] + arrs = dict(patches=[], upper_left=[]) + for i in range(0, width, tile_size - overlap): + for j in range(0, height, tile_size - overlap): + i = min(i, width - tile_size) + j = min(j, height - tile_size) + arrs['patches'].append(arr[:,:,i:i+tile_size,j:j+tile_size]) + arrs['upper_left'].append([[i,j]]) + arrs['patches'] = np.concatenate(arrs['patches']) + arrs['upper_left'] = np.concatenate(arrs['upper_left']) + return arrs['patches'], arrs['upper_left'] + diff --git a/windflow/networks/warper.py b/windflow/networks/warper.py new file mode 100644 index 0000000000000000000000000000000000000000..958c21258cccbd693a5022cd97f237e479530cb8 --- /dev/null +++ b/windflow/networks/warper.py @@ -0,0 +1,43 @@ +import numpy as np +import torch +import torch.nn as nn + +class BackwardWarp(nn.Module): + def __init__(self): + super(BackwardWarp, self).__init__() + #self.device = device + # end + + def forward(self, tensorInput, tensorFlow): + if hasattr(self, 'tensorGrid') == False or self.tensorGrid.size(0) != tensorFlow.size(0) or self.tensorGrid.size(2) != tensorFlow.size(2) or self.tensorGrid.size(3) != tensorFlow.size(3): + tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1) + tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3)) + + self.tensorGrid = torch.cat([ tensorHorizontal, tensorVertical ], 1).cuda() + # end + + tensorFlow = torch.cat([ tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), + tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0)], 1) + out = torch.nn.functional.grid_sample(input=tensorInput, grid=(self.tensorGrid + tensorFlow).permute(0, 2, 3, 1), + mode='bilinear', padding_mode='border') + return out + # end + +class BackwardWarpMV(nn.Module): + def __init__(self): + super(BackwardWarpMV, self).__init__() + self.warper = BackwardWarp() + + def forward(self, tensorInput, tensorFlow): + num_channels = tensorFlow.size(1)//2 + U = tensorFlow[:,:num_channels] + V = tensorFlow[:,num_channels:] + out = [] + for j in range(num_channels): + #uv = tensorFlow[:,[j,j+num_channels],:,:] + uv = torch.cat([U[:,j:j+1], V[:,j:j+1]], 1) + out_j = self.warper(tensorInput[:,j:j+1], uv) + out.append(out_j) + + return torch.cat(out, 1) + diff --git a/windflow/train/__init__.py b/windflow/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/train/base.py b/windflow/train/base.py new file mode 100644 index 0000000000000000000000000000000000000000..637316d6d04889c414c62c0a5684fbf6418daaf5 --- /dev/null +++ b/windflow/train/base.py @@ -0,0 +1,115 @@ +import os, sys + +import torch +import torch.nn as nn +from torch import optim +from torch.utils import data +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.utils.tensorboard import SummaryWriter # there is a bug with summarywriter importing + +import torchvision + +def scale_image(x): + xmn = torch.min(x) + xmx = torch.max(x) + return (x - xmn) / (xmx - xmn) + +class BaseTrainer(nn.Module): + def __init__(self, model, + model_name, + model_path, + lr=1e-4, + device=None, + distribute=False, + rank=0): + super(BaseTrainer, self).__init__() + self.model = model + self.model_name = model_name + self.model_path = model_path + self.lr = lr + self.device = device + self.distribute = distribute + self.rank = rank + + # NEW + self.scaler = torch.cuda.amp.GradScaler() + + self.checkpoint_filepath = os.path.join(model_path, 'checkpoint.pth.tar') + if (rank == 0) and (not os.path.exists(model_path)): + os.makedirs(model_path) + + self.global_step = 0 + self._set_optimizer() + self._set_summary_writer() + + + def _set_optimizer(self): + # set optimizer + #if self.rank == 0: + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4) + + + def _set_summary_writer(self): + self.tfwriter_train = SummaryWriter(os.path.join(self.model_path, 'train', 'tfsummary')) + self.tfwriter_valid = SummaryWriter(os.path.join(self.model_path, 'valid', 'tfsummary')) + + def load_checkpoint(self): + filename = self.checkpoint_filepath + if os.path.isfile(filename): + print("loading checkpoint %s" % filename) + checkpoint = torch.load(filename) + self.global_step = checkpoint['global_step'] + try: + self.model.module.load_state_dict(checkpoint['model']) + except: + self.model.load_state_dict(checkpoint['model']) + + self.optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (Step {})" + .format(filename, self.global_step)) + else: + print("=> no checkpoint found at '{}'".format(filename)) + + def save_checkpoint(self): + if self.distribute: + state = {'global_step': self.global_step, + 'model': self.model.module.state_dict(), + 'optimizer': self.optimizer.state_dict()} + else: + state = {'global_step': self.global_step, + 'model': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict()} + torch.save(state, self.checkpoint_filepath) + + def log_tensorboard(self): + pass + + def get_tfwriter(self, train): + if train: + return self.tfwriter_train + else: + return self.tfwriter_valid + + def log_scalar(self, x, name, train=True): + tfwriter = self.get_tfwriter(train) + tfwriter.add_scalar(name, x, self.global_step) + + def log_image_grid(self, img, name, train=True, N=4): + ''' + img of shape (N, C, H, W) + ''' + tfwriter = self.get_tfwriter(train) + img_grid = torchvision.utils.make_grid(img[:N]) + tfwriter.add_image(name, scale_image(img_grid), self.global_step) + + def log_flow_grid(self, flows, name, train=True, N=4): + tfwriter = self.get_tfwriter(train) + U_grid = torchvision.utils.make_grid(flows[:N,:1]) + V_grid = torchvision.utils.make_grid(flows[:N,1:]) + intensity = (U_grid ** 2 + V_grid ** 2)**0.5 + tfwriter.add_image(f'{name}/U', scale_image(U_grid), self.global_step) + tfwriter.add_image(f'{name}/V', scale_image(V_grid), self.global_step) + tfwriter.add_image(f'{name}/intensity', scale_image(intensity), self.global_step) + \ No newline at end of file diff --git a/windflow/train/flownet_trainer.py b/windflow/train/flownet_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..18a7f5899d82bcafdbd8b9dd0a362337d1644741 --- /dev/null +++ b/windflow/train/flownet_trainer.py @@ -0,0 +1,118 @@ +import os +import numpy as np + +import torch +import torch.nn as nn +from torch import optim +import torchvision + +from .base import BaseTrainer +from ..networks import FlowNetS +from ..losses import CharbonnierLoss, L1Loss, L2Loss, RMSVDLoss + +from torch.utils.tensorboard import SummaryWriter + +class MultiFrameTrainer(BaseTrainer): + def __init__(self, + model, + model_name, + model_path, + lr=1e-4, + device=None, + distribute=False, + rank=0): + BaseTrainer.__init__(self, model, model_name, model_path, lr=lr, + device=device, distribute=distribute, rank=rank) + + # set loss functions + self.photo_loss = L2Loss(None) + + def step(self, inputs, labels, log=False, train=True): + ''' + Inputs shape: (N, T, C, H, W) + Flows shape: (N, T, 2, H, W) + ''' + + N_pairs = len(inputs)-1 + # for every pair, compute + +class FlownetTrainer(BaseTrainer): + def __init__(self, + model, + model_name, + model_path, + lr=1e-4, + device=None, + distribute=False, + rank=0, + loss='L2'): + BaseTrainer.__init__(self, model, model_name, model_path, lr=lr, + device=device, distribute=distribute, rank=rank) + + # set loss functions + if loss == 'L2': + self.photo_loss = L2Loss(None) + elif loss.lower() == 'rmsvd': + print("RMSVD Loss") + self.photo_loss = RMSVDLoss() + elif loss == 'CHAR': + self.photo_loss = CharbonnierLoss(alpha=0.50) + elif loss == 'L1': + self.photo_loss = L1Loss(None) + + #self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', + #factor=0.5, patience=100) + #def step(self, I0, I1, UV, log=False, train=True): + def step(self, inputs, labels, log=False, train=True): + ''' + Inputs shape: (N, T, C, H, W) + Flows shape: (N, T, 2, H, W) + ''' + I0 = inputs[:,0] + I1 = inputs[:,1] + x = torch.cat([I0, I1], 1) + + #with torch.cuda.amp.autocast(): + #flows = self.model(I0, I1) + flows = self.model(x) + #if not isinstance(flows, list): + # flows = [flows] + + UV_l = labels[:,0] + losses = [] + weights = [0.32, 0.16, 0.08, 0.04, 0.02, 0.005] + + for layer, flow_l in enumerate(flows): + H = flow_l.shape[2] + scale_factor = flow_l.shape[2] / UV_l.shape[2] + UV_l = torch.nn.functional.interpolate(UV_l, scale_factor=scale_factor, + mode='bilinear', recompute_scale_factor=True) + losses.append(weights[layer] * self.photo_loss(flow_l, UV_l)) + if log: + self.log_flow_grid(flow_l, f'flow_level_{layer}', train=train) + + total_loss = sum(losses) + + if train: + self.optimizer.zero_grad() + total_loss.backward() + #self.scaler.scale(total_loss).backward() + self.optimizer.step() + #self.scaler.step(self.optimizer) + #self.scaler.update() + #else: + # self.scheduler.step(total_loss) + + if log and (self.rank == 0): + self.log_scalar(total_loss, "total_loss", train=train) + self.log_image_grid(I0, "data/I0", train=train) + self.log_image_grid(I1, "data/I1", train=train) + self.log_flow_grid(labels[:,0], "label", train=train) + for i in range(len(losses)): + self.log_scalar(losses[i], f"losses/level_{i}", train=train) + self.log_flow_grid(flows[i], f"level_{i}", train=train) + + if train: + self.global_step += 1 + + return total_loss diff --git a/windflow/train/raft_trainer.py b/windflow/train/raft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f2486db6263550fecb07febf66cc04cadf81753d --- /dev/null +++ b/windflow/train/raft_trainer.py @@ -0,0 +1,158 @@ +import os +import numpy as np + +import torch +import torch.nn as nn +from torch import optim +#from torch.nn.parallel import DistributedDataParallel as DDP +import torchvision + +from .base import BaseTrainer +from ..networks import raft, warper +from ..losses import CharbonnierLoss, L1Loss, L2Loss + +from torch.utils.tensorboard import SummaryWriter + +MAX_FLOW = 200 + +def sequence_loss(flow_preds, flow_gt, gamma=0.8, max_flow=MAX_FLOW): + """ Loss function defined over sequence of flow predictions """ + + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + #valid = (valid >= 0.5) & (mag < max_flow) + valid = (mag < max_flow) + for i in range(n_predictions): + i_weight = gamma**(n_predictions - i - 1) + i_loss = (flow_preds[i] - flow_gt).abs() + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + #flow_loss += (i_weight * i_loss).mean() + + return flow_loss + + +def unsupervised_sequence_loss(I0, I1, flow_pred, gamma=0.8): + """Loss function unsupervised warping I1 to I0 with predicted flows""" + warp = warper.BackwardWarp() + loss = CharbonnierLoss(alpha=0.50) + + n_predictions = len(flow_pred) + flow_loss = 0.0 + + for i in range(n_predictions): + I0_i = warp(I1, flow_pred[i]) + i_weight = gamma**(n_predictions - i - 1) + i_loss = loss(I0, I0_i) + flow_loss += (i_weight * i_loss).mean() + + return flow_loss + + +class RAFTTrainer(BaseTrainer): + def __init__(self, model, model_path, iters=24, + device=None, lr=1e-5, distribute=False, + clip=1., rank=0): + BaseTrainer.__init__(self, model, 'raft', model_path, lr=lr, + device=device, distribute=distribute, rank=rank) + self.iters = iters + self.clip = clip + self.optimizer = optim.AdamW(model.parameters(), lr=lr, + weight_decay=1e-4, eps=1e-8) + + self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, + max_lr=lr, + total_steps=500000, + pct_start=0.05, + cycle_momentum=False, + anneal_strategy='linear') + + + #def step(self, I0, I1, flow_gt, flow_init=None, log=False, train=True): + def step(self, inputs, labels, flow_init=None, log=False, train=True): + if train: + tfwriter = self.tfwriter_train + else: + tfwriter = self.tfwriter_valid + + with torch.cuda.amp.autocast(): + I0 = inputs[:,0] + I1 = inputs[:,1] + flow_predictions = self.model(I0, I1, iters=self.iters, flow_init=flow_init) + loss = sequence_loss(flow_predictions, labels[:,0]) + + if train: + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) + + self.optimizer.step() + #self.scaler.step(self.optimizer) + #self.scalar.update() + self.scheduler.step() + + + if log and (self.rank == 0): + self.log_scalar(loss, 'total_loss', train) + self.log_flow_grid(labels[:,0], 'label', train) + self.log_image_grid(I0, 'data/I0', train) + self.log_image_grid(I1, 'data/I1', train) + + for i in range(0, self.iters, 2): + self.log_flow_grid(flow_predictions[i], f'flows/iter_{i}', train) + if train: + self.global_step = self.global_step + 1 + return loss + + +class RAFTGuided(BaseTrainer): + def __init__(self, model, model_path, iters=10, lambda_obs=0.1, + device=None, lr=1e-4, distribute=False, clip=1., rank=0): + BaseTrainer.__init__(self, model, 'raft', model_path, lr=lr, + device=device, distribute=distribute, rank=rank) + self.iters = iters + self.clip = clip + self.lambda_obs = lambda_obs + + def step(self, I0_obs, I1_obs, I0_phys, I1_phys, flow_phys, train=True, log=None): + if train: + tfwriter = self.tfwriter_train + else: + tfwriter = self.tfwriter_valid + + with torch.cuda.amp.autocast(): + flow_predictions_phys = self.model(I0_phys, I1_phys, iters=self.iters) + loss_guide = sequence_loss(flow_predictions_phys, flow_phys) + + #flow_predictions_obs = self.model(I0_obs, I1_obs, iters=self.iters) + #loss_obs = unsupervised_sequence_loss(I0_obs, I1_obs, flow_predictions_obs) + print('loss', loss_guide) + loss = loss_guide # + loss_obs * self.lambda_obs + + if train: + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) + + self.optimizer.step() + #self.scaler.step(self.optimizer) + #self.scalar.update() + + + if log and (self.rank == 0): + self.log_scalar(loss_guide, 'loss_guide', train) + self.log_scalar(loss_obs, 'loss_obs', train) + self.log_scalar(loss, 'total_loss', train) + self.log_flow_grid(flow_gt, 'label', train) + self.log_image_grid(I0, 'data/I0', train) + self.log_image_grid(I1, 'data/I1', train) + + for i in range(0, self.iters, 2): + self.log_flow_grid(flow_predictions[i], f'flows/iter_{i}', train) + + if train: + self.global_step = self.global_step + 1 + + return loss diff --git a/windflow/train/trainers.py b/windflow/train/trainers.py new file mode 100644 index 0000000000000000000000000000000000000000..97d3b8076516f94af7e78745d386ff0621382cd9 --- /dev/null +++ b/windflow/train/trainers.py @@ -0,0 +1,16 @@ +from .raft_trainer import RAFTTrainer, RAFTGuided +# from .flownet_trainer import FlownetTrainer, MultiFrameTrainer + + +def get_flow_trainer(model, model_name, model_path, distribute=False, rank=0, lr=1e-4, loss='L2'): + # Select trainer + if model_name.lower() == 'raft': + trainer = RAFTTrainer(model, model_path, distribute=distribute, + rank=rank, lr=lr, iters=24) + elif model_name.lower() in ['flownets', 'pwc-net', 'flownet2', 'flownetc', 'flownetsd', 'pwcnet', 'maskflownet']: + trainer = FlownetTrainer(model, model_name, model_path, + distribute=distribute, rank=rank, lr=lr, + loss=loss) + # trainer = MultiFrameTrainer(model, model_name, model_path, + # distribute=distribute, rank=rank, lr=lr) + return trainer diff --git a/windflow/utils/__init__.py b/windflow/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/windflow/utils/plotting.py b/windflow/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e20563e3797de12d5935626c23b337d31027f5 --- /dev/null +++ b/windflow/utils/plotting.py @@ -0,0 +1,98 @@ +import matplotlib.pyplot as plt +import torch +import torch.nn +import numpy as np +from mpl_toolkits.basemap import Basemap +from metpy.plots import colortables +import cv2 + +def downsample(x, pool=25): + x = torch.from_numpy(x[np.newaxis]) + xlr = torch.nn.AvgPool2d(pool, stride=pool)(x) + return xlr.detach().numpy()[0] + +def flow_quiver_plot(u, v, x=None, y=None, ax=None, + down=25, vmin=None, vmax=None, latlon=False, + size=10, cmap='jet', colorbar=False): + + intensity = (u**2 + v**2) ** 0.5 + + u_l = downsample(u, down) + v_l = downsample(v, down) + + + intensity_l = ( u_l ** 2 + v_l**2 ) ** 0.5 + #intensity_l = ( u_l ** 2 + v_l**2 ) ** 0.25 + + u_l = u_l #/ intensity_l + v_l = v_l #/ intensity_l + + if (x is None) or (y is None): + x = np.arange(0, u_l.shape[1]) * down + down/2. + y = np.arange(0, v_l.shape[0]) * down + down/2. + X, Y = np.meshgrid(x, y) + else: + x = downsample(x, down) + y = downsample(y, down) + X, Y = x, y + + #if X.shape[0] != u_l.shape[0]: + #Y = cv2.resize(Y, dsize=v_l.shape, interpolation=cv2.INTER_LINEAR) + #X = cv2.resize(X, dsize=u_l.shape, interpolation=cv2.INTER_LINEAR) + + + if not ax: + ratio = 1.*u.shape[0] / u.shape[1] + hi = int(ratio * size) + wi = int(size) + fig = plt.figure(figsize=(wi,hi), frameon=False) + ax = fig.add_axes([0, 0, 1, 1]) + ax.axis('off') + + if vmax is None: + vmax = np.nanmax(intensity) + if vmin is None: + vmin = np.nanmin(intensity) + + im1 = ax.imshow(intensity, origin='upper', cmap=cmap, vmax=vmax, vmin=vmin) + if colorbar: + plt.colorbar(im1, label='Wind Speed (m/s)', aspect=50, pad=0.01, shrink=0.85) + + #ax.quiver(X, Y, v_l, u_l, pivot='middle', + # width=0.001, headwidth=2, headlength=2) + scale = 150 + try: + ax.quiver(X, Y, v_l, u_l, latlon=latlon, + scale_units='inches', scale=scale) + except: + ax.quiver(X, Y, v_l, u_l, + scale_units='inches', scale=scale) + + if hasattr(ax, 'xaxis'): + ax.xaxis.set_ticks([]) + ax.yaxis.set_ticks([]) + ax.set_aspect('equal') + return ax + +def plot_infrared(data, x, y, sat_lon, ax=None, cmin=190, cmax=350): + # The geostationary projection + if ax is None: + fig = plt.figure(figsize=[8,8], frameon=False) + ax = fig.add_subplot(111) + + m = Basemap(projection='geos', lon_0=sat_lon, resolution='i', + rsphere=(6378137.00,6356752.3142), + llcrnrx=x.min(),llcrnry=y.min(), + urcrnrx=x.max(),urcrnry=y.max(), + ax=ax) + m.drawcoastlines() + m.drawcountries() + m.drawstates() + ax.axis('off') + + # Use a colortable/colormap available from MetPy + ir_norm, ir_cmap = colortables.get_with_range('ir_drgb_r', cmin, cmax) + # Plot the data using imshow + im = m.imshow(data, origin='upper', cmap=ir_cmap, norm=ir_norm) + plt.colorbar(im, pad=0, aspect=50, ticks=range(cmin,cmax,10), shrink=0.85) + return ax \ No newline at end of file diff --git a/windflow/utils/preprocess.py b/windflow/utils/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..6b45d90c0c498a3bf2c8ae3cefb3c24d21f5cab5 --- /dev/null +++ b/windflow/utils/preprocess.py @@ -0,0 +1,15 @@ +import numpy as np + +def image_histogram_equalization(image, number_bins=500): + # from http://www.janeriksolem.net/2009/06/histogram-equalization-with-python-and.html + # get image histogram + image_histogram, bins = np.histogram(image.flatten(), number_bins, density=True) + cdf = image_histogram.cumsum() # cumulative distribution function + cdf = cdf / cdf[-1] # normalize + + # use linear interpolation of cdf to find new pixel values + image_equalized = np.interp(image.flatten(), bins[:-1], cdf) + + image_equalized[~np.isfinite(image_equalized)] = 0. + + return image_equalized.reshape(image.shape)#, cdf \ No newline at end of file diff --git a/windflow/utils/resample_eco1280_to_geos5nr.py b/windflow/utils/resample_eco1280_to_geos5nr.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cc21eda2ab9f4e28d33c3dc3ae6be91c219e4a --- /dev/null +++ b/windflow/utils/resample_eco1280_to_geos5nr.py @@ -0,0 +1,75 @@ +from scipy.interpolate import RegularGridInterpolator + +import numpy as np +from netCDF4 import Dataset + +# resample methods: +linear = 'linear' +cubic = 'cubic' +nearest = 'nearest' + +eco_x = np.linspace(0.0, 2.0, 3600) +eco_y = np.linspace(0.0, 1.0, 1801) +g5_x = np.linspace(0.0, 2.0, 5760) +g5_y = np.linspace(0.0, 1.0, 2881) + +g5_scale_x = 5760 / 3600 +g5_scale_y = 2881 / 1801 + + +def time_linear_interp_grids(grid_a, grid_b, time_a, time_b, time_i): + grd_shape = grid_a.shape + + grid_a = grid_a.flatten() + grid_b = grid_b.flatten() + + del_time = time_b - time_a + + w_a = 1.0 - (time_i - time_a) / del_time + w_b = (time_i - time_a) / del_time + + grid_i = w_a * grid_a + w_b * grid_b + grid_i = np.reshape(grid_i, grd_shape) + + return grid_i + + +def upsample(scalar_field, ej_a=0, ej_b=1801, ei_a=0, ei_b=3600): + gj_a = int(g5_scale_y * ej_a) + gj_b = int(g5_scale_y * ej_b) + gi_a = int(g5_scale_x * ei_a) + gi_b = int(g5_scale_x * ei_b) + + # match N-S orientation to GEOS5-NR + scalar_field = scalar_field[::-1, :] + + scalar_field = scalar_field[ej_a:ej_b, ei_a:ei_b] + print('input dims: ', scalar_field.shape[0], scalar_field.shape[1]) + + intrp = RegularGridInterpolator((eco_y[ej_a:ej_b], eco_x[ei_a:ei_b]), scalar_field, method=linear, bounds_error=False) + + g5_y_s = g5_y[gj_a:gj_b] + g5_x_s = g5_x[gi_a:gi_b] + print('output dims: ', g5_y_s.shape[0], g5_x_s.shape[0]) + + xg, yg = np.meshgrid(g5_x_s, g5_y_s, indexing='xy') + yg, xg = yg.flatten(), xg.flatten() + pts = np.array([yg, xg]) + t_pts = np.transpose(pts) + + return np.reshape(intrp(t_pts), (g5_y_s.shape[0], g5_x_s.shape[0])) + + +def test(eco_file, g5_file, ej_a=0, ej_b=1801, ei_a=0, ei_b=3600): + rtgrp = Dataset(eco_file, 'r', format='NETCDF3') + gp_var = rtgrp['gp_newP'] + eco_gp = gp_var[:, :, 0].copy() + print('gp_newP range/shape: ', np.nanmin(eco_gp), np.nanmax(eco_gp), eco_gp.shape) + + rsm_img = upsample(eco_gp, ej_a=ej_a, ej_b=ej_b, ei_a=ei_a, ei_b=ei_b) + + return rsm_img, eco_gp + + + +