diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b3c55e7c6d766efc4ecfa043acf546d3c7dd16ee
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 TJ Vandal
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/data/download_g5nr.py b/data/download_g5nr.py
new file mode 100644
index 0000000000000000000000000000000000000000..725bf11e5ae2358923c7bc1f0f3a26527936bed9
--- /dev/null
+++ b/data/download_g5nr.py
@@ -0,0 +1,45 @@
+import os
+
+import urllib3
+import pandas
+
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument('output', metavar='G5NR_7km/', type=str,
+                    help='Where to save your data?')
+
+args = parser.parse_args()
+
+def url_to_pandas(url):
+    http = urllib3.PoolManager()
+    r = http.request('GET', url)
+    return pandas.read_html(r.data)[-2]
+
+path = 'https://portal.nccs.nasa.gov/datashare/gmao_obsteam/osse_for_wisc/'
+to_directory = args.output
+
+if not os.path.exists(to_directory):
+    os.makedirs(to_directory)
+
+df = url_to_pandas(path)
+files = []
+
+for name in df['Name']:
+    try:
+        int(name[:8])
+    except:
+        continue
+
+    folder_path = os.path.join(path, name)
+    df_folder = url_to_pandas(folder_path)
+    for filename in df_folder['Name']:
+        if isinstance(filename, str):
+            if 'nc4' == filename[-3:]:
+                filepath = os.path.join(folder_path, filename)
+                print(f"Filepath: {filepath}")
+                if os.path.exists(filepath):
+                    continue
+                    
+                wget_cmd = f"wget -O {to_directory}/{filename} {filepath}"
+                os.system(wget_cmd)
\ No newline at end of file
diff --git a/data/g5nr_to_zarr.py b/data/g5nr_to_zarr.py
new file mode 100644
index 0000000000000000000000000000000000000000..723ec0b49fcff711ba9d38122a87edbac73432a2
--- /dev/null
+++ b/data/g5nr_to_zarr.py
@@ -0,0 +1,29 @@
+import xarray as  xr
+import os, sys
+
+sys.path.append('..')
+
+from nexai.datasets.g5nr import g5nr
+from dask.distributed import Client
+
+
+# Set data variables
+label_directory = '/nobackupp10/tvandal/nex-ai-opticalflow/data/G5NR_7km/'
+label_files = g5nr.get_files(label_directory, 'test').reset_index()
+
+output_directory = '/nobackupp10/tvandal/nex-ai-opticalflow/data/G5NR_7km_zarr/'
+print(label_files)
+
+N = len(label_files)
+for i, row in label_files.iterrows():
+    fname = os.path.basename(row['U'])
+    output_zarr = os.path.join(output_directory, fname.split('.')[2] + '.zarr')
+    if os.path.exists(output_zarr):
+        continue
+        
+    U = xr.open_mfdataset(row['U'])
+    V = xr.open_mfdataset(row['V'])
+    ds = xr.open_mfdataset(row['QV']).merge(U).merge(V)
+    ds = ds.chunk(dict(time=1, lat=500, lon=500, lev=1))
+    ds.to_zarr(output_zarr)
+    print(i/N, output_zarr)
\ No newline at end of file
diff --git a/model_weights/windflow.raft.pth.tar b/model_weights/windflow.raft.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..0dbb63486c37cb23116fbb78033812c6beb5f3d6
Binary files /dev/null and b/model_weights/windflow.raft.pth.tar differ
diff --git a/pipelines/eco1280/eco_to_flows.py b/pipelines/eco1280/eco_to_flows.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad864a88d690681b8d0c8b22c0f68de958f513e9
--- /dev/null
+++ b/pipelines/eco1280/eco_to_flows.py
@@ -0,0 +1,141 @@
+import os, sys
+import glob
+import argparse
+import time
+
+# T.Rink, MacBook OSX doesn't support MPI
+# from mpi4py import MPI
+
+# MPI_COMM = MPI.COMM_WORLD
+# MPI_RANK = MPI_COMM.Get_rank()
+# MPI_SIZE = MPI_COMM.Get_size()
+MPI_RANK = 0
+
+# T.Rink, MacBook GPU (AMD Radeon Pro 5300M) doesn't support CUDA
+# N_GPUS = 2
+
+# rank_gpu = MPI_RANK % N_GPUS
+# os.environ['CUDA_VISIBLE_DEVICES'] = f'{rank_gpu}'
+os.environ['CUDA_VISIBLE_DEVICES'] = f'{0.0}'
+
+import numpy as np
+import pandas as pd
+import xarray as xr
+import matplotlib.pyplot as plt
+
+import torch
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+
+# from windflow.datasets.g5nr import g5nr
+from windflow.inference.inference_flows import FlowRunner
+from windflow.datasets.utils import cartesian_to_speed_eco
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--model_name", default="raft", type=str)
+# parser.add_argument("--checkpoint", default="models/raft-size_512/checkpoint.pth.tar", type=str)
+parser.add_argument("--checkpoint", default="model_weights/windflow.raft.pth.tar", type=str)
+parser.add_argument("--data_directory", default="data/ECO1280/", type=str)
+parser.add_argument("--output", default="data/ECO_flows/raft/", type=str)
+parser.add_argument("--n_files", default=None, type=int)
+parser.add_argument("--batch_size", default=2, type=int)
+
+args = parser.parse_args()
+
+# Set model information
+model_name = args.model_name
+
+# Set data variables
+data_directory = args.data_directory
+to_directory = args.output
+
+if (MPI_RANK == 0) and (not os.path.exists(to_directory)):
+    os.makedirs(to_directory)
+
+# get file list
+#files = g5nr.get_files(data_directory, 'test')
+#print("Number of OSSE test files", len(files))
+
+pattern = os.path.join(data_directory, '*.nc4')
+files = sorted(glob.glob(pattern))
+
+variables = ['gp', 'uv']
+var_files = {v: sorted([f for f in files if f'{v}_' in f]) for v in variables}
+var_files = pd.DataFrame(var_files)
+
+if args.n_files is None:
+    N_files = len(files)
+else:
+    N_files = args.n_files
+
+# N_per_rank = N_files // MPI_SIZE
+# files = files.iloc[N_per_rank * MPI_RANK: N_per_rank * (MPI_RANK+1) + 1]
+
+# load model
+tile_size = 512
+overlap = 128
+
+if model_name in ['pwc-net-rmsvd', 'pwc-net-l1']:
+    runner = FlowRunner('pwc-net', 
+                        tile_size=tile_size,
+                        overlap=overlap,
+                        batch_size=args.batch_size)
+else:
+    runner = FlowRunner(model_name.replace('-guided','').replace('-unflow', ''),
+                        tile_size=tile_size,
+                        overlap=overlap,
+                        batch_size=args.batch_size,
+                        device=torch.device('cpu'))
+ 
+runner.load_checkpoint(args.checkpoint)
+
+stats = []
+
+# iterate and perform inference, compute test statistics 
+ds1 = xr.open_mfdataset(files.iloc[0].values, engine='netcdf4')
+for i in range(1, files.shape[0]): 
+    # open dataset
+    
+    ds2 = xr.open_mfdataset(files.iloc[i].values, engine='netcdf4')
+ 
+    f = os.path.basename(files.iloc[i-1]['uv'])  # get ds1 file
+    to_flow_file = os.path.join(to_directory, f.replace('uv_', '_WindFlow_').replace('.nc', '.zarr'))
+    if os.path.exists(to_flow_file):
+        ds1 = ds2.copy()
+        continue 
+
+    # t = ds1['time'].values
+
+    output_ds = xr.zeros_like(ds1)
+    del output_ds['gp_newP']
+
+    U_flows = np.zeros(ds1.ugrd_newP.shape)
+    V_flows = np.zeros(ds1.vgrd_newP.shape)
+
+    t0 = time.time()
+    
+    for i, lev in enumerate(ds1.lev_p):
+        qv1_lev = ds1.sel(lev_p=lev)['gp_newP'].values[0]
+        qv2_lev = ds2.sel(lev_p=lev)['gp_newP'].values[0]
+        _, flows_lev = runner.forward(qv1_lev, qv2_lev)
+        U_flows[:, i] = flows_lev[0]
+        V_flows[:, i] = flows_lev[1]
+
+    output_ds['ugrd_newP'] = output_ds['ugrd_newP'] + U_flows
+    output_ds['vgrd_newP'] = output_ds['vgrd_newP'] + V_flows
+
+    output_ds = cartesian_to_speed_eco(output_ds)
+ 
+    output_ds.attrs['Source'] = 'NEX'
+    output_ds.attrs['Title'] = 'Optical Flow Feature Tracking'
+    output_ds.attrs['Contact'] = 'vandal@baeri.org'
+    output_ds.attrs['History'] = 'G5NR outputs from GEOS-5 by gmao processed by NEX optical flow.'
+    output_ds.attrs['Model'] = model_name
+    output_ds.attrs['Pytorch_Checkpoint'] = args.checkpoint
+
+    # output_ds.to_netcdf(to_flow_file)
+    output_ds.to_zarr(to_flow_file)
+    # print(f'Wrote to file {to_flow_file}')
+    print(f"Wrote to file: {to_flow_file} -- Processing time {time.time()-t0} (seconds)")
+
+    ds1 = ds2.copy()
diff --git a/pipelines/g5nr/g5nr_to_flows.py b/pipelines/g5nr/g5nr_to_flows.py
new file mode 100644
index 0000000000000000000000000000000000000000..04a1d57521024cd584d98b98fc0d5c06800c3eaf
--- /dev/null
+++ b/pipelines/g5nr/g5nr_to_flows.py
@@ -0,0 +1,131 @@
+import os, sys
+import glob
+import argparse
+import time
+
+# T.Rink, MacBook OSX doesn't support MPI
+# from mpi4py import MPI
+
+# MPI_COMM = MPI.COMM_WORLD
+# MPI_RANK = MPI_COMM.Get_rank()
+# MPI_SIZE = MPI_COMM.Get_size()
+
+# T.Rink, MacBook GPU (AMD Radeon Pro 5300M) doesn't support CUDA
+# N_GPUS = 2
+
+# rank_gpu = MPI_RANK % N_GPUS
+# os.environ['CUDA_VISIBLE_DEVICES'] = f'{rank_gpu}'
+os.environ['CUDA_VISIBLE_DEVICES'] = f'{0.0}'
+
+import numpy as np
+import pandas as pd
+import xarray as xr
+import matplotlib.pyplot as plt
+
+import torch
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+
+from windflow.datasets.g5nr import g5nr
+from windflow.inference.inference_flows import FlowRunner
+from windflow.datasets.utils import cartesian_to_speed
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--model_name", default="raft", type=str)
+parser.add_argument("--checkpoint", default="models/raft-size_512/checkpoint.pth.tar", type=str)
+parser.add_argument("--data_directory", default="data/G5NR/", type=str)
+parser.add_argument("--output", default="data/G5NR_Flows/raft/", type=str)
+parser.add_argument("--n_files", default=None, type=int)
+parser.add_argument("--batch_size", default=2, type=int)
+
+args = parser.parse_args()
+
+# Set model information
+model_name = args.model_name
+
+# Set data variables
+data_directory = args.data_directory
+to_directory = args.output
+
+if (MPI_RANK == 0) and (not os.path.exists(to_directory)):
+    os.makedirs(to_directory)
+
+# get file list
+files = g5nr.get_files(data_directory, 'test')
+print("Number of OSSE test files", len(files))
+
+if args.n_files is None:
+    N_files = len(files)
+else:
+    N_files = args.n_files
+
+# N_per_rank = N_files // MPI_SIZE
+# files = files.iloc[N_per_rank * MPI_RANK: N_per_rank * (MPI_RANK+1) + 1]
+
+# load model
+tile_size = 512
+overlap = 128
+
+if model_name in ['pwc-net-rmsvd', 'pwc-net-l1']:
+    runner = FlowRunner('pwc-net', 
+                        tile_size=tile_size,
+                        overlap=overlap,
+                        batch_size=args.batch_size)
+else:
+    runner = FlowRunner(model_name.replace('-guided','').replace('-unflow', ''),
+                        tile_size=tile_size,
+                        overlap=overlap,
+                        batch_size=args.batch_size)
+ 
+runner.load_checkpoint(args.checkpoint)
+
+stats = []
+
+# iterate and perform inference, compute test statistics 
+ds1 = xr.open_mfdataset(files.iloc[0].values, engine='netcdf4')
+for i in range(1, files.shape[0]): 
+    # open dataset
+    
+    ds2 = xr.open_mfdataset(files.iloc[i].values, engine='netcdf4')
+ 
+    f = os.path.basename(files.iloc[i-1]['U']) # get ds1 file
+    to_flow_file = os.path.join(to_directory, f.replace('_U_', '_WindFlow_').replace('.nc', '.zarr'))
+    if os.path.exists(to_flow_file):
+        ds1 = ds2.copy()
+        continue 
+
+    t = ds1['time'].values
+
+    output_ds = xr.zeros_like(ds1)
+    del output_ds['QV'], output_ds['tpw']
+
+    U_flows = np.zeros(ds1.U.shape)
+    V_flows = np.zeros(ds1.V.shape)
+
+    t0 = time.time()
+    
+    for i, lev in enumerate(ds1.lev):
+        qv1_lev = ds1.sel(lev=lev)['QV'].values[0]
+        qv2_lev = ds2.sel(lev=lev)['QV'].values[0]
+        _, flows_lev = runner.forward(qv1_lev, qv2_lev)
+        U_flows[0,i] = flows_lev[0]
+        V_flows[0,i] = flows_lev[1]
+
+    output_ds['U'] = output_ds['U'] + U_flows
+    output_ds['V'] = output_ds['V'] + V_flows
+
+    output_ds = cartesian_to_speed(output_ds)
+ 
+    output_ds.attrs['Source'] = 'NEX'
+    output_ds.attrs['Title'] = 'Optical Flow Feature Tracking'
+    output_ds.attrs['Contact'] = 'vandal@baeri.org'
+    output_ds.attrs['History'] = 'G5NR outputs from GEOS-5 by gmao processed by NEX optical flow.'
+    output_ds.attrs['Model'] = model_name
+    output_ds.attrs['Pytorch_Checkpoint'] = args.checkpoint
+
+    #output_ds.to_netcdf(to_flow_file)
+    output_ds.to_zarr(to_flow_file)
+    #print(f'Wrote to file {to_flow_file}')
+    print(f"Wrote to file: {to_flow_file} -- Processing time {time.time()-t0} (seconds)")
+
+    ds1 = ds2.copy()
diff --git a/pipelines/geo/geo_flows_to_zarr.py b/pipelines/geo/geo_flows_to_zarr.py
new file mode 100644
index 0000000000000000000000000000000000000000..272e67090fd28f055ad85c6d0e4e1c135cb1187b
--- /dev/null
+++ b/pipelines/geo/geo_flows_to_zarr.py
@@ -0,0 +1,128 @@
+'''
+Author: TJ Vandal
+
+Process NOAA L1b Files to Flow files between a range of dates
+
+Inputs: 10-minute NOAA L1b files, Optical Flow Model
+Output: 10-minute wind vectors
+'''
+
+import os, sys
+import time
+
+# from mpi4py import MPI
+#
+# comm = MPI.COMM_WORLD
+# rank = comm.Get_rank()
+# size = comm.Get_size()
+# name = MPI.Get_processor_name()
+rank = -1
+size = -1
+name = ''
+
+import torch
+
+# sys.stdout.write(f"Rank {rank} Size {size} Name {name} CUDA: {os.environ['CUDA_VISIBLE_DEVICES']} Devices count: {torch.cuda.device_count()}\n")
+# sys.stdout.write(f'number of devices: {torch.cuda.device_count()}\n')
+# sys.stdout.flush()
+
+device = 0  # 'cuda:0'#torch.cuda.current_device() #'cuda:0'# + str(rank % N_GPUS)
+# torch.cuda.set_device(device)
+
+import glob
+import argparse
+import datetime as dt
+
+import xarray as xr
+import numpy as np
+
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+
+from windflow.datasets.geostationary import goesr, stats
+from windflow.inference import inference_flows
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--model_name", default="raft", type=str)
+parser.add_argument("--checkpoint",
+                    default="/nobackupp10/tvandal/windflow/models/raft-size_512/checkpoint.pth.tar", type=str)
+parser.add_argument('--start_date', type=lambda s: dt.datetime.strptime(s, '%Y-%m-%d'),
+                    default='2021-01-01')
+parser.add_argument('--end_date', type=lambda s: dt.datetime.strptime(s, '%Y-%m-%d'), default='2021-01-02')
+parser.add_argument('--timestep', type=int, default=10)
+parser.add_argument('--band', type=int, default=10)
+parser.add_argument('--data_path', type=str, default='/nex/datapool/geonex/public/GOES16/NOAA-L1B/')
+parser.add_argument('--spatial', type=str, default=None)
+parser.add_argument('--upsample', type=float, default=None)
+parser.add_argument('--batch_size', type=int, default=1)
+parser.add_argument('--product', type=str, default='ABI-L1b-RadF')
+parser.add_argument('--output_path', type=str, default='data/L1B-WindFlow/')
+args = parser.parse_args()
+
+
+curr_date = args.start_date
+dates = []
+while curr_date <= args.end_date:
+    dates.append(curr_date)
+    #curr_date = curr_date + dt.timedelta(minutes=args.timestep)
+    curr_date = curr_date + dt.timedelta(hours=1)
+
+rank_idxs = np.arange(rank, len(dates)-1, size)
+
+inference = inference_flows.GeoFlows(args.model_name.replace('-guided', ''), 
+                                     args.data_path, 
+                                     overlap=128, 
+                                     tile_size=512,
+                                     channels=[args.band], 
+                                     product=args.product,
+                                     device=device,
+                                     batch_size=args.batch_size, 
+                                     timestep=args.timestep,
+                                     upsample_data=args.upsample, 
+                                     spatial=args.spatial)
+inference.load_checkpoint(args.checkpoint)
+
+for idx in rank_idxs:
+    t = dates[idx]
+
+    filename = f'GeoNEX_AMV_Flows-{args.product}-B{args.band:02g}-{t.year}{t.month:02g}{t.day:02g}_{t.hour:02g}{t.minute:02g}.{args.model_name}.zarr'
+    output_file_path = os.path.join(args.output_path, args.product, f'{t.year}', 
+                               f'{t.timetuple().tm_yday:03g}')
+    if not os.path.exists(output_file_path):
+        os.makedirs(output_file_path)
+
+    output_file = os.path.join(output_file_path, filename)
+    if os.path.exists(output_file):
+        continue
+
+    sys.stdout.write(f"Processing time {t}\n")
+    sys.stdout.flush()
+
+    #data, flows = inference.flows_by_time(t)
+    try:
+        t0 = time.time()
+        data = inference.flows_by_time(t, reproject=True)
+                
+        '''
+        img1 = data['Rad'].values
+        img1[img1 == 0] = np.nan
+        flows[0][img1 != img1] = np.nan
+        flows[1][img1 != img1] = np.nan
+        '''
+
+        ds_out = xr.Dataset(dict(U=data['U'], V=data['V']))
+        ds_out['band'] = args.band
+        ds_out['lat'] = data['lat']
+        ds_out['lon'] = data['lon']
+        ds_out['t'] = t
+        #ds_out.attrs['input_dataset_name'] = data.attrs['dataset_name']
+
+        #ds_out.to_netcdf(output_file)
+        ds_out.chunk({'lat': 1000, 'lon': 1000}).to_zarr(output_file)
+        print(f"Wrote to file: {output_file} -- Processing time {time.time()-t0} (seconds)")
+        sys.stdout.write(output_file + '\n')
+        sys.stdout.flush()
+
+        del ds_out, data
+    except Exception as err:
+        print('Exception Raised', err, t)
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ed95c4f1bd37599d5f16645545222274cf4976
--- /dev/null
+++ b/train.py
@@ -0,0 +1,167 @@
+import sys
+import os
+import time
+import argparse
+
+from torch.utils.tensorboard import SummaryWriter # there is a bug with summarywriter importing
+from collections import OrderedDict
+
+import numpy as np
+import torch
+torch.cuda.empty_cache()
+
+import torch.nn as nn
+from torch import optim
+from torch.utils import data
+from torch.nn.parallel import DistributedDataParallel as DDP
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from windflow.datasets import get_dataset
+from windflow.networks.models import get_flow_model
+from windflow.train.trainers import get_flow_trainer
+
+os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning' 
+
+
+def setup(rank, world_size, port):
+    '''
+    Setup multi-gpu processing group
+    Args:
+        rank: current rank
+        world_size: number of processes
+        port: which port to connect to
+    Returns:
+        None
+    '''
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = f'{port}'
+
+    # initialize the process group
+    dist.init_process_group("gloo", rank=rank, world_size=world_size)
+
+
+def cleanup():
+    dist.destroy_process_group()
+
+
+def train_net(params, rank=0):
+    print('CUDA is available? ', torch.cuda.is_available())
+    # set device
+    # if not device:
+        # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    # device = rank  #% N_DEVICES
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    print(f'in train_net rank {rank} on device {device}')
+
+    if params['ngpus'] > 1:
+        distribute = True
+    else:
+        distribute = False
+
+    dataset_train, dataset_valid = get_dataset(params['dataset'], params['data_path'],
+                                              scale_factor=params['scale_input'], 
+                                               frames=params['input_frames'])
+    
+    data_params = {'batch_size': params['batch_size'] // params['ngpus'], 'shuffle': True,
+                   'num_workers': 8, 'pin_memory': True}
+    training_generator = data.DataLoader(dataset_train, **data_params)
+    val_generator = data.DataLoader(dataset_valid, **data_params)
+
+    model = get_flow_model(params['model_name'], small=False)
+    if distribute:
+        model = DDP(model.to(device), device_ids=[device], find_unused_parameters=True)
+    
+    trainer = get_flow_trainer(model, params['model_name'], 
+                               params['model_path'], 
+                               distribute=distribute,
+                               rank=rank,
+                               lr=params['lr'],
+                               loss=params['loss'])
+
+    trainer.model.to(device)
+    trainer.load_checkpoint()
+
+    print(f'Start Training {params["model_name"]} on {params["dataset"]}')
+
+    while trainer.global_step < params['max_iterations']:
+        running_loss = 0.0
+        t0 = time.time()
+        for batch_idx, (images, flows) in enumerate(training_generator):
+            images = images.to(device)
+            flows = flows.to(device)
+
+            log = False
+            if (trainer.global_step % params['log_step'] == 0):
+                log=True
+
+            train_loss = trainer.step(images, flows, 
+                                      log=log, train=True)
+            if np.isinf(train_loss.cpu().detach().numpy()):
+                print(train_loss, I0.cpu().detach().numpy(), I1.cpu().detach().numpy().flatten())
+                return
+
+            if log:
+                images, flows = next(iter(val_generator))
+                valid_loss = trainer.step(images.to(device), flows.to(device), 
+                                          log=log, train=False)
+                print(f'Rank {trainer.rank} @ Step: {trainer.global_step-1}, Training Loss: {train_loss:.3f}, Valid Loss: {valid_loss:.3f}')
+
+            if (trainer.global_step % params['checkpoint_step'] == 0) and (rank == 0):
+                trainer.save_checkpoint()
+
+    return best_validation_loss
+
+
+def manual_experiment(args):
+    train_net(vars(args))
+
+
+def train_net_mp(rank, world_size, port, params):
+    '''
+    Setup and train on node
+    '''
+    setup(rank, world_size, port)
+    train_net(params, rank=rank)
+
+
+def run_training(args, world_size, port):
+    #params['batch_size'] = params['batch_size'] // world_size
+    if world_size > 1:
+        mp.spawn(train_net_mp,
+                 args=(world_size, port, vars(args)),
+                 nprocs=world_size,
+                 join=True)
+        cleanup()
+    else:
+        train_net(vars(args))
+
+
+if __name__ == "__main__":
+    # Feb 1 2020, Band 1 hyper-parameter search
+    # {'w': 0.6490224421024322, 's': 0.22545622639358046, 'batch_size': 128}
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model_path", default="models/raft-size_512/", type=str)
+    parser.add_argument("--dataset", default="g5nr", type=str)
+    parser.add_argument("--data_path", default="data/G5NR_patches/",    type=str)
+    parser.add_argument("--input_frames", default=2, type=int)
+    parser.add_argument("--model_name", default="raft", type=str)
+    # parser.add_argument("--gpus", default="0", type=str)
+    parser.add_argument("--batch_size", default=4, type=int)
+    parser.add_argument("--lr", default=1e-4, type=float)
+    parser.add_argument("--scale_input", default=None, type=float, help='Bilinear interpolation on input image.')
+    parser.add_argument('--max_iterations',type=int, default=2000000, help='Number of training iterations')
+    parser.add_argument("--log_step", default=1000, type=int)
+    parser.add_argument("--checkpoint_step", default=5000, type=int)
+    parser.add_argument("--ngpus", default=1, type=int)
+    parser.add_argument("--port", default=9009, type=int)
+    parser.add_argument("--loss", default='L1', type=str)
+
+    args = parser.parse_args()
+    print('torch.cuda.device_count: ', torch.cuda.device_count())
+    if torch.cuda.device_count() < args.ngpus:
+        print(f"Cannot running training because {args.ngpus} are not available.")
+
+    run_training(args, args.ngpus, args.port)
diff --git a/windflow/__init__.py b/windflow/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/datasets/__init__.py b/windflow/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb6f9ed4d36f0ec9b43b5cc3b9aa4e1ed338aad5
--- /dev/null
+++ b/windflow/datasets/__init__.py
@@ -0,0 +1,2 @@
+from . import utils
+from .datasets import *
diff --git a/windflow/datasets/datasets.py b/windflow/datasets/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..9953d76b5f58dea5eafccf31f63b4dd3db58580f
--- /dev/null
+++ b/windflow/datasets/datasets.py
@@ -0,0 +1,23 @@
+import os
+
+from .g5nr import G5NRFlows
+from .geostationary import L1bPatches
+
+def get_dataset(dataset_name, data_path, size=512, scale_factor=None, frames=2):
+    # Load dataset
+    if dataset_name in ['gmao_osse_7km', 'g5nr', 'g5nr_7km']:
+        dataset_train = G5NRFlows(os.path.join(data_path, 'train'), 
+                                  size=size, scale_factor=scale_factor,
+                                  augment=True,
+                                 frames=frames)
+        dataset_valid = G5NRFlows(os.path.join(data_path, 'valid'), 
+                                  size=size, scale_factor=scale_factor,
+                                 frames=frames)
+    elif dataset_name.lower() in ['goes', 'goes16']:
+        # Load Geostationary Observations
+        dataset_train = L1bPatches(data_path, time_step=10, # minutes
+                                  size=size, mode='train')
+        dataset_valid = L1bPatches(data_path, time_step=10,
+                                  size=size, mode='valid')
+
+    return dataset_train, dataset_valid
diff --git a/windflow/datasets/flow_transforms.py b/windflow/datasets/flow_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac399377aa4864e72bf7e4439f4e36d72b58db92
--- /dev/null
+++ b/windflow/datasets/flow_transforms.py
@@ -0,0 +1,626 @@
+# Copyright (c) 2019 Henrique Morimitsu
+# This file is released under the "MIT License". Please see the LICENSE file
+# included in this package.
+#
+# Many of the functions are modifications from the FlowNetPytorch package
+# available at https://github.com/ClementPinard/FlowNetPytorch, but adapted
+# to be implemented with PyTorch functions.
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import warnings
+from abc import abstractmethod
+from typing import Any, Sequence, Tuple
+
+
+""" Common tranformation functions for augmenting data for training deep
+optical flow models. The functions most functions take two 4D torch.Tensor
+inputs: (1) a batch of images, (2) a batch of flows. The tensors are assumed
+to have a shape [N, C, H, W], where:
+    N = batch size,
+    C = number of channels (typically 3 for images and 2 for flows),
+    H = height,
+    W = width.
+"""
+
+
+class BaseTransform(object):
+    @abstractmethod
+    def __call__(self,
+                 images: Any,
+                 flows: Any) -> Tuple[Any, Any]:
+        pass
+
+
+class ToTensor(BaseTransform):
+    """ Converts a 4D numpy.ndarray or a list of 3D numpy.ndarrays into a
+    4D torch.Tensor. Unlike the torchvision implementation, no normalization is
+    applied to the values of the inputs.
+
+    Args:
+        images_order: str: optional, default 'HWC'
+            Must be one of {'CHW', 'HWC'}. Indicates whether the input images
+            have the channels first or last.
+        flows_order: str: optional, default 'HWC'
+            Must be one of {'CHW', 'HWC'}. Indicates whether the input flows
+            have the channels first or last.
+        fp16: bool: optional, default False
+            If True, the tensors use have-precision floating point.
+        device: str or torch.device: optional, default 'cpu'
+            Name of the torch device where the tensors will be put in.
+    """
+    def __init__(self,
+                 images_order: str = 'HWC',
+                 flows_order: str = 'HWC',
+                 fp16: bool = False,
+                 device: Any = 'cpu') -> None:
+        self.images_order = images_order.upper()
+        self.flows_order = flows_order.upper()
+        self.fp16 = fp16
+        self.device = device
+
+    def __call__(self,
+                 images: Any,
+                 flows: Any) -> Tuple[torch.Tensor, torch.Tensor]:
+        if isinstance(images, list) or isinstance(images, tuple):
+            images = np.stack(images, axis=0)
+        if isinstance(flows, list) or isinstance(flows, tuple):
+            flows = np.stack(flows, axis=0)
+        images = torch.from_numpy(images)
+        flows = torch.from_numpy(flows)
+        if self.images_order == 'HWC':
+            images = images.permute(0, 3, 1, 2)
+        if self.flows_order == 'HWC':
+            flows = flows.permute(0, 3, 1, 2)
+        if self.fp16:
+            images = images.half()
+            flows = flows.half()
+        else:
+            images = images.float()
+            flows = flows.float()
+        images = images.to(device=self.device)
+        flows = flows.to(device=self.device)
+        return images, flows
+
+
+class Compose(BaseTransform):
+    """ Similar to torchvision Compose. Applies a series of transforms from
+    the input list in sequence.
+
+    Args:
+        transforms_list: Sequence[BaseTransform]:
+            A sequence of transforms to be applied.
+    """
+    def __init__(self,
+                 transforms_list: Sequence[BaseTransform]) -> None:
+        self.transforms_list = transforms_list
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        for t in self.transforms_list:
+            images, flows = t(images, flows)
+        return images, flows
+
+
+class Normalize(BaseTransform):
+    """ Normalizes the images according to the equation:
+    images = (images - means) / stdevs.
+
+    Args:
+        means: Sequence[float]:
+            A list with the values of the means of each of the channels of the
+            image.
+        stdevs: Sequence[float]:
+            A list with the values of the standard deviations of each of the
+            channels of the image.
+    """
+    def __init__(self,
+                 means: Sequence[float],
+                 stdevs: Sequence[float]) -> None:
+        self.means = torch.from_numpy(np.array(means)).reshape(1, -1, 1, 1)
+        self.stdevs = torch.from_numpy(np.array(stdevs)).reshape(1, -1, 1, 1)
+        self.has_converted = False
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if not self.has_converted:
+            self.means = self.means.type(images.dtype).to(images.device)
+            self.stdevs = self.stdevs.type(images.dtype).to(images.device)
+            self.has_converted = True
+        images -= self.means
+        images /= self.stdevs
+        return images, flows
+
+
+class Crop(BaseTransform):
+    """ Crops a region at the given position of the image according to the
+    given size.
+
+    Args:
+        coordinates: Tuple[int, int]:
+            The (y, x) position of the origin of the crop. The coordinates
+            values depend on the choice of the parameters is_center. If
+            is_center == True, then coordinates should be the center point of
+            the region to be cropped. Otherwise, it should be the top-left
+            corner of the region.
+        size: Tuple[int, int]:
+            The size (height, width) of the region to be cropped.
+        is_center: bool: optional, default True
+            Determines how the coordinates will be used. See the comments about
+            coordinates.
+    """
+    def __init__(self,
+                 coordinates: Tuple[int, int],
+                 size: Tuple[int, int],
+                 is_center: bool = True) -> None:
+        y, x = coordinates
+        assert y >= 0 and x >= 0, 'The coordinates must be positive.'
+        h, w = size
+        assert h > 0 and w > 0, 'The size must be larger than zero.'
+        self.coordinates = coordinates
+        self.size = size
+        self.is_center = is_center
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        _, _, h, w = images.shape
+        ch, cw = self.size
+        cy, cx = self.coordinates
+        if self.is_center:
+            ch2 = ch // 2
+            cw2 = cw // 2
+            y1 = cy - ch2
+            x1 = cx - cw2
+            y2 = cy + ch - ch2
+            x2 = cx + cw - cw2
+        else:
+            y1 = cy
+            x1 = cx
+            y2 = y1 + ch
+            x2 = x1 + cw
+        if y1 < 0:
+            warnings.warn('y1 < 0 when cropping. It will be set as y1 = 0.',
+                          RuntimeWarning)
+            y1 = 0
+        if x1 < 0:
+            warnings.warn('x1 < 0 when cropping. It will be set as x1 = 0.',
+                          RuntimeWarning)
+            x1 = 0
+        if y2 > h:
+            warnings.warn(
+                'y2 > image_height when cropping. '
+                'It will be set as y2 = image_height.', RuntimeWarning)
+            y2 = h
+        if x2 > w:
+            warnings.warn(
+                'x2 > image_width when cropping. '
+                'It will be set as x2 = image_width.', RuntimeWarning)
+            x2 = w
+        images = images[:, :, y1:y2, x1:x2]
+        flows = flows[:, :, y1:y2, x1:x2]
+        return images, flows
+
+
+
+class CenterCrop(BaseTransform):
+    """ Crops a region in the center of the image according to the given size.
+
+    Args:
+        size: Tuple[int, int]:
+            The size (height, width) of the region to be cropped.
+    """
+    def __init__(self,
+                 size: Tuple[int, int]) -> None:
+        self.size = size
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        _, _, h, w = images.shape
+        th, tw = self.size
+        if w == tw and h == th:
+            return images, flows
+
+        x1 = (w - tw) // 2
+        y1 = (h - th) // 2
+        images = images[:, :, y1:y1+th, x1:x1+tw]
+        flows = flows[:, :, y1:y1+th, x1:x1+tw]
+        return images, flows
+
+
+class Resize(BaseTransform):
+    """ Resizes the inputs to a new given size.
+
+    Args:
+        size: Tuple[int, int]:
+            The new size (height, width) of the inputs.
+    """
+    def __init__(self,
+                 size: Tuple[int, int]) -> None:
+        self.size = size
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        yscale = float(self.size[0]) / images.shape[2]
+        xscale = float(self.size[1]) / images.shape[3]
+        images = F.interpolate(
+            images, size=self.size, mode='bilinear', align_corners=False)
+        flows = F.interpolate(
+            flows, size=self.size, mode='bilinear', align_corners=False)
+        flows[:, 0] *= yscale
+        flows[:, 1] *= xscale
+        return images, flows
+
+
+class RandomHorizontalFlip(BaseTransform):
+    """ Randomly flips the entire batch horizontally with a probability of 0.5.
+    """
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if np.random.rand() < 0.5:
+            images = torch.flip(images, dims=[3])
+            flows = torch.flip(flows, dims=[3])
+            flows[:, 0] *= -1
+        return images, flows 
+
+
+class RandomVerticalFlip(BaseTransform):
+    """ Randomly flips the entire batch vertically with a probability of 0.5.
+    """
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if np.random.rand() < 0.5:
+            images = torch.flip(images, dims=[2])
+            flows = torch.flip(flows, dims=[2])
+            flows[:, 1] *= -1
+        return images, flows
+
+
+class RandomAdditiveColor(BaseTransform):
+    """ Adds a random value to the image. The value is sampled from a normal
+    distribution.
+
+    Args:
+        stdev: float:
+            The standard deviation value for the normal distribution.
+        independent: bool: optional, default False
+            if True, one different value is sampled for each image of the
+            batch. If False, the same value is added to all images.
+    """
+    def __init__(self,
+                 stdev: float,
+                 independent: bool = False) -> None:
+        self.stdev = stdev
+        self.independent = independent
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.independent:
+            add_vals = torch.normal(0.0, self.stdev, [images.shape[0]])
+        else:
+            add_vals = torch.normal(0.0, self.stdev, [1])
+            add_vals = add_vals.repeat(images.shape[0])
+        add_vals = add_vals.type(images.dtype)
+        add_vals = add_vals.to(images.device)
+        add_vals = add_vals.reshape(-1, 1, 1, 1)
+        images += add_vals
+        return images, flows
+
+
+class RandomMultiplicativeColor(BaseTransform):
+    """ Multiplies the image by a random value. The value is sampled from an
+    uniform distribution according to the given boundaries.
+
+    Args:
+        range_min: float:
+            The minimum value of the uniform distribution.
+        range_max: float:
+            The maximum value of the uniform distribution.
+        independent: bool: optional, default False
+            if True, one different value is sampled for each image of the
+            batch. If False, all images are multiplied by the same value.
+    """
+    def __init__(self,
+                 range_min: float,
+                 range_max: float,
+                 independent: bool = False) -> None:
+        self.range_min = range_min
+        self.range_max = range_max
+        self.independent = independent
+        self.interval = range_max - range_min
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.independent:
+            mult_vals = torch.rand([images.shape[0]])
+        else:
+            mult_vals = torch.rand([1])
+            mult_vals = mult_vals.repeat(images.shape[0])
+        mult_vals = mult_vals.type(images.dtype)
+        mult_vals = mult_vals.to(images.device)
+        mult_vals *= self.interval
+        mult_vals += self.range_min
+        mult_vals = mult_vals.reshape(-1, 1, 1, 1)
+        images *= mult_vals
+        return images, flows
+
+
+class RandomGammaColor(BaseTransform):
+    """ Applies a power transform to the image by a random value. The value is
+    sampled from an uniform distribution according to the given boundaries.
+
+    Args:
+        range_min: float:
+            The minimum value of the uniform distribution.
+        range_max: float:
+            The maximum value of the uniform distribution.
+        pixel_min: float:
+            The minimum valid value of a pixel of the image.
+        pixel_max: float:
+            The maximum valid value of a pixel of the image.
+        independent: bool: optional, default False
+            if True, one different value is sampled for each image of the
+            batch. If False, all images are multiplied by the same value.
+    """
+    def __init__(self,
+                 range_min: float,
+                 range_max: float,
+                 pixel_min: float,
+                 pixel_max: float,
+                 independent: bool = False) -> None:
+        self.range_min = range_min
+        self.range_max = range_max
+        self.pixel_min = pixel_min
+        self.pixel_max = pixel_max
+        self.independent = independent
+        self.interval = range_max - range_min
+        self.pixel_interval = pixel_max - pixel_min
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.independent:
+            expo_vals = torch.rand([images.shape[0]])
+        else:
+            expo_vals = torch.rand([1])
+            expo_vals = expo_vals.repeat(images.shape[0])
+        expo_vals = expo_vals.type(images.dtype)
+        expo_vals = expo_vals.to(images.device)
+        expo_vals *= self.interval
+        expo_vals += self.range_min
+        expo_vals = torch.pow(expo_vals, -1.0)
+        expo_vals = expo_vals.reshape(-1, 1, 1, 1)
+        images = (
+            torch.pow((images-self.pixel_min)/self.pixel_interval, expo_vals) *
+            self.pixel_interval + self.pixel_min)
+        return images, flows
+
+
+class AddGaussianNoise(BaseTransform):
+    """ Adds noise to the image by summing each pixel to a value sampled from
+    a normal distribution.
+
+    Args:
+        stdev: float:
+            The standard deviation value for the normal distribution.
+    """
+    def __init__(self,
+                 stdev: float) -> None:
+        self.stdev = stdev
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        noise = torch.normal(0.0, self.stdev, images.shape)
+        noise = noise.type(images.dtype)
+        noise = noise.to(images.device)
+        images += noise
+        return images, flows
+
+
+class RandomScale(BaseTransform):
+    """ Rescale the inputs to a new size according to the given boundaries.
+
+    Args:
+        range_min: float:
+            The minimum value of the uniform distribution.
+        range_max: float:
+            The maximum value of the uniform distribution.
+    """
+    def __init__(self,
+                 range_min: float,
+                 range_max: float) -> None:
+        self.range_min = range_min
+        self.range_max = range_max
+        self.interval = range_max - range_min
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        scale = torch.rand([1])
+        scale = scale.type(flows.dtype)
+        scale = scale.to(flows.device)
+        scale *= self.interval
+        scale += self.range_min
+        new_size = (int(scale*images.shape[2]), int(scale*images.shape[3]))
+        images = F.interpolate(
+            images, size=new_size, mode='bilinear', align_corners=False)
+        flows = F.interpolate(
+            flows, size=new_size, mode='bilinear', align_corners=False)
+        flows *= scale
+        return images, flows
+
+
+class RandomCrop(BaseTransform):
+    """ Crops the inputs at a random location according to a given size.
+
+    Args:
+        size: Tuple[int, int]:
+            The size [height, width] of the crop.
+    """
+    def __init__(self,
+                 size: Tuple[int, int]) -> None:
+        self.size = size
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        _, _, h, w = images.shape
+        th, tw = self.size
+        if w == tw and h == th:
+            return images, flows
+
+        x1 = np.random.randint(0, w - tw)
+        y1 = np.random.randint(0, h - th)
+        images = images[:, :, y1:y1+th, x1:x1+tw]
+        flows = flows[:, :, y1:y1+th, x1:x1+tw]
+        return images, flows
+
+
+class RandomTranslate(BaseTransform):
+    """ Creates a translation between images by applying a random alternated
+    crop on the sequence of inputs. A translation value t is randomly selected
+    first. Then, the first image is cropped by a box translated by t. The 
+    second image will be cropped by a reversed translation -t. The third will
+    be cropped by t again, and so on...
+
+    Args:
+        translation: Tuple[int, int]
+            The maximum absolute translation values (in pixels) in the
+            vertical and horizontal axis respectively.
+    """
+    def __init__(self,
+                 translation: Tuple[int, int]) -> None:
+        self.translation = translation
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        _, _, h, w = images.shape
+        th, tw = self.translation
+        tw = np.random.randint(-tw, tw)
+        th = np.random.randint(-th, th)
+        if tw == 0 and th == 0:
+            return images, flows
+
+        new_images = images[:, :, abs(th):, abs(tw):].clone()
+        new_flows = flows[:, :, abs(th):, abs(tw):].clone()
+
+        x1, x2 = max(0, tw), min(w+tw, w)
+        y1, y2 = max(0, th), min(h+th, h)
+        new_images[::2] = images[::2, :, y1:y2, x1:x2]
+        new_flows[::2] = flows[::2, :, y1:y2, x1:x2]
+        new_flows[::2, 0] += tw
+        new_flows[::2, 1] += th
+
+        x1, x2 = max(0, -tw), min(w-tw, w)
+        y1, y2 = max(0, -th), min(h-th, h)
+        new_images[1::2] = images[1::2, :, y1:y2, x1:x2]
+        new_flows[1::2] = flows[1::2, :, y1:y2, x1:x2]
+        new_flows[1::2, 0] -= tw
+        new_flows[1::2, 1] -= th
+
+        return new_images, new_flows
+
+
+class RandomRotate(BaseTransform):
+    """ Applies random rotation to the inputs. The inputs are rotated around
+    the center of the image. First all inputs are rotated by the same random
+    major `angle`. Then, another random angle a is sampled according to `diff_angle`.
+    The first image will be rotated by a. The second image will be rotated by a
+    reversed angle -a. The third will be rotated by a again, and so on...
+
+    Args:
+        angle: float:
+            The maximum absolute value to sample the major angle from.
+        diff_angle: float: optional, default 0
+            The maximum absolute value to sample the angle difference between
+            consecutive images.
+    """
+    def __init__(self,
+                 angle: float,
+                 diff_angle: float = 0):
+        self.angle = angle
+        self.diff_angle = diff_angle
+
+    def __call__(self,
+                 images: torch.Tensor,
+                 flows: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        major_angle = np.random.uniform(-self.angle, self.angle)
+        inter_angle = np.random.uniform(-self.diff_angle, self.diff_angle)
+
+        _, _, h, w = images.shape
+
+        def generate_rotation_grid(rot_angle, batch_size):
+            vy, vx = torch.meshgrid(torch.arange(h), torch.arange(w))
+            vx = vx.type(flows.dtype)
+            vy = vy.type(flows.dtype)
+            vx = vx.to(flows.device)
+            vy = vy.to(flows.device)
+            vx -= (w-1.0)/2.0
+            vy -= (h-1.0)/2.0
+            angle_rad = rot_angle*2*np.pi / 360
+            rotx = np.cos(angle_rad)*vx - np.sin(angle_rad)*vy
+            roty = np.sin(angle_rad)*vx + np.cos(angle_rad)*vy
+            rotx = rotx / ((w-1)/2)
+            roty = roty / ((h-1)/2)
+            rot_grid = torch.stack((rotx, roty), dim=2)[None]
+            rot_grid = rot_grid.repeat(batch_size, 1, 1, 1)
+            return rot_grid
+
+        def generate_rotation_matrix(rot_angle, batch_size):
+            vx, vy = torch.meshgrid(torch.arange(h), torch.arange(w))
+            vx = vx.type(flows.dtype)
+            vy = vy.type(flows.dtype)
+            vx = vx.to(flows.device)
+            vy = vy.to(flows.device)
+            rotx = (vx - h/2.0) * (rot_angle*np.pi/180.0)
+            roty = -(vy - w/2.0) * (rot_angle*np.pi/180.0)
+            rot_mat = torch.stack((rotx, roty), dim=0)[None]
+            rot_mat = rot_mat.repeat(batch_size, 1, 1, 1)
+            return rot_mat
+
+        def rotate_flow(flow, rot_angle):
+            angle_rad = rot_angle*2*np.pi / 360
+            rot_flow = flow.clone()
+            rot_flow[:, 0] = (
+                np.cos(angle_rad)*flow[:, 0] + np.sin(angle_rad)*flow[:, 1])
+            rot_flow[:, 1] = (
+                -np.sin(angle_rad)*flow[:, 0] + np.cos(angle_rad)*flow[:, 1])
+            return rot_flow
+
+        angle = major_angle - inter_angle / 2
+        num_images = images[::2].shape[0]
+        rot_grid = generate_rotation_grid(angle, num_images)
+        mn = images.min()
+        mx = images.max()
+
+        images[::2] = F.grid_sample(images[::2], rot_grid, mode='bilinear', align_corners=True)
+
+        num_flows = flows[::2].shape[0]
+        rot_mat = generate_rotation_matrix(inter_angle, num_flows)
+        flows[::2] += rot_mat
+        flows[::2] = F.grid_sample(
+            flows[::2], rot_grid[:num_flows], mode='bilinear', align_corners=True)
+        flows[::2] = rotate_flow(flows[::2], angle)
+
+        angle = major_angle + inter_angle / 2
+        num_images = images[1::2].shape[0]
+        rot_grid = generate_rotation_grid(angle, num_images)
+        images[1::2] = F.grid_sample(images[1::2], rot_grid, mode='bilinear', align_corners=True)
+        num_flows = flows[1::2].shape[0]
+        flows[1::2] -= rot_mat
+        flows[1::2] = F.grid_sample(
+            flows[1::2], rot_grid[:num_flows], mode='bilinear', align_corners=True)
+        flows[1::2] = rotate_flow(flows[1::2], angle)
+
+        return images, flows
diff --git a/windflow/datasets/g5nr/__init__.py b/windflow/datasets/g5nr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b50680b15f72599a8e7f555eb1375539f3ccb343
--- /dev/null
+++ b/windflow/datasets/g5nr/__init__.py
@@ -0,0 +1 @@
+from .g5nr_loader import G5NRFlows
diff --git a/windflow/datasets/g5nr/g5nr.py b/windflow/datasets/g5nr/g5nr.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3e05461f1eed5e7efb290b8ed733efe0e0abe0e
--- /dev/null
+++ b/windflow/datasets/g5nr/g5nr.py
@@ -0,0 +1,91 @@
+'''
+Author: TJ Vandal
+
+GMAO G5NR files from Will McCarty in June 2021
+Script to generating training data for optical flow. 
+'''
+
+import os
+import glob
+import xarray as xr
+import numpy as np
+import pandas as pd
+
+from mpi4py import MPI
+
+comm = MPI.COMM_WORLD
+MPI_RANK = comm.Get_rank()
+MPI_SIZE = comm.Get_size()
+
+def get_files(folder, mode):
+    pattern = os.path.join(folder, '*.nc4')
+    files = sorted(glob.glob(pattern))
+
+    variables = ['QV', 'U', 'V']
+    var_files = {v: sorted([f for f in files if f'_{v}_' in f]) for v in variables}
+    var_files = pd.DataFrame(var_files)
+
+    N_files = len(var_files)
+    if mode == 'train':
+        var_files = var_files.iloc[0:int(N_files*0.7)]
+    elif mode == 'valid':
+        var_files = var_files.iloc[int(N_files*0.7):int(N_files*0.8)]
+    elif mode == 'test':
+        var_files = var_files.iloc[int(N_files*0.8):]
+
+    return var_files
+
+def make_training_data(input_folder, output_folder, mode='train', patch_size=560, step=300):
+    files = get_files(data_folder, mode)
+
+    for i in np.arange(MPI_RANK, len(files), MPI_SIZE):
+        data = xr.open_mfdataset(files.values[i])
+        data = data.sel(lat=slice(-80, 80))
+        n_lats = data.lat.shape[0]
+        n_lons = data.lon.shape[0]
+        n_lev = data.lev.shape[0]
+
+        time = data.time.values[0]
+        timestamp = (time.astype('uint64') / 1e6).astype('uint32')
+
+        for lev in data.lev.values:
+            for lat_idx in np.arange(0, n_lats-patch_size, step):
+                for lon_idx in np.arange(0, n_lons-patch_size, step):
+
+                    sub_ds = data.isel(lat=slice(lat_idx, lat_idx+patch_size),
+                                       lon=slice(lon_idx, lon_idx+patch_size))\
+                                     .sel(lev=lev)
+                    lat_0 = sub_ds.lat.values[0]
+                    lon_0 = sub_ds.lon.values[0]
+                    sub_folder = os.path.join(output_folder, f'{int(lev)}_{lat_0}_{lon_0}')
+                    sub_file = os.path.join(sub_folder, str(timestamp).zfill(11) + '.nc4')
+                    if os.path.exists(sub_file):
+                        continue
+
+                    try:
+                        if not os.path.exists(sub_folder):
+                            os.makedirs(sub_folder)
+                    except IOError:
+                        pass
+
+                    sub_ds.to_netcdf(sub_file)
+                    print(f"Rank: {MPI_RANK}, Saved patch to file {sub_file}")
+
+if __name__ == '__main__':
+    
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('data', metavar='G5NR_7km/', type=str,
+                        help='Where to save your data?')
+    parser.add_argument('output', metavar='G5NR_7km_training/', type=str,
+                        help='Where to save your data?')
+    args = parser.parse_args()
+
+
+    data_folder = args.data
+    training_data_folder = args.output
+    make_training_data(data_folder, training_data_folder+'train', mode='train')
+    make_training_data(data_folder, training_data_folder+'valid', mode='valid')
+
+
diff --git a/windflow/datasets/g5nr/g5nr_loader.py b/windflow/datasets/g5nr/g5nr_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2058656e0d9b73988a3293b07b457269c5b137a
--- /dev/null
+++ b/windflow/datasets/g5nr/g5nr_loader.py
@@ -0,0 +1,109 @@
+import os
+import glob
+import numpy as np
+import xarray as xr
+
+import torch
+from torch.utils import data
+import cv2
+
+from .. import flow_transforms
+from ..preprocess import image_histogram_equalization
+
+class TileFlows(data.Dataset):
+    def __init__(self, directory, scale_factor=None, size=None, 
+                 augment=False, frames=2, convert_cartesian=False):
+        '''
+        7-km, 30-min simulation
+        Units: m/s
+        Scale to cartesian coordinates: 1800 (seconds) / 7000 (m)
+        '''
+        self.directory = directory
+        self.files = np.array(glob.glob(os.path.join(self.directory, '*.nc4')))
+        
+        filenames = np.array([int(os.path.basename(f.replace('.nc4',''))) for f in self.files])
+        filesort = np.argsort(filenames)
+        self.files = self.files[filesort]
+        
+        self.size = size
+        self.scale_factor = scale_factor
+        self.augment = augment
+        self.frames = frames
+        self.convert_cartesian = convert_cartesian
+        
+        self.rand_transform = flow_transforms.Compose([
+            flow_transforms.ToTensor(images_order='CHW', flows_order='CHW'),  # Must come before other transforms
+            #flow_transforms.RandomRotate(15), # rotate fills in a lot of the image with zeros
+            flow_transforms.RandomCrop((size, size)),
+            flow_transforms.RandomHorizontalFlip(),
+            flow_transforms.RandomVerticalFlip(),
+            #flow_transforms.RandomAdditiveColor(0.01),
+            #flow_transforms.RandomMultiplicativeColor(0.9, 1.1)
+        ])
+
+    def __len__(self):
+        return len(self.files)-self.frames+1
+
+    def __getitem__(self, idx):
+        try:
+            ds = xr.open_mfdataset(self.files[idx:idx+self.frames])
+            h, w = ds.U.isel(time=0).shape
+            diff = ds.time.values[1] - ds.time.values[0]
+            if diff != np.timedelta64(30,'m'):
+                #print("Difference in time not 30 minutes", ds.time)
+                return self.__getitem__((idx+1) % len(self))
+
+            if self.scale_factor:
+                h, w = int(h * self.scale_factor), int(w * self.scale_factor)
+                new_lats = np.linspace(ds.lat.values[0], ds.lat.values[-1], h)
+                new_lons = np.linspace(ds.lon.values[0], ds.lon.values[-1], w)
+                ds = ds.interp(lat=new_lats, lon=new_lons)
+
+            qv = ds['QV'].values
+            u = ds['U'].values
+            v = ds['V'].values
+        except (KeyError, xr.MergeError) as err:
+            print("Error in G5NR Loader __getitem__", err)
+            return self.__getitem__((idx+self.frames) % len(self))
+        except (AttributeError) as err:
+            print(f"Error: {err}")
+            print(f"Files", self.files[idx:idx+self.frames])
+            return self.__getitem__((idx+self.frames) % len(self))
+        
+        uv = np.concatenate([u[:,np.newaxis], v[:,np.newaxis]], 1) #/ 7000. * (30 * 60)
+        uv[~np.isfinite(uv)] = 0. 
+        qv[~np.isfinite(qv)] = 0.
+
+        # pixel size differs by latitude
+        # see haversine formula -- compute in lon direction, lats are constant
+
+        if self.convert_cartesian:
+            #uv = uv * CARTESIAN_SCALE
+            
+            lat_rad = np.radians(ds.lat.values)
+            lon_rad = np.radians(ds.lon.values)
+            a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2
+            d = 2 * 6378.137 * np.arcsin(a**0.5)
+            size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # kms
+
+            uv = uv / size_per_pixel / 1000 * 1800
+        
+        qv = image_histogram_equalization(qv)
+        
+        images = [q[np.newaxis] for q in qv]
+        flows = [_uv for _uv in uv]
+        images_tensor, flow_tensor = self.rand_transform(images, flows)        
+        return images_tensor, flow_tensor
+    
+    
+class G5NRFlows(data.ConcatDataset):
+    def __init__(self, directory, scale_factor=None, size=None, augment=False, frames=2, convert_cartesian=True):
+        self.directory = directory
+        tiles = os.listdir(self.directory)
+        tile_paths = [os.path.join(self.directory, t) for t in tiles]
+        data.ConcatDataset.__init__(self, [TileFlows(p, scale_factor=scale_factor, size=size, augment=augment, frames=frames, convert_cartesian=convert_cartesian) for p in tile_paths])
+        
+        
+if __name__ == '__main__':
+    dataset = G5NRFlows('/nobackupp10/tvandal/nex-ai-opticalflow/data/G5NR_7km_training/train/', size=128)
+    print(dataset[0][0].shape)
diff --git a/windflow/datasets/geostationary/__init__.py b/windflow/datasets/geostationary/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..365315906f3a049e9561971fbddba33c460c474a
--- /dev/null
+++ b/windflow/datasets/geostationary/__init__.py
@@ -0,0 +1,2 @@
+#from .schema import BandSchema1000, BandSchema256
+from .pytorch_dataloader import L1bResized, L1bPatches
diff --git a/windflow/datasets/geostationary/geonexl1g.py b/windflow/datasets/geostationary/geonexl1g.py
new file mode 100755
index 0000000000000000000000000000000000000000..9e2745ef2b4ae98fdba14d3aaa210cad4a18effb
--- /dev/null
+++ b/windflow/datasets/geostationary/geonexl1g.py
@@ -0,0 +1,296 @@
+import os, sys
+import numpy as np
+from pyhdf.SD import SD, SDC
+from scipy import ndimage
+import glob
+import pandas as pd
+import xarray as xr
+from joblib import Parallel, delayed
+
+
+'''
+# Basic parameters
+lat_0 = 60
+lon_0 = -180
+res_x = 0.01                # 0.02 for the 2km grid
+res_y = 0.01                # 0.02 for the 2km grid
+tile_xdim = 600            # 300 for the 2km grid
+tile_ydim = 600            # 300 for the 2km grid
+
+# Input information
+hid = 14                       # 0 - 59
+vid =  5                        # 0 - 19
+x = 0                            # column/sample, 0-(tile_xdim-1) 
+y = 0                            # row/line, 0-(tile_ydim-1) 
+
+# Output formula 
+lat_ulcnr = lat_0 - (vid*tile_ydim + y)*res_y        # upper-left corner latitude
+lon_ulcnr = lon_0 + (hid*tile_xdim + x)*res_y     # upper-left corner longitude
+lat_cntr = lat_ulcnr - 0.5*res_y                            # center latitude
+lon_cntr = lon_ulcnr + 0.5*res_x                        # center longitude
+'''
+
+def get_tile_coords(tile, resolution_km=1.):
+    hid = int(tile[1:3])
+    vid = int(tile[4:6])
+    
+    lat_0 = 60
+    lon_0 = -180
+    res_x = 0.01 * resolution_km
+    res_y = 0.01 * resolution_km
+    tile_xdim = int(600 / resolution_km)
+    tile_ydim = int(600 / resolution_km)
+
+    lat_ulcnr = lat_0 - vid*tile_ydim*res_y        # upper-left corner latitude
+    lon_ulcnr = lon_0 + hid*tile_xdim*res_y     # upper-left corner longitude
+    lat_cntr = lat_ulcnr - 0.5*res_y                            # center latitude
+    lon_cntr = lon_ulcnr + 0.5*res_x    
+    lats = np.linspace(lat_0 - vid*tile_ydim*res_y, lat_0 - (vid+1)*tile_ydim*res_y+res_y, tile_ydim, endpoint=True)
+    lons = np.linspace(lon_0 + hid*tile_xdim*res_x, lon_0 + (hid+1)*tile_xdim*res_x-res_x, tile_xdim, endpoint=True)    
+    return lats, lons
+
+class L1GFile(object):
+    '''
+    Reads a single L1B file at a common resolution. Channels are bilinearly interpolated to the defined resolution.
+    Args:
+        file: Filepath to L1b
+        bands (optional): List of bands, default=list(range(1,17))
+        resolution_km (optional): Resolution in km for common grid, default=2
+    '''
+    def __init__(self, file, bands=list((range(1,17))),
+                 resolution_km=2.):
+        self.file = file
+        self.bands = bands
+        self.resolution_km = resolution_km
+        self.resolution_size = int(600. / resolution_km)
+        self.reflective_bands = list(range(1,7))
+        self.emissive_bands = list(range(7,17))
+
+    def load(self):
+        fp = SD(self.file, SDC.READ)
+        data_array = np.zeros((self.resolution_size, self.resolution_size, len(self.bands)))
+        for i, b in enumerate(self.bands):
+            b_obj = fp.select('BAND%02i' % b)
+            attrs = b_obj.attributes()
+            scale_factor = attrs['Scale_Factor']
+            #fill_value = attrs['_FillValue'] ## this fill value seems to be wrong in l1g
+            fill_value = 32768.
+            arr = b_obj.get()[:].astype(np.float32)
+            offset = attrs['Offset_Constant']
+            arr[arr == fill_value] = np.nan
+            arr *= scale_factor
+            arr += offset
+            if arr.shape[0] != self.resolution_size:
+                arr = ndimage.interpolation.zoom(arr, self.resolution_size/arr.shape[0], order=1)
+            data_array[:,:,i] = arr
+        #self.data_array = data_array
+        return data_array
+
+    def solar(self):
+        fp = SD(self.file, SDC.READ)
+        sa = fp.select('Solar_Azimuth').get()[:]
+        sz = fp.select('Solar_Zenith').get()[:]
+        sa = ndimage.interpolation.zoom(sa, self.resolution_size/sa.shape[0], order=1)
+        sz = ndimage.interpolation.zoom(sz, self.resolution_size/sz.shape[0], order=1)
+        return sa*0.01, sz*0.01
+    
+    def load_xarray(self):
+        data = self.load()
+        fp = SD(self.file, SDC.READ)
+        tile_path = os.path.join("/".join(self.file.split("/")[:-3]), 'constant')
+        tile_file = glob.glob(tile_path + '/*.hdf')[0]
+        sat, sensor, date, time, _, tile, _ = os.path.basename(self.file).replace(".hdf","").split("_")
+        lats, lons = get_tile_coords(tile, self.resolution_km)
+        sa, sz = self.solar()
+        #print(tile, 'lats', len(lats))
+
+        da = xr.DataArray(data, dims=('lat', 'lon', 'band'), coords=dict(lat=lats, lon=lons, band=self.bands))
+        zenith = xr.DataArray(sz, dims=('lat', 'lon'), coords=dict(lat=lats, lon=lons))
+        azimuth = xr.DataArray(sa, dims=('lat', 'lon'), coords=dict(lat=lats, lon=lons))
+        return xr.Dataset({'data': da, 'zenith': zenith, 'azimuth': azimuth})
+        
+class GeoNEXL1G(object):
+    '''
+    Get information on L1G data directory, available tiles, years, and files
+        file lists are locally cached to future reading as retrieving file lists
+        can be time consuming.
+    Args:
+        data_directory: directory of the L1G product
+        sensor: (G16,G17,H8)
+        /nex/datapool/geonex/public/
+    '''
+    def __init__(self, data_directory, sensor):
+        self.data_directory = data_directory
+        self.sensor = sensor
+        self.sat = os.path.basename(os.path.dirname(os.path.dirname(data_directory)))
+
+    def tiles(self):
+        tile_pattern = os.path.join(self.data_directory, 'h*v*')
+        tile_folders = glob.glob(tile_pattern)
+        tiles = [os.path.basename(t) for t in tile_folders]
+        return tiles
+
+    def years(self):
+        tile = self.tiles()[0]
+        years = os.listdir(os.path.join(self.data_directory, tile))
+        years = [int(y) for y in years if y[0] == '2']
+        return years
+
+    def hours(self):
+        return list(range(0,24))
+
+    def files(self, tile=None, year=None, dayofyear=None, cachedir='.tmp'):
+        '''
+        Args:
+            tile (optional): Tile from GeoNEX grid
+            year (optional): Year of files to get
+            dayofyear (optional): Day of year
+            cachedir (optional): Cache filelist in directory
+        Returns:
+            pd.DataFrame of filelist with year, dayofyear, hour, minute, tile, file, h, and v
+        '''
+        if tile == None:
+            tile = '*'
+        if year == None:
+            year = '*'
+        else:
+            year = str(year)
+        if dayofyear == None:
+            dayofyear = '*'
+        else:
+            dayofyear = '%03i' % dayofyear
+
+        cache_file = f'{cachedir}/filelist/{self.sat}_{self.sensor}_{tile}_{year}_{dayofyear}.pkl'
+        if os.path.exists(cache_file):
+            return pd.read_pickle(cache_file)
+
+        
+        file_pattern = os.path.join(self.data_directory, '%s/%s/%s/*.hdf' % (tile, year, dayofyear))
+        files = glob.glob(file_pattern)
+        fileinfo = []
+        for f in files:
+            root = os.path.dirname(f)
+            rl = root.split('/')
+            doy = int(rl[-1])
+            y = int(rl[-2])
+
+            fl = os.path.basename(f).split('_')
+            hour = int(fl[3][:2])
+            minute = int(fl[3][2:])
+            tile = fl[5]
+            h = int(tile[1:3])
+            v = int(tile[4:6])
+            
+            fileinfo.append(dict(year=y, dayofyear=doy, hour=hour,
+                              minute=minute, file=f, tile=tile, h=h, v=v))
+        fileinfo = pd.DataFrame(fileinfo)
+        if not os.path.exists(os.path.dirname(cache_file)):
+            os.makedirs(os.path.dirname(cache_file))
+        fileinfo.to_pickle(cache_file)
+        return fileinfo
+
+class L1GPaired(object):
+    def __init__(self, data_path1, data_path2, sensor1, sensor2):
+        '''
+        Retrieve lists of files that match temporally and spatially using two GeoNEXL1b objects
+        
+        Args:
+            data_path1: First L1G data directory 
+            data_path2: Second L1G data directory 
+            sensor1: First sensor
+            sensor2: Second sensor
+        '''
+        self.data_path1 = data_path1
+        self.data_path2 = data_path2
+        
+        self.sensor1 = sensor1
+        self.sensor2 = sensor2
+
+        self.data1 = GeoNEXL1G(data_path1, sensor1)
+        self.data2 = GeoNEXL1G(data_path2, sensor2)
+
+    def tiles(self):
+        tiles1 = self.data1.tiles()
+        tiles2 = self.data2.tiles()
+        intersect = np.intersect1d(tiles1, tiles2)
+        return intersect.tolist()
+
+    def day_pairs(self, year):
+        tile = self.tiles()[0]
+        path1 = os.path.join(self.data_path1, tile, str(year))
+        return [int(day) for day in os.listdir(path1)]
+
+    def files(self, tile=None, year=None, dayofyear=None, how='inner', cachedir='.tmp'):
+        '''
+        Get filelists for each data directory and join in space-time 
+        
+        Args:
+            tile (optional): Tile from GeoNEX grid
+            year (optional): Year of files to get
+            dayofyear (optional): Day of year
+            how (optional): Pandas method to join DataFrames, default='inner'
+            cachedir (optional): Cache filelist in directory
+        Returns:
+            pd.DataFrame of filelist with year, dayofyear, hour, minute, tile, file, h, and v
+        '''
+        files1 = self.data1.files(tile=tile, year=year, 
+                                  dayofyear=dayofyear, cachedir=cachedir)
+        files2 = self.data2.files(tile=tile, year=year, 
+                                  dayofyear=dayofyear, cachedir=cachedir)
+        get_timestamp = lambda x: '_'.join(os.path.basename(x).split('_')[2:4])
+
+        files1['timestamp'] = files1['file'].apply(get_timestamp)
+        files1 = files1.set_index(['timestamp', 'tile'])
+
+        files2['timestamp'] = files2['file'].apply(get_timestamp)
+        files2 = files2.set_index(['timestamp', 'tile'])
+        joined = files1.join(files2, lsuffix='1', rsuffix='2', how=how)
+        return joined
+
+
+def _read_file_to_xarray(f, bands, resolution_km):
+    return L1GFile(f, resolution_km=resolution_km, bands=bands).load_xarray()
+
+
+class MOSAIC(object):
+    def __init__(self, data_path, sensor, h_list=None, v_list=None):
+        self.data_path = data_path
+        self.sensor = sensor
+        self.geo = GeoNEXL1G(self.data_path, self.sensor)
+        self.h_list = h_list
+        self.v_list = v_list
+        
+    def read_data(self, year, dayofyear, hour, minute, 
+                  bands=list(range(1,17)), resolution_km=2.):
+        files = self.geo.files(year=year, dayofyear=dayofyear)
+        files = files[files.hour == hour]
+        files = files[files.minute == minute].reset_index()
+        piv = files.pivot('v', 'h')
+        files = piv['file']
+
+        if self.geo.sensor == 'H8':
+            files = pd.concat([files.loc[:,44:59], files.loc[:,0:3]], 1)
+        elif self.geo.sensor == 'G17':
+            files = pd.concat([files.loc[:,57:59], files.loc[:,0:16]], 1)
+        elif self.geo.sensor == 'G16':
+            files = files.loc[:,7:26]
+
+        #files = files.iloc[5:15,5:15]
+        all_data = []
+        if self.h_list is None:
+            h_list = files.columns.values
+        if self.v_list is None:
+            v_list = files.index.values
+
+        for v in self.v_list:
+            jobs = [] # joblib does not seem to add performance
+            for h in self.h_list:
+                f = files.loc[v, h]
+                if isinstance(f, str):
+                    jobs.append(_read_file_to_xarray(f, bands, resolution_km))
+                    #jobs.append(delayed(_read_file_to_xarray)(f, bands))
+            all_data.append(jobs)    
+            #all_data.append(Parallel(n_jobs=4)(jobs))
+
+        ds = xr.concat([xr.concat(d, 'lon') for d in all_data], 'lat')
+        return ds
\ No newline at end of file
diff --git a/windflow/datasets/geostationary/goesr.py b/windflow/datasets/geostationary/goesr.py
new file mode 100644
index 0000000000000000000000000000000000000000..0972570827985d173f24c071eb6bde75bc6de0e6
--- /dev/null
+++ b/windflow/datasets/geostationary/goesr.py
@@ -0,0 +1,346 @@
+import os, sys
+import glob
+import xarray as xr
+import pandas as pd
+import numpy as np
+import dask as da
+import datetime as dt
+import matplotlib.pyplot as plt
+#from mpl_toolkits.basemap import Basemap
+import pyproj
+import pyresample
+from .. import utils
+
+
+def get_filename_metadata(f):
+    channel = int(f.split('_')[1][-2:])
+    spatial = f.split('-')[2]
+    t1 = f.split('_')[3]
+    year = int(t1[1:5])
+    dayofyear = int(t1[5:8])
+    hour = int(t1[8:10])
+    minute = int(t1[10:12])
+    second = int(t1[12:15])
+    return dict(channel=channel, year=year, dayofyear=dayofyear, hour=hour,
+                minute=minute, second=second, spatial=spatial)
+
+class L1bBand(object):
+    def __init__(self, fpath):
+        self.fpath = fpath
+        meta = get_filename_metadata(self.fpath)
+        self.band = meta['channel']
+        self.year = meta['year']
+        self.dayofyear = meta['dayofyear']
+        self.hour = meta['hour']
+        self.minute = meta['minute']
+        self.second = meta['second']
+        self.spatial = meta['spatial']
+        self.datetime = dt.datetime(self.year, 1, 1, self.hour, self.minute, self.second//10) +\
+                dt.timedelta(days=self.dayofyear-1)
+
+    def open_dataset(self, rescale=True, force=False, chunks=None):
+        if (not hasattr(self, 'data')) or force:
+            ds = xr.open_dataset(self.fpath, chunks=chunks)
+            ds = ds.where(ds.DQF.isin([0, 1]))
+            band = ds.band_id[0]
+            # normalize radiance
+            if rescale:
+                radiance = ds['Rad']
+                if band <= 6:
+                    ds['Rad'] = ds['Rad'] * ds.kappa0
+                else:
+                    fk1 = ds.planck_fk1.values
+                    fk2 = ds.planck_fk2.values
+                    bc1 = ds.planck_bc1.values
+                    bc2 = ds.planck_bc2.values
+                    tmp = fk1 / ds["Rad"] + 1
+                    tmp = np.where(tmp > 0, tmp, 1)
+                    T = (fk2/(np.log(tmp))-bc1)/bc2
+                    radiance.values = T
+                ds['Rad'] = radiance
+            self.data = ds
+        return self.data
+
+    def plot(self, ax=None, cmap=None, norm=None):
+        if not hasattr(self, 'data'):
+            self.open_dataset()
+
+        # Satellite height
+        sat_h = self.data['goes_imager_projection'].perspective_point_height
+        # Satellite longitude
+        sat_lon = self.data['goes_imager_projection'].longitude_of_projection_origin
+
+        # The geostationary projection
+        x = self.data['x'].values * sat_h
+        y = self.data['y'].values * sat_h
+        if ax is None:
+            fig = plt.figure(figsize=[10,10])
+            ax = fig.add_subplot(111)
+            
+        m = Basemap(projection='geos', lon_0=sat_lon, resolution='i',
+                     rsphere=(6378137.00,6356752.3142),
+                     llcrnrx=x.min(),llcrnry=y.min(),
+                     urcrnrx=x.max(),urcrnry=y.max(),
+                     ax=ax)
+        m.drawcoastlines()
+        m.drawcountries()
+        m.drawstates()
+        ax.set_title('GOES-16 -- Band {}'.format(self.band), fontweight='semibold', loc='left')
+        ax.set_title('%s' % self.datetime.strftime('%H:%M UTC %d %B %Y'), loc='right')
+        return m.imshow(self.data['Rad'].values[::-1], cmap=cmap, norm=norm)
+        #return m
+        
+    def plot_infrared(self, ax=None, colortable='ir_drgb_r', colorbar=False, cmin=180, cmax=270):
+        from metpy.plots import colortables
+        ir_norm, ir_cmap = colortables.get_with_range('WVCIMSS', cmin, cmax)
+        # Use a colortable/colormap available from MetPy
+        # ir_norm, ir_cmap = colortables.get_with_range(colortable, 190, 350)
+        im = self.plot(ax, cmap=ir_cmap, norm=ir_norm)
+        if colorbar:
+            plt.colorbar(im, pad=0.01, aspect=50, ax=ax, shrink=0.85, 
+                         ticks=range(cmin,cmax,10), label='Temperature (K)')
+        return im
+    
+    def interp(self, scale):
+        if not hasattr(self, 'data'):
+            self.open_dataset()
+
+        new_x = np.linspace(self.data.x.values[0], self.data.x.values[-1], 
+                            int(self.data.x.values.shape[0] * scale))
+        new_y = np.linspace(self.data.y.values[0], self.data.y.values[-1], 
+                            int(self.data.y.values.shape[0] * scale))
+        self.data = self.data.interp(x=new_x, y=new_y)
+        
+    def latlon(self):
+        if not hasattr(self, 'lats'):
+            if not hasattr(self, 'data'):
+                self.open_dataset()
+            # Satellite height
+            sat_h = self.data['goes_imager_projection'].perspective_point_height
+            # Satellite longitude
+            sat_lon = self.data['goes_imager_projection'].longitude_of_projection_origin
+            sat_sweep= self.data['goes_imager_projection'].sweep_angle_axis
+            p = pyproj.Proj(proj='geos', h=sat_h, lon_0=sat_lon, sweep=sat_sweep)
+            X = self.data['x'].values * sat_h
+            Y = self.data['y'].values * sat_h
+            XX, YY = np.meshgrid(X, Y)
+            lons, lats = p(XX, YY, inverse=True)
+            self.lats = lats
+            self.lons = lons
+            NANs = np.isnan(self.data['Rad'].values)
+            self.lats[NANs] = np.nan
+            self.lons[NANs] = np.nan
+            self.lats[~np.isfinite(self.lats)] = np.nan
+            self.lons[~np.isfinite(self.lons)] = np.nan
+        return self.lats, self.lons
+
+    def latlon_lookup(self, lat, lon):
+        self.latlon()
+        if (lat > self.lats.min()) and (lat < self.lats.max()) and (lon > self.lons.min()) and (lon < self.lons.max()):
+            dist = ((self.lats - lat)**2 + (self.lons - lon)**2)**0.5
+            ix, iy = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
+            return ix, iy
+        
+        
+    def reproject_to_latlon(self):
+        data = self.open_dataset()
+        lats, lons = self.latlon()
+        sat_lon = data['goes_imager_projection'].longitude_of_projection_origin
+        
+        if self.band in [1,3,4,5,6]:
+            s = 0.01
+        elif self.band == 2:
+            s = 0.005
+        else:
+            s = 0.02
+        
+        lat_min = np.around(max(np.nanmin(lats), -60), 3)
+        lat_max = np.around(min(np.nanmax(lats), 60), 3)
+        lon_min = np.around(max(np.nanmin(lons), sat_lon-60), 3)
+        lon_max = np.around(min(np.nanmax(lons), sat_lon+60), 3)
+        lats_new = np.arange(lat_min, lat_max, s)
+        lons_new = np.arange(lon_min, lon_max, s)
+        lons_new, lats_new = np.meshgrid(lons_new, lats_new)
+        
+        source_def = pyresample.geometry.SwathDefinition(lats=lats, lons=lons)
+        target_def = pyresample.geometry.GridDefinition(lons=lons_new, lats=lats_new)
+        result = pyresample.kd_tree.resample_nearest(source_def, 
+                                                     data['Rad'].values,
+                                                     target_def, 
+                                                     radius_of_influence=50000, 
+                                                     epsilon=0.05)                                                              
+        data_new = xr.DataArray(result, coords=dict(lat=lats_new[:,0], lon=lons_new[0,:]), dims=('lat', 'lon'))
+        return xr.Dataset(dict(Rad=data_new))
+
+
+class GroupBandTemporal(object):
+    def __init__(self, data_list):
+        self.data = data_list
+        self.open_dataset()
+
+    def open_dataset(self):
+        for b in self.data:
+            if not hasattr(b, 'data'):
+                b.open_dataset()
+
+    def get_radiances(self, indices=None):
+        self.open_dataset()
+        xequal = np.array_equal(self.data[0].data.x.values,
+                       self.data[-1].data.x.values)
+        yequal = np.array_equal(self.data[0].data.y.values,
+                       self.data[-1].data.y.values)
+        if (not xequal) or (not yequal):
+            return
+
+        if indices is None:
+            indices = range(len(self.data))
+        data = xr.concat([self.data[i].data['Rad'] for i in indices], 'time')
+        return data
+
+    def get_radiance_patches(self, patch_size):
+        data = self.get_radiances()
+        if data is None:
+            return
+
+        data = utils.block_array(data, 2,
+                             size=patch_size,
+                             stride=patch_size)
+        data = utils.block_array(data, 2,
+                                 size=patch_size,
+                                 stride=patch_size)
+        data= data.reshape(-1, len(self.data),
+                                patch_size,
+                                patch_size)
+        return data
+
+    def add(self, band):
+        self.data.append(band)
+
+    def __len__(self):
+        return len(self.data)
+
+    def timeseries(self, ix, iy):
+        self.open_dataset()
+        data = np.array([b.data['Rad'].isel(x=ix, y=iy).values for b in self.data])
+        x = [b.datetime for b in self.data]
+        return x, data
+
+    def timeseries_latlon(self, lat, lon):
+        print("Lat: {}, Lon: {}".format(lat, lon))
+        self.open_dataset()
+        indices = [b.latlon_lookup(lat, lon) for b in self.data]
+        data = []
+        x = []
+        for i, b in enumerate(self.data):
+            if indices[i] is None:
+                print("Data bounds: ({}, {}), ({}, {})".format(b.lats.min(), b.lats.max(), b.lons.min(),
+                                               b.lons.max()))
+                continue
+            data.append(b.data['Rad'].isel(x=indices[i][0], y=indices[i][1]).values)
+            x.append(b.datetime)
+        data = np.array(data)
+        x = np.array(x)
+        return x, data
+
+class GOESL1b(object):
+    def __init__(self, product='ABI-L1b-RadF',
+                       channels=range(1,17),
+                       data_directory='/nex/datapool/geonex/public/GOES16/NOAA-L1B/'):
+        self.product = product
+        self.channels = channels
+        self.data_directory = os.path.join(data_directory, product)
+
+    def local_files(self, year=None, dayofyear=None, hour=None, spatial=None):
+        data = []
+        base_dir = self.data_directory
+
+        if year is None:
+            year = "*"
+        else:
+            year = '%04i' % year
+        if dayofyear is None:
+            dayofyear = "*"
+        else:
+            dayofyear = '%03i' % dayofyear
+        if hour is None:
+            hour = "*"
+        else:
+            hour = '%02i' % hour 
+
+        for c in self.channels:
+            path_pattern = os.path.join(self.data_directory, year, dayofyear, hour,
+                                        'OR_ABI-L1b-*C%02i_*.nc' % c)
+            channel_files = glob.glob(path_pattern)
+            for f in channel_files:
+
+                meta = get_filename_metadata(os.path.basename(f))
+                meta['file'] = f
+                data.append(meta)
+
+        data = pd.DataFrame(data)
+        if (len(data) > 0) and (spatial is not None):
+            data = data[data['spatial'] == spatial]
+
+        if len(data) > 0:
+            new_index = ['year', 'dayofyear', 'hour', 'minute', 'second', 'spatial']
+            data = data.set_index(new_index)
+            data = data.pivot(columns='channel').drop_duplicates()
+
+        return data
+
+    def snapshot_file(self, year, dayofyear, hour, minute, spatial=None):
+        t = dt.datetime(year, 1, 1, hour, minute) + dt.timedelta(days=dayofyear-1)
+        t1 = t + dt.timedelta(hours=1)
+        curr_time = t
+        files = []
+        while curr_time <= t1:
+            files.append(self.local_files(curr_time.year, curr_time.timetuple().tm_yday, 
+                                          curr_time.hour, spatial=spatial))
+            curr_time = curr_time + dt.timedelta(hours=1)
+        files = pd.concat(files)
+
+        if spatial is None:
+            spatial = files.index.get_level_values('spatial')[0]
+
+        files = files.reset_index()
+        
+        def to_dt(row):
+            tt = dt.datetime(row['year'], 1, 1, row['hour'], row['minute'], row['second']/10.) + dt.timedelta(days=row['dayofyear'].values[0]-1)
+            return tt
+            
+        times = files.apply(to_dt, axis=1)
+        
+        idx = np.argmin(np.abs(times-t))
+        return files.iloc[idx]['file']        
+
+    def snapshot_files(self, year, dayofyear, hour, minute, spatial=None):
+        hour_files = self.local_files(year, dayofyear, hour, spatial=spatial)
+
+        minutes = hour_files.index.get_level_values('minute')
+        if spatial is None:
+            spatial = hour_files.index.get_level_values('spatial')[0]
+
+        idx = np.argmin(np.abs(minutes-minute))
+        
+        minute_sel = minutes[idx]
+        return hour_files.loc[year, dayofyear, hour, minute_sel]['file']        
+
+    
+    def open_snapshot(self, year, dayofyear, hour, minute, spatial=None):
+        snapshot_files = self.snapshot_files(year, dayofyear, hour, minute, spatial=spatial)
+        rads = []
+        regrid = utils.regrid_1km
+        for c in self.channels:
+            band_file = snapshot_files[c].iloc[0]
+            l1b = L1bBand(band_file)
+            data = l1b.open_dataset()
+            rad_c = data['Rad']
+            rad_c = rad_c.expand_dims(dict(band=[c]))
+            rad_c_regrid = regrid(rad_c, c)
+            rads.append(rad_c_regrid)
+            rads[-1] = rads[-1].assign_coords(x=rads[0].x.values,
+                                              y=rads[0].y.values)
+
+        rad = xr.concat(rads, 'band')
+        #rad = rad.swapaxes(0, 1) # put time in front
+        return rad
diff --git a/windflow/datasets/geostationary/noaa_amv.py b/windflow/datasets/geostationary/noaa_amv.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf91789590dab50bac36ed523a1e8c8812ac7ff
--- /dev/null
+++ b/windflow/datasets/geostationary/noaa_amv.py
@@ -0,0 +1,16 @@
+import xarray as xr
+import numpy as np
+import datetime as dt
+
+class NOAAAMV(object):
+    def __init__(self, file):
+        self.file = file
+        self.dataset = xr.open_dataset(self.file)
+        self.dataset = self.dataset.isel(nMeasures=np.where(np.isfinite(self.dataset.wind_speed.values))[0])
+        self.time = dt.datetime(2000, 1, 1, 12) + dt.timedelta(seconds=self.dataset['time_bounds'].values[0], minutes=0)        
+        self.wind_speeds =  self.dataset.wind_speed.values
+        self.direction = self.dataset.wind_direction.values
+        self.lats = self.dataset['lat'].values
+        self.lons = self.dataset['lon'].values
+        self.vs = np.sin(self.direction / 180 * np.pi) * self.wind_speeds
+        self.us = np.cos(self.direction / 180 * np.pi) * self.wind_speeds
diff --git a/windflow/datasets/geostationary/pytorch_dataloader.py b/windflow/datasets/geostationary/pytorch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f525b5352d0f731fa9f685e0e8e9a0104b3a015
--- /dev/null
+++ b/windflow/datasets/geostationary/pytorch_dataloader.py
@@ -0,0 +1,124 @@
+import os
+import xarray as xr
+import numpy as np
+
+import torch
+from torch.utils import data
+from torchvision import transforms
+
+from . import goesr
+from . import stats
+
+from .. import flow_transforms
+from ..preprocess import image_histogram_equalization
+
+class L1bPatches(data.Dataset):
+    def __init__(self, data_directory, time_step=1, size=512, bands=[9,], mode='train'):
+        self.patch_folders = [os.path.join(data_directory, f) for f in os.listdir(data_directory)]
+        
+        self.patches = [L1bPatch(f, time_step=time_step, size=size, bands=bands, mode=mode) for f in self.patch_folders]
+        self.patches = data.ConcatDataset(self.patches)
+        self.N = len(self.patches)
+
+    def __len__(self):
+        return self.N
+
+    def __getitem__(self, idx):
+        return self.patches[idx]
+
+class L1bPatch(data.Dataset):
+    def __init__(self, data_directory, time_step=1, size=512, bands=[9,], mode='train'):
+        self.files = sorted([os.path.join(data_directory, f) for f in os.listdir(data_directory)])
+        
+        N = len(self.files)
+        if mode == 'train':
+            self.files = self.files[:int(N*0.8)]
+        elif mode == 'valid':
+            self.files = self.files[int(N*0.8):int(N*0.9)]
+        elif mode == 'test':
+            self.files = self.files[int(N*0.9):]
+        
+        self.N = len(self.files)-time_step
+        self.time_step = time_step
+        self.bands = bands
+        self.rand_transform = transforms.Compose([
+            transforms.RandomCrop((size, size)),
+            transforms.RandomHorizontalFlip(),
+            transforms.RandomVerticalFlip(),
+        ])
+        
+    def __len__(self):
+        return self.N
+
+    def __getitem__(self, idx):
+        sample_files = [self.files[idx], self.files[idx+self.time_step]]
+        N_samples = len(sample_files)
+        N_bands = len(self.bands)
+        
+        ds = [xr.open_dataset(f) for f in sample_files]
+        
+        x = np.concatenate([d['Rad'].sel(band=self.bands).values for d in ds], 0)
+        x[~np.isfinite(x)] = 0.
+        
+        mask = torch.Tensor((x[0] == x[0]).copy()).float()
+        
+        x = image_histogram_equalization(x)
+        x = self.rand_transform(torch.Tensor(x))
+        x = x.unsqueeze(1)
+        #x = [x[i:i+N_bands] for i in range(N_samples)]
+        return x, mask
+
+class L1bResized(data.Dataset):
+    def __init__(self, data_directory, time_step=1, new_size = (1024, 1024), jitter=6, band=9):
+        self.files = sorted([os.path.join(data_directory, f) for f in os.listdir(data_directory)])
+        self.N = len(self.files)-time_step
+        self.time_step = time_step
+        self.jitter = jitter
+        self.new_size = new_size
+        jitter_size = (new_size[0] + jitter, new_size[1] + jitter)
+        self.transform = transforms.Compose([transforms.ToPILImage(),
+                                            transforms.Resize(jitter_size),
+                                            transforms.ToTensor()])
+        self.mu, self.sd = stats.get_sensor_stats("ABI")        
+        self.mu = self.mu[band-1]
+        self.sd = self.sd[band-1]
+        
+    def __len__(self):
+        return self.N
+
+    def __getitem__(self, idx):
+        sample_files = [self.files[idx], self.files[idx+self.time_step]]
+        print(sample_files)
+        samples = []
+        for f in sample_files:
+            if (f[-3:] == '.nc') or (f[-4:] == '.nc4'):
+                data = xr.open_dataset(f)
+                x = data['Rad'].values.astype(np.float32)
+                x = np.copy(x)
+            elif f[-3:] == 'npy':
+                x = np.load(f).astype(np.float32)
+            x = (x - self.mu) / self.sd
+            x[~np.isfinite(x)] = 0.
+            x = self.transform(x)
+            samples.append(x)
+
+        #if np.random.uniform() > 0.5:
+        #    samples = [torch.flipud(s) for s in samples]
+        #if np.random.uniform() > 0.5:
+        #    samples = [torch.fliplr(s) for s in samples]
+        #if np.random.uniform() > 0.5:
+        #    samples = [torch.rot90(s, 1, [1, 2]) for s in samples]        
+    
+        if self.jitter > 0:
+            ix = np.random.choice(range(self.jitter))
+            iy = np.random.choice(range(self.jitter))
+            samples = [s[:,ix:ix+self.new_size[0],iy:iy+self.new_size[1]] for s in samples]
+    
+        #factor1 = np.random.uniform(0.9,1.1)
+        #factor2 = np.random.uniform(0.9,1.1)
+        #samples = [s*factor1 for s in samples]
+        #samples[1] = samples[1]*factor2
+        mask = (samples[0] != 0).float()
+        return samples, mask
+
+
diff --git a/windflow/datasets/geostationary/schema.py b/windflow/datasets/geostationary/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..40590c0be2688a328782bb10ecfd40d0125749ea
--- /dev/null
+++ b/windflow/datasets/geostationary/schema.py
@@ -0,0 +1,28 @@
+from petastorm.unischema import Unischema, UnischemaField
+from petastorm.codecs import ScalarCodec, NdarrayCodec
+from pyspark.sql.types import IntegerType, StringType, FloatType
+
+import numpy as np
+
+BandSchema1000 = Unischema('BandSchema1000', [
+   UnischemaField('index', np.int32, (), ScalarCodec(IntegerType()), False),
+   UnischemaField('timestamp', np.int32, (), ScalarCodec(IntegerType()), False),
+   UnischemaField('spatial', np.string_, (), ScalarCodec(StringType()), False),
+   UnischemaField('file', np.string_, (), ScalarCodec(StringType()), False),
+   UnischemaField('data', np.float32, (1000, 1000), NdarrayCodec(), False),
+   UnischemaField('x_ul', np.float32, (), ScalarCodec(FloatType()), False),
+   UnischemaField('y_ul', np.float32, (), ScalarCodec(FloatType()), False),
+   UnischemaField('group_id', np.float32, (), ScalarCodec(FloatType()), False),
+])
+
+BandSchema256 = Unischema('BandSchema256', [
+   UnischemaField('index', np.int32, (), ScalarCodec(IntegerType()), False),
+   UnischemaField('timestamp', np.int32, (), ScalarCodec(IntegerType()), False),
+   UnischemaField('spatial', np.string_, (), ScalarCodec(StringType()), False),
+   UnischemaField('file', np.string_, (), ScalarCodec(StringType()), False),
+   UnischemaField('data', np.float32, (256, 256), NdarrayCodec(), False),
+   UnischemaField('x_ul', np.float32, (), ScalarCodec(FloatType()), False),
+   UnischemaField('y_ul', np.float32, (), ScalarCodec(FloatType()), False),
+   UnischemaField('group_id', np.int32, (), ScalarCodec(IntegerType()), False),
+   UnischemaField('patch_id', np.int32, (), ScalarCodec(IntegerType()), False),
+])
diff --git a/windflow/datasets/geostationary/stats.py b/windflow/datasets/geostationary/stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..980b2d39f213f923aced9b5f85c1136627233e35
--- /dev/null
+++ b/windflow/datasets/geostationary/stats.py
@@ -0,0 +1,24 @@
+def get_sensor_stats(sensor):
+    '''
+    Values computed by scripts/get_data_stats.py
+    inputs
+        sensor: (AHI, ABI, G16, G17, H8)
+    output
+        mean
+        standard deviation
+    '''
+    if sensor in ['AHI', 'H8', 'H8_15']:
+        # computed from himawari
+        mu = (0.829, 0.512, 0.621, 0.663, 0.427, 0.333, 285.251, 235.759, 244.547,
+              252.391, 272.476, 251.321, 274.869, 274.522, 272.611, 259.913)
+        sd = (0.292, 0.257, 0.325, 0.362, 0.283, 0.218, 6.243, 2.628, 3.472, 4.165,
+              7.574, 4.078, 7.973, 8.122, 7.839, 5.408)
+    elif sensor in ['ABI', 'G16', 'G17']:
+        mu = (0.829, 0.621, 0.663, 0.077, 0.427, 0.333, 285.251, 235.759, 244.547,
+              252.391, 272.476, 251.321, 274.869, 274.522, 272.611, 259.913)
+        sd = (0.292, 0.325, 0.362, 0.113, 0.283, 0.218, 6.243, 2.628, 3.472, 4.165,
+              7.574, 4.078, 7.973, 8.122, 7.839, 5.408)
+    elif sensor in ['AHI12']:
+        mu = (0.06, 0.08, 0.15, 0.30, 0.32, 0.23)
+        sd = (0.03, 0.1, 0.2, 0.3, 0.3, 0.3)
+    return mu, sd
\ No newline at end of file
diff --git a/windflow/datasets/igrs/README.md b/windflow/datasets/igrs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d714fdd7c686417a09cdeba648ad06c503e1a046
--- /dev/null
+++ b/windflow/datasets/igrs/README.md
@@ -0,0 +1,8 @@
+## NOAA Integrated Global Radiosonde Archive (IGRA)
+
+
+https://www.ncdc.noaa.gov/data-access/weather-balloon/integrated-global-radiosonde-archive
+
+ftp://ftp.ncdc.noaa.gov/pub/data/igra
+
+
diff --git a/windflow/datasets/igrs/build_database.py b/windflow/datasets/igrs/build_database.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ec5db2003dd5a681c269504dabe7a7cc0087267
--- /dev/null
+++ b/windflow/datasets/igrs/build_database.py
@@ -0,0 +1,95 @@
+import os, sys
+import glob
+
+import pandas as pd
+
+import sqlite3
+
+igra_directory = '/nobackupp10/tvandal/data/IGRA/'
+
+# ['data-y2d', 'igra2-readme.txt', 'igra2-country-list.txt', 'igra2-station-list.txt', 'status.txt',
+# 'igra2-list-format.txt', 'igra2-data-format.txt', 'igra2-us-states.txt', 'igra2-readme.txt.1']
+
+#con = sqlite3.connect("data/portal_mammals.sqlite")
+# Load the data into a DataFrame
+#surveys_df = pd.read_sql_query("SELECT * from surveys", con)
+
+# Select only data for 2002
+#surveys2002 = surveys_df[surveys_df.year == 2002]
+
+# Write the new DataFrame to a new SQLite table
+#surveys2002.to_sql("surveys2002", con, if_exists="replace")
+#con.close()
+
+class RawIGRA:
+    def __init__(self, directory):
+        self.dir = directory
+        
+
+    def stations(self):
+        path = os.path.join(self.dir, 'igra2-station-list.txt')
+        #names = ['id', 'lat', 'lon', 'elevation', 'state', 'name', 'first_year', 'last_year', 'n_obs']
+        #df = pd.read_table(path, sep='\t')
+        data = []
+        with open(path, 'r') as reader:
+            for line in reader.readlines():
+                station = {'id': line[:11], 'lat': float(line[12:20]), 'lon': float(line[21:30]),
+                        'elevation': float(line[31:37]), 'state': line[38:40].strip(), 'name': line[41:71].strip(),
+                        'first_year': int(line[72:76]), 'last_year': int(line[77:81]), 
+                        'n_obs': int(line[82:88])}
+                data.append(station)
+        data = pd.DataFrame(data)
+        return data
+    
+    def station_data(self, station):
+        path = os.path.join(self.dir, f'data-y2d/{station}-data.txt')
+        if not os.path.exists(path):
+            print(f"Cannot find data for station {station}")
+            return
+        print(f"Found data for stations {station}")
+        data = []
+        
+        with open(path, 'r') as reader:
+            for i, line in enumerate(reader.readlines()):
+                if line[0] == '#': # header row
+                    record_header = {'station': station, 'year': int(line[13:17]), 'month': int(line[18:20]), 
+                          'day': int(line[21:23]), 'hour': int(line[24:26]),
+                          'reltime': int(line[27:31]), 'num_levels': int(line[32:36]),
+                          'p_src': line[37:45].strip(), 'np_src': line[46:54].strip(), 
+                          'lat': int(line[55:62])/10000., 'lon': int(line[63:71])/10000.,
+                         }
+                    #print(record_header)
+                else:
+                    record = {'level_type1': int(line[0]), 'level_type2': int(line[1]), 'elapsed_time': int(line[3:8]),
+                              'pressure': int(line[9:15]), 'pflag': line[15], 'geopotential_height': int(line[16:21]), 
+                              'zflag': line[21], 'temp': int(line[22:27]), 'tflag': line[27], 'rh': int(line[28:33]),
+                              'dpdp': int(line[34:39]), 'wind_direction': int(line[40:45]), 'wind_speed': int(line[46:51])
+                             }
+                    record.update(record_header)
+                    data.append(record)
+               
+        data = pd.DataFrame(data)
+        return data
+    
+    
+def write_to_database():  
+    conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db'))
+    
+    igra = RawIGRA(igra_directory)
+    stations = igra.stations()
+    stations.to_sql("stations", conn, if_exists="replace")
+    alldata = []
+    for i, row in stations.iterrows():
+        station_data = igra.station_data(row.id)
+        if station_data is not None:
+            station_data.to_sql("records", conn, if_exists="append", index=False)
+            conn.commit()
+    
+def query_records():
+    conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db'))
+    df = pd.read_sql_query("SELECT * from records", conn)
+    print(df)
+    
+if __name__ == '__main__':
+    write_to_database()
+    query_records()
\ No newline at end of file
diff --git a/windflow/datasets/igrs/igrs_to_dataframe.py b/windflow/datasets/igrs/igrs_to_dataframe.py
new file mode 100644
index 0000000000000000000000000000000000000000..fccca09dbbee198bfa9d0ad172cf88c97e4e9f6d
--- /dev/null
+++ b/windflow/datasets/igrs/igrs_to_dataframe.py
@@ -0,0 +1,85 @@
+import os, sys
+import glob
+
+import pandas as pd
+
+import sqlite3
+
+igra_directory = '/nobackupp10/tvandal/data/IGRA/'
+
+# ['data-y2d', 'igra2-readme.txt', 'igra2-country-list.txt', 'igra2-station-list.txt', 'status.txt',
+# 'igra2-list-format.txt', 'igra2-data-format.txt', 'igra2-us-states.txt', 'igra2-readme.txt.1']
+
+
+class RawIGRA:
+    def __init__(self, directory):
+        self.dir = directory
+        
+
+    def stations(self):
+        path = os.path.join(self.dir, 'igra2-station-list.txt')
+        #names = ['id', 'lat', 'lon', 'elevation', 'state', 'name', 'first_year', 'last_year', 'n_obs']
+        #df = pd.read_table(path, sep='\t')
+        data = []
+        with open(path, 'r') as reader:
+            for line in reader.readlines():
+                station = {'id': line[:11], 'lat': float(line[12:20]), 'lon': float(line[21:30]),
+                        'elevation': float(line[31:37]), 'state': line[38:40].strip(), 'name': line[41:71].strip(),
+                        'first_year': int(line[72:76]), 'last_year': int(line[77:81]), 
+                        'n_obs': int(line[82:88])}
+                data.append(station)
+        data = pd.DataFrame(data)
+        return data
+    
+    def station_data(self, station):
+        path = os.path.join(self.dir, f'data-y2d/{station}-data.txt')
+        if not os.path.exists(path):
+            print(f"Cannot find data for station {station}")
+            return
+        print(f"Found data for stations {station}")
+        data = []
+        
+        with open(path, 'r') as reader:
+            for i, line in enumerate(reader.readlines()):
+                if line[0] == '#': # header row
+                    record_header = {'station': station, 'year': int(line[13:17]), 'month': int(line[18:20]), 
+                          'day': int(line[21:23]), 'hour': int(line[24:26]),
+                          'reltime': int(line[27:31]), 'num_levels': int(line[32:36]),
+                          'p_src': line[37:45].strip(), 'np_src': line[46:54].strip(), 
+                          'lat': int(line[55:62])/10000., 'lon': int(line[63:71])/10000.,
+                         }
+                    #print(record_header)
+                else:
+                    record = {'level_type1': int(line[0]), 'level_type2': int(line[1]), 'elapsed_time': int(line[3:8]),
+                              'pressure': int(line[9:15]), 'pflag': line[15], 'geopotential_height': int(line[16:21]), 
+                              'zflag': line[21], 'temp': int(line[22:27]), 'tflag': line[27], 'rh': int(line[28:33]),
+                              'dpdp': int(line[34:39]), 'wind_direction': int(line[40:45]), 'wind_speed': int(line[46:51])
+                             }
+                    record.update(record_header)
+                    data.append(record)
+               
+        data = pd.DataFrame(data)
+        return data
+    
+    
+def write_to_database():  
+    conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db'))
+    
+    igra = RawIGRA(igra_directory)
+    stations = igra.stations()
+    stations.to_sql("stations", conn, if_exists="replace")
+    alldata = []
+    for i, row in stations.iterrows():
+        station_data = igra.station_data(row.id)
+        if station_data is not None:
+            station_data.to_sql("records", conn, if_exists="append", index=False)
+            conn.commit()
+    
+def query_records():
+    conn = sqlite3.connect(os.path.join(igra_directory, 'igra.db'))
+    df = pd.read_sql_query("SELECT * from records", conn)
+    print(df)
+    
+if __name__ == '__main__':
+    write_to_database()
+    query_records()
\ No newline at end of file
diff --git a/windflow/datasets/preprocess.py b/windflow/datasets/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..8875c0ca66d38c4f890c2d005e38ce4f1fdd6f12
--- /dev/null
+++ b/windflow/datasets/preprocess.py
@@ -0,0 +1,20 @@
+import numpy as np
+
+def image_histogram_equalization(image, number_bins=500):
+    # from http://www.janeriksolem.net/2009/06/histogram-equalization-with-python-and.html
+    # get image histogram
+    image_histogram, bins = np.histogram(image.flatten(), number_bins, density=True)
+    cdf = image_histogram.cumsum() # cumulative distribution function
+    cdf = cdf / cdf[-1] # normalize
+
+    # use linear interpolation of cdf to find new pixel values
+    image_equalized = np.interp(image.flatten(), bins[:-1], cdf)
+
+    image_equalized[~np.isfinite(image_equalized)] = 0. 
+    
+    return image_equalized.reshape(image.shape)#, cdf
+
+def normalize_qv(qv):
+    qv_log = np.log(qv * 1e4 + 0.001)
+    qv_log_norm = (qv_log - 1.06) / 2.15
+    return qv_log_norm
\ No newline at end of file
diff --git a/windflow/datasets/utils.py b/windflow/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ab0ce58589a66da9d193aa44d7200ab3db7d57d
--- /dev/null
+++ b/windflow/datasets/utils.py
@@ -0,0 +1,120 @@
+import xarray as xr
+import numpy as np
+import scipy.interpolate
+import dask as da
+
+def interp_dim(x, scale):
+    x0, xlast = x[0], x[-1]
+    newlength = int(len(x) * scale)
+    y = np.linspace(x0, xlast, num=newlength, endpoint=False)
+    return y
+
+def blocks(data, width=352):
+    #n = data.t.shape[0]
+    w = data.x.shape[0]
+    h = data.y.shape[0]
+    d = data.band.shape[0]
+
+    hs = np.arange(0, h, width)
+    ws = np.arange(0, w, width)
+    blocks = []
+    for hindex in hs:
+        if hindex+width > h:
+            hindex = h - width
+
+        for windex in ws:
+            if windex+width > w:
+                windex = w - width
+            blocks.append(data.sel(y=data.y.values[hindex:hindex+width],
+                                   x=data.x.values[windex:windex+width]))
+    return blocks
+
+def block_dask_array(arr, axis, size=128, stride=128):
+    arr = da.array.swapaxes(arr, axis, 0)
+    n = arr.shape[0]
+    stack = []
+    for j in range(0, n, stride):
+        j = min(j, n - size)
+        stack.append(arr[j:j+size])
+    stack = da.array.stack(stack)
+    stack = da.array.swapaxes(stack, axis+1, 1)
+    return stack
+
+def block_array(arr, axis, size=128, stride=128):
+    arr = np.swapaxes(arr, axis, 0)
+    n = arr.shape[0]
+    stack = []
+    for j in range(0, n, stride):
+        j = min(j, n - size)
+        stack.append(arr[np.newaxis,j:j+size])
+    stack = np.concatenate(stack, 0)
+    stack = np.swapaxes(stack, axis+1, 1)
+    return stack
+
+def xarray_to_block_list(arr, dim, size=128, stride=128):
+    n = arr[dim].shape[0]
+    stack = []
+    for j in range(0, n, stride):
+        j = min(j, n - size)
+        stack.append(arr.isel({dim: np.arange(j,j+size)}))
+    return stack
+
+def interp(da, scale, fillna=False):
+    xnew = interp_dim(da['x'].values, scale)
+    ynew = interp_dim(da['y'].values, scale)
+    newcoords = dict(x=xnew, y=ynew)
+    return da.interp(newcoords)
+
+def regrid_2km(da, band):
+    if band == 2:
+        return interp(da, 1. / 4, fillna=False)
+    elif band in [1, 3, 5]:
+        return interp(da, 1. / 2, fillna=False)
+    return da
+
+def regrid_1km(da, band):
+    if band == 2: #(0.5 km)
+        return interp(da, 1./2, fillna=False)
+    elif band not in [1, 3, 5]: # 2km
+        return interp(da, 2., fillna=False)
+    return da
+
+def regrid_500m(da, band):
+    if band == 2: # 500m
+        return da
+    elif band in [1, 3, 5]: # 1km
+        return interp(da, 2., fillna=False)
+    return interp(da, 4., fillna=False) # 2km
+
+
+def cartesian_to_speed(da):
+    lat_rad = np.radians(da.lat.values)
+    lon_rad = np.radians(da.lon.values)
+    a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2
+    d = 2 * 6378.137 * np.arcsin(a**0.5)
+    size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # km
+    da['U'] = da['U'] * size_per_pixel * 1000 / 1800
+    da['V'] = da['V'] * size_per_pixel * 1000 / 1800
+    return da
+
+
+def cartesian_to_speed_eco(da):
+    lat_rad = np.radians(da.lat_0.values)
+    lon_rad = np.radians(da.lon_0.values)
+    a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2
+    d = 2 * 6378.137 * np.arcsin(a**0.5)
+    size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # km
+    da['ugrd_newP'] = da['ugrd_newP'] * size_per_pixel * 1000 / 10800
+    da['vgrd_newP'] = da['vgrd_newP'] * size_per_pixel * 1000 / 10800
+    return da
+
+
+def speed_to_cartesian(da):
+    lat_rad = np.radians(da.lat.values)
+    lon_rad = np.radians(da.lon.values)
+    a = np.cos(lat_rad)**2 * np.sin((lon_rad[1]-lon_rad[0])/2)**2
+    d = 2 * 6378.137 * np.arcsin(a**0.5)
+    size_per_pixel = np.repeat(np.expand_dims(d, -1), len(lon_rad), axis=1) # km
+    da['U'] = da['U'] / size_per_pixel / 1000 * 1800 / 0.9
+    da['V'] = da['U'] / size_per_pixel / 1000 * 1800 / 0.9
+    return da
\ No newline at end of file
diff --git a/windflow/external_packages/__init__.py b/windflow/external_packages/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/external_packages/__pycache__/__init__.cpython-39.pyc b/windflow/external_packages/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca0f05cc7c666751f26bafdc177bec6bef98b6f5
Binary files /dev/null and b/windflow/external_packages/__pycache__/__init__.cpython-39.pyc differ
diff --git a/windflow/external_packages/channelnorm_package/__init__.py b/windflow/external_packages/channelnorm_package/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/external_packages/channelnorm_package/channelnorm.py b/windflow/external_packages/channelnorm_package/channelnorm.py
new file mode 100755
index 0000000000000000000000000000000000000000..76301af51f3d2d08e68c2e092e94f43a3e98cf16
--- /dev/null
+++ b/windflow/external_packages/channelnorm_package/channelnorm.py
@@ -0,0 +1,39 @@
+from torch.autograd import Function, Variable
+from torch.nn.modules.module import Module
+import channelnorm_cuda
+
+class ChannelNormFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input1, norm_deg=2):
+        assert input1.is_contiguous()
+        b, _, h, w = input1.size()
+        output = input1.new(b, 1, h, w).zero_()
+
+        channelnorm_cuda.forward(input1, output, norm_deg)
+        ctx.save_for_backward(input1, output)
+        ctx.norm_deg = norm_deg
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input1, output = ctx.saved_tensors
+
+        grad_input1 = Variable(input1.new(input1.size()).zero_())
+
+        channelnorm_cuda.backward(input1, output, grad_output.data,
+                                              grad_input1.data, ctx.norm_deg)
+
+        return grad_input1, None
+
+
+class ChannelNorm(Module):
+
+    def __init__(self, norm_deg=2):
+        super(ChannelNorm, self).__init__()
+        self.norm_deg = norm_deg
+
+    def forward(self, input1):
+        return ChannelNormFunction.apply(input1, self.norm_deg)
+
diff --git a/windflow/external_packages/channelnorm_package/channelnorm_cuda.cc b/windflow/external_packages/channelnorm_package/channelnorm_cuda.cc
new file mode 100755
index 0000000000000000000000000000000000000000..69d82eb184e97b2eefa9810ad156d1104cf84745
--- /dev/null
+++ b/windflow/external_packages/channelnorm_package/channelnorm_cuda.cc
@@ -0,0 +1,31 @@
+#include <torch/torch.h>
+#include <ATen/ATen.h>
+
+#include "channelnorm_kernel.cuh"
+
+int channelnorm_cuda_forward(
+    at::Tensor& input1, 
+    at::Tensor& output,
+    int norm_deg) {
+
+    channelnorm_kernel_forward(input1, output, norm_deg);
+    return 1;
+}
+
+
+int channelnorm_cuda_backward(
+    at::Tensor& input1, 
+    at::Tensor& output,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1,
+    int norm_deg) {
+
+    channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
+    return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
+  m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
+}
+
diff --git a/windflow/external_packages/channelnorm_package/channelnorm_kernel.cu b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cu
new file mode 100755
index 0000000000000000000000000000000000000000..99ace6855a61373443a6ddff7c7858eb474e9e48
--- /dev/null
+++ b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cu
@@ -0,0 +1,177 @@
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include "channelnorm_kernel.cuh"
+
+#define CUDA_NUM_THREADS 512 
+
+#define DIM0(TENSOR) ((TENSOR).x)
+#define DIM1(TENSOR) ((TENSOR).y)
+#define DIM2(TENSOR) ((TENSOR).z)
+#define DIM3(TENSOR) ((TENSOR).w)
+
+#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
+
+using at::Half;
+
+template <typename scalar_t>
+__global__ void kernel_channelnorm_update_output(
+    const int n, 
+    const scalar_t* __restrict__ input1,
+    const long4 input1_size,
+    const long4 input1_stride,
+    scalar_t* __restrict__ output, 
+    const long4 output_size,
+    const long4 output_stride,
+    int norm_deg) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (index >= n) {
+        return;
+    }
+
+    int dim_b = DIM0(output_size);
+    int dim_c = DIM1(output_size);
+    int dim_h = DIM2(output_size);
+    int dim_w = DIM3(output_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    int i1dim_c = DIM1(input1_size);
+    int i1dim_h = DIM2(input1_size);
+    int i1dim_w = DIM3(input1_size);
+    int i1dim_chw = i1dim_c * i1dim_h * i1dim_w;
+    int i1dim_hw  = i1dim_h * i1dim_w;
+
+    float result = 0.0;
+
+    for (int c = 0; c < i1dim_c; ++c) {
+        int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x;
+        scalar_t val = input1[i1Index];
+        result += static_cast<float>(val * val);
+    }
+    result = sqrt(result);
+    output[index] = static_cast<scalar_t>(result);
+}
+
+
+template <typename scalar_t>
+__global__ void kernel_channelnorm_backward_input1(
+    const int n,
+    const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+    const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, 
+    const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+    scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, 
+    int norm_deg) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (index >= n) {
+        return;
+    }
+
+    float val = 0.0;
+
+    int dim_b = DIM0(gradInput_size);
+    int dim_c = DIM1(gradInput_size);
+    int dim_h = DIM2(gradInput_size);
+    int dim_w = DIM3(gradInput_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+
+    int outIndex = b * dim_hw + y * dim_w + x;
+    val = static_cast<float>(gradOutput[outIndex]) * static_cast<float>(input1[index]) / (static_cast<float>(output[outIndex])+1e-9);
+    gradInput[index] = static_cast<scalar_t>(val);
+
+}
+
+void channelnorm_kernel_forward(
+    at::Tensor& input1, 
+    at::Tensor& output, 
+    int norm_deg) {
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+    const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+    int n = output.numel();
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] {
+
+      kernel_channelnorm_update_output<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+          n,
+          input1.data<scalar_t>(), 
+          input1_size,
+          input1_stride, 
+          output.data<scalar_t>(),
+          output_size,
+          output_stride, 
+          norm_deg);
+
+    }));
+
+      // TODO: ATen-equivalent check
+
+     // THCudaCheck(cudaGetLastError());
+}
+
+void channelnorm_kernel_backward(
+    at::Tensor& input1, 
+    at::Tensor& output,
+    at::Tensor& gradOutput, 
+    at::Tensor& gradInput1, 
+    int norm_deg) {
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+    const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+    const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
+    const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
+
+    const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
+    const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
+
+    int n = gradInput1.numel();
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] {
+
+      kernel_channelnorm_backward_input1<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+          n, 
+          input1.data<scalar_t>(),
+          input1_size,
+          input1_stride,
+          output.data<scalar_t>(),
+          output_size,
+          output_stride,
+          gradOutput.data<scalar_t>(),
+          gradOutput_size,
+          gradOutput_stride, 
+          gradInput1.data<scalar_t>(),
+          gradInput1_size,
+          gradInput1_stride,
+          norm_deg
+    );
+
+    }));
+
+    // TODO: Add ATen-equivalent check
+
+//    THCudaCheck(cudaGetLastError());
+}
diff --git a/windflow/external_packages/channelnorm_package/channelnorm_kernel.cuh b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..3e6223f7fe60feb4bf9e4f66c3d849b84c89dcda
--- /dev/null
+++ b/windflow/external_packages/channelnorm_package/channelnorm_kernel.cuh
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <ATen/ATen.h>
+
+void channelnorm_kernel_forward(
+    at::Tensor& input1,
+    at::Tensor& output, 
+    int norm_deg);
+
+
+void channelnorm_kernel_backward(
+    at::Tensor& input1,
+    at::Tensor& output,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1,
+    int norm_deg);
diff --git a/windflow/external_packages/channelnorm_package/setup.py b/windflow/external_packages/channelnorm_package/setup.py
new file mode 100755
index 0000000000000000000000000000000000000000..17ecacf5450630a412e0517c462c4ffa5a532718
--- /dev/null
+++ b/windflow/external_packages/channelnorm_package/setup.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python3
+import os
+import torch
+
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+#cxx_args = ['-std=c++11']
+cxx_args = ['-std=c++14']
+
+nvcc_args = [
+    '-gencode', 'arch=compute_52,code=sm_52',
+    '-gencode', 'arch=compute_60,code=sm_60',
+    '-gencode', 'arch=compute_61,code=sm_61',
+    '-gencode', 'arch=compute_70,code=sm_70',
+    '-gencode', 'arch=compute_70,code=compute_70'
+]
+
+setup(
+    name='channelnorm_cuda',
+    ext_modules=[
+        CUDAExtension('channelnorm_cuda', [
+            'channelnorm_cuda.cc',
+            'channelnorm_kernel.cu'
+        ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/windflow/external_packages/correlation_package/__init__.py b/windflow/external_packages/correlation_package/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/external_packages/correlation_package/__pycache__/__init__.cpython-39.pyc b/windflow/external_packages/correlation_package/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab63cf9a7d0b5c3ae0b6d4a71213570ea225014f
Binary files /dev/null and b/windflow/external_packages/correlation_package/__pycache__/__init__.cpython-39.pyc differ
diff --git a/windflow/external_packages/correlation_package/__pycache__/correlation.cpython-39.pyc b/windflow/external_packages/correlation_package/__pycache__/correlation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bc4894787bb408baceb0c7c27ca549a85fc72c3
Binary files /dev/null and b/windflow/external_packages/correlation_package/__pycache__/correlation.cpython-39.pyc differ
diff --git a/windflow/external_packages/correlation_package/correlation.py b/windflow/external_packages/correlation_package/correlation.py
new file mode 100755
index 0000000000000000000000000000000000000000..80a8b09c088c360e152230af583779d771d1bca1
--- /dev/null
+++ b/windflow/external_packages/correlation_package/correlation.py
@@ -0,0 +1,62 @@
+import torch
+from torch.nn.modules.module import Module
+from torch.autograd import Function
+import correlation_cuda
+
+class CorrelationFunction(Function):
+
+    def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
+        super(CorrelationFunction, self).__init__()
+        self.pad_size = pad_size
+        self.kernel_size = kernel_size
+        self.max_displacement = max_displacement
+        self.stride1 = stride1
+        self.stride2 = stride2
+        self.corr_multiply = corr_multiply
+        # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
+
+    def forward(self, input1, input2):
+        self.save_for_backward(input1, input2)
+
+        with torch.cuda.device_of(input1):
+            rbot1 = input1.new()
+            rbot2 = input2.new()
+            output = input1.new()
+
+            correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 
+                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
+
+        return output
+
+    def backward(self, grad_output):
+        input1, input2 = self.saved_tensors
+
+        with torch.cuda.device_of(input1):
+            rbot1 = input1.new()
+            rbot2 = input2.new()
+
+            grad_input1 = input1.new()
+            grad_input2 = input2.new()
+
+            correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
+                self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
+
+        return grad_input1, grad_input2
+
+
+class Correlation(Module):
+    def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
+        super(Correlation, self).__init__()
+        self.pad_size = pad_size
+        self.kernel_size = kernel_size
+        self.max_displacement = max_displacement
+        self.stride1 = stride1
+        self.stride2 = stride2
+        self.corr_multiply = corr_multiply
+
+    def forward(self, input1, input2):
+
+        result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)
+
+        return result
+
diff --git a/windflow/external_packages/correlation_package/correlation_cuda.cc b/windflow/external_packages/correlation_package/correlation_cuda.cc
new file mode 100755
index 0000000000000000000000000000000000000000..feccd65295fa90a22564b08fc80464a76361a1aa
--- /dev/null
+++ b/windflow/external_packages/correlation_package/correlation_cuda.cc
@@ -0,0 +1,173 @@
+#include <torch/torch.h>
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <stdio.h>
+#include <iostream>
+
+#include "correlation_cuda_kernel.cuh"
+
+int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
+                       int pad_size,
+                       int kernel_size,
+                       int max_displacement,
+                       int stride1,
+                       int stride2,
+                       int corr_type_multiply)
+{
+
+  int batchSize = input1.size(0);
+
+  int nInputChannels = input1.size(1);
+  int inputHeight = input1.size(2);
+  int inputWidth = input1.size(3);
+
+  int kernel_radius = (kernel_size - 1) / 2;
+  int border_radius = kernel_radius + max_displacement;
+
+  int paddedInputHeight = inputHeight + 2 * pad_size;
+  int paddedInputWidth = inputWidth + 2 * pad_size;
+
+  int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
+
+  int outputHeight = ceil(static_cast<float>(paddedInputHeight - 2 * border_radius) / static_cast<float>(stride1));
+  int outputwidth = ceil(static_cast<float>(paddedInputWidth - 2 * border_radius) / static_cast<float>(stride1));
+
+  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
+
+  rInput1.fill_(0);
+  rInput2.fill_(0);
+  output.fill_(0);
+
+  int success = correlation_forward_cuda_kernel(
+    output,
+    output.size(0), 
+    output.size(1),
+    output.size(2),
+    output.size(3),
+    output.stride(0),
+    output.stride(1),
+    output.stride(2),
+    output.stride(3),
+    input1,
+    input1.size(1),
+    input1.size(2),
+    input1.size(3),
+    input1.stride(0),
+    input1.stride(1),
+    input1.stride(2),
+    input1.stride(3),
+    input2,
+    input2.size(1),
+    input2.stride(0),
+    input2.stride(1),
+    input2.stride(2),
+    input2.stride(3),
+    rInput1,
+    rInput2,
+    pad_size,     
+    kernel_size,
+    max_displacement,
+    stride1,
+    stride2,
+    corr_type_multiply,
+	at::cuda::getCurrentCUDAStream()
+    //at::globalContext().getCurrentCUDAStream()
+  );
+
+  //check for errors
+  if (!success) {
+    AT_ERROR("CUDA call failed");
+  }
+
+  return 1;
+
+}
+
+int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 
+                       at::Tensor& gradInput1, at::Tensor& gradInput2,
+                       int pad_size,
+                       int kernel_size,
+                       int max_displacement,
+                       int stride1,
+                       int stride2,
+                       int corr_type_multiply)
+{
+
+  int batchSize = input1.size(0);
+  int nInputChannels = input1.size(1);
+  int paddedInputHeight = input1.size(2)+ 2 * pad_size;
+  int paddedInputWidth = input1.size(3)+ 2 * pad_size;
+
+  int height = input1.size(2);
+  int width = input1.size(3);
+
+  rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+  gradInput1.resize_({batchSize, nInputChannels, height, width});
+  gradInput2.resize_({batchSize, nInputChannels, height, width});
+
+  rInput1.fill_(0);
+  rInput2.fill_(0);
+  gradInput1.fill_(0);
+  gradInput2.fill_(0);
+
+  int success = correlation_backward_cuda_kernel(gradOutput,
+                                                gradOutput.size(0),
+                                                gradOutput.size(1),
+                                                gradOutput.size(2),
+                                                gradOutput.size(3),
+                                                gradOutput.stride(0),
+                                                gradOutput.stride(1),
+                                                gradOutput.stride(2),
+                                                gradOutput.stride(3),
+                                                input1,
+                                                input1.size(1),
+                                                input1.size(2),
+                                                input1.size(3),
+                                                input1.stride(0),
+                                                input1.stride(1),
+                                                input1.stride(2),
+                                                input1.stride(3),
+                                                input2,  
+                                                input2.stride(0),
+                                                input2.stride(1),
+                                                input2.stride(2),
+                                                input2.stride(3),
+                                                gradInput1,
+                                                gradInput1.stride(0),
+                                                gradInput1.stride(1),
+                                                gradInput1.stride(2),
+                                                gradInput1.stride(3),
+                                                gradInput2,
+                                                gradInput2.size(1),
+                                                gradInput2.stride(0),
+                                                gradInput2.stride(1),
+                                                gradInput2.stride(2),
+                                                gradInput2.stride(3),
+                                                rInput1,
+                                                rInput2,
+                                                pad_size,
+                                                kernel_size,
+                                                max_displacement,
+                                                stride1, 
+                                                stride2,
+                                                corr_type_multiply,
+												at::cuda::getCurrentCUDAStream()
+                                                //at::globalContext().getCurrentCUDAStream()
+                                               );
+
+  if (!success) {
+    AT_ERROR("CUDA call failed");
+  }
+
+  return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
+  m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
+}
+
diff --git a/windflow/external_packages/correlation_package/correlation_cuda_kernel.cu b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cu
new file mode 100755
index 0000000000000000000000000000000000000000..eaf86fc129137d055de7400916567c6669b45c19
--- /dev/null
+++ b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cu
@@ -0,0 +1,564 @@
+#include <stdio.h>
+
+#include "correlation_cuda_kernel.cuh"
+
+#define CUDA_NUM_THREADS 1024
+#define THREADS_PER_BLOCK 32
+#define FULL_MASK 0xffffffff
+
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Dispatch.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+
+using at::Half;
+
+template<typename scalar_t>
+__forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) {
+        for (int offset = 16; offset > 0; offset /= 2)
+                val += __shfl_down_sync(FULL_MASK, val, offset);
+        return val;
+}
+
+template<typename scalar_t>
+__forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) {
+
+        static __shared__ scalar_t shared[32];
+        int lane = threadIdx.x % warpSize;
+        int wid = threadIdx.x / warpSize;
+
+        val = warpReduceSum(val);
+
+        if (lane == 0)
+                shared[wid] = val;
+
+        __syncthreads();
+
+        val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
+
+        if (wid == 0)
+                val = warpReduceSum(val);
+
+        return val;
+}
+
+
+template <typename scalar_t>
+__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
+{
+
+    // n (batch size), c (num of channels), y (height), x (width)
+    int n = blockIdx.x;
+    int y = blockIdx.y;
+    int x = blockIdx.z;
+
+    int ch_off = threadIdx.x;
+    scalar_t value;
+
+    int dimcyx = channels * height * width;
+    int dimyx = height * width;
+
+    int p_dimx = (width + 2 * pad_size);
+    int p_dimy = (height + 2 * pad_size);
+    int p_dimyxc = channels * p_dimy * p_dimx;
+    int p_dimxc = p_dimx * channels;
+
+    for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
+      value = input[n * dimcyx + c * dimyx + y * width + x];
+      rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
+    }
+}
+
+
+template<typename scalar_t>
+__global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels,
+                const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1,
+                const int nInputChannels, const int inputHeight, const int inputWidth,
+                const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size,
+                const int max_displacement, const int stride1, const int stride2) {
+
+        int32_t pInputWidth = inputWidth + 2 * pad_size;
+        int32_t pInputHeight = inputHeight + 2 * pad_size;
+
+        int32_t kernel_rad = (kernel_size - 1) / 2;
+
+        int32_t displacement_rad = max_displacement / stride2;
+
+        int32_t displacement_size = 2 * displacement_rad + 1;
+
+        int32_t n = blockIdx.x;
+        int32_t y1 = blockIdx.y * stride1 + max_displacement;
+        int32_t x1 = blockIdx.z * stride1 + max_displacement;
+        int32_t c = threadIdx.x;
+
+        int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+
+        int32_t pdimxc = pInputWidth * nInputChannels;
+
+        int32_t pdimc = nInputChannels;
+
+        int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth;
+        int32_t tdimyx = outputHeight * outputWidth;
+        int32_t tdimx = outputWidth;
+
+        int32_t nelems = kernel_size * kernel_size * pdimc;
+
+        // element-wise product along channel axis
+        for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
+                for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
+                        int x2 = x1 + ti * stride2;
+                        int y2 = y1 + tj * stride2;
+
+                        float acc0 = 0.0f;
+
+                        for (int j = -kernel_rad; j <= kernel_rad; ++j) {
+                                for (int i = -kernel_rad; i <= kernel_rad; ++i) {
+                                        // THREADS_PER_BLOCK
+                                        #pragma unroll
+                                        for (int ch = c; ch < pdimc; ch += blockDim.x) {
+
+                                                int indx1 = n * pdimyxc + (y1 + j) * pdimxc
+                                                                + (x1 + i) * pdimc + ch;
+                                                int indx2 = n * pdimyxc + (y2 + j) * pdimxc
+                                                                + (x2 + i) * pdimc + ch;
+                                                acc0 += static_cast<float>(rInput1[indx1] * rInput2[indx2]);
+                                        }
+                                }
+                        }
+
+                        if (blockDim.x == warpSize) {
+                            __syncwarp();
+                            acc0 = warpReduceSum(acc0);
+                        } else {
+                            __syncthreads();
+                            acc0 = blockReduceSum(acc0);
+                        }
+
+                        if (threadIdx.x == 0) {
+
+                                int tc = (tj + displacement_rad) * displacement_size
+                                                + (ti + displacement_rad);
+                                const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx
+                                                + blockIdx.z;
+                                output[tindx] = static_cast<scalar_t>(acc0 / nelems);
+                        }
+            }
+        }
+}
+
+
+template <typename scalar_t>
+__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 
+                                            const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 
+                                            const scalar_t* __restrict__ rInput2, 
+                                            int pad_size,
+                                            int kernel_size,
+                                            int max_displacement,
+                                            int stride1,
+                                            int stride2)
+  {
+    // n (batch size), c (num of channels), y (height), x (width)
+
+    int n = item; 
+    int y = blockIdx.x * stride1 + pad_size;
+    int x = blockIdx.y * stride1 + pad_size;
+    int c = blockIdx.z;
+    int tch_off = threadIdx.x;
+
+    int kernel_rad = (kernel_size - 1) / 2;
+    int displacement_rad = max_displacement / stride2;
+    int displacement_size = 2 * displacement_rad + 1;
+
+    int xmin = (x - kernel_rad - max_displacement) / stride1;
+    int ymin = (y - kernel_rad - max_displacement) / stride1;
+
+    int xmax = (x + kernel_rad - max_displacement) / stride1;
+    int ymax = (y + kernel_rad - max_displacement) / stride1;
+
+    if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
+        // assumes gradInput1 is pre-allocated and zero filled
+      return;
+    }
+
+    if (xmin > xmax || ymin > ymax) {
+        // assumes gradInput1 is pre-allocated and zero filled
+        return;
+    }
+
+    xmin = max(0,xmin);
+    xmax = min(outputWidth-1,xmax);
+
+    ymin = max(0,ymin);
+    ymax = min(outputHeight-1,ymax);
+
+    int pInputWidth = inputWidth + 2 * pad_size;
+    int pInputHeight = inputHeight + 2 * pad_size;
+
+    int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+    int pdimxc = pInputWidth * nInputChannels;
+    int pdimc = nInputChannels;
+
+    int tdimcyx = nOutputChannels * outputHeight * outputWidth;
+    int tdimyx = outputHeight * outputWidth;
+    int tdimx = outputWidth;
+
+    int odimcyx = nInputChannels * inputHeight* inputWidth;
+    int odimyx = inputHeight * inputWidth;
+    int odimx = inputWidth;
+
+    scalar_t nelems = kernel_size * kernel_size * nInputChannels;
+
+    __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
+    prod_sum[tch_off] = 0;
+
+    for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
+
+      int i2 = (tc % displacement_size - displacement_rad) * stride2;
+      int j2 = (tc / displacement_size - displacement_rad) * stride2;
+
+      int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
+      
+      scalar_t val2 = rInput2[indx2];
+
+      for (int j = ymin; j <= ymax; ++j) {
+        for (int i = xmin; i <= xmax; ++i) {
+          int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
+          prod_sum[tch_off] += gradOutput[tindx] * val2;
+        }
+      }
+    }
+    __syncthreads();
+
+    if(tch_off == 0) {
+      scalar_t reduce_sum = 0;
+      for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
+          reduce_sum += prod_sum[idx];
+      }
+      const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
+      gradInput1[indx1] = reduce_sum / nelems;
+    }
+
+}
+
+template <typename scalar_t>
+__global__ void correlation_backward_input2(int item, scalar_t*  gradInput2, int nInputChannels, int inputHeight, int inputWidth,
+                                            const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
+                                            const scalar_t* __restrict__ rInput1,
+                                            int pad_size,
+                                            int kernel_size,
+                                            int max_displacement,
+                                            int stride1,
+                                            int stride2)
+{
+    // n (batch size), c (num of channels), y (height), x (width)
+
+    int n = item;
+    int y = blockIdx.x * stride1 + pad_size;
+    int x = blockIdx.y * stride1 + pad_size;
+    int c = blockIdx.z;
+
+    int tch_off = threadIdx.x;
+
+    int kernel_rad = (kernel_size - 1) / 2;
+    int displacement_rad = max_displacement / stride2;
+    int displacement_size = 2 * displacement_rad + 1;
+
+    int pInputWidth = inputWidth + 2 * pad_size;
+    int pInputHeight = inputHeight + 2 * pad_size;
+
+    int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+    int pdimxc = pInputWidth * nInputChannels;
+    int pdimc = nInputChannels;
+
+    int tdimcyx = nOutputChannels * outputHeight * outputWidth;
+    int tdimyx = outputHeight * outputWidth;
+    int tdimx = outputWidth;
+
+    int odimcyx = nInputChannels * inputHeight* inputWidth;
+    int odimyx = inputHeight * inputWidth;
+    int odimx = inputWidth;
+
+    scalar_t nelems = kernel_size * kernel_size * nInputChannels;
+
+    __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
+    prod_sum[tch_off] = 0;
+
+    for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
+      int i2 = (tc % displacement_size - displacement_rad) * stride2;
+      int j2 = (tc / displacement_size - displacement_rad) * stride2;
+
+      int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
+      int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
+
+      int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
+      int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
+
+      if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
+          // assumes gradInput2 is pre-allocated and zero filled
+        continue;
+      }
+
+      if (xmin > xmax || ymin > ymax) {
+          // assumes gradInput2 is pre-allocated and zero filled
+          continue;
+      }
+
+      xmin = max(0,xmin);
+      xmax = min(outputWidth-1,xmax);
+
+      ymin = max(0,ymin);
+      ymax = min(outputHeight-1,ymax);
+      
+      int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
+      scalar_t val1 = rInput1[indx1];
+
+      for (int j = ymin; j <= ymax; ++j) {
+        for (int i = xmin; i <= xmax; ++i) {
+          int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
+          prod_sum[tch_off] += gradOutput[tindx] * val1;
+        }
+      }
+    }
+
+    __syncthreads();
+
+    if(tch_off == 0) {
+      scalar_t reduce_sum = 0;
+      for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
+          reduce_sum += prod_sum[idx];
+      }
+      const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
+      gradInput2[indx2] = reduce_sum / nelems;
+    }
+
+}
+
+int correlation_forward_cuda_kernel(at::Tensor& output,
+                                    int ob,
+                                    int oc,
+                                    int oh,
+                                    int ow,
+                                    int osb,
+                                    int osc,
+                                    int osh,
+                                    int osw,
+
+                                    at::Tensor& input1,
+                                    int ic,
+                                    int ih,
+                                    int iw,
+                                    int isb,
+                                    int isc,
+                                    int ish,
+                                    int isw,
+
+                                    at::Tensor& input2,
+                                    int gc,
+                                    int gsb,
+                                    int gsc,
+                                    int gsh,
+                                    int gsw,
+
+                                    at::Tensor& rInput1,
+                                    at::Tensor& rInput2,
+                                    int pad_size,
+                                    int kernel_size,
+                                    int max_displacement,
+                                    int stride1,
+                                    int stride2,
+                                    int corr_type_multiply,
+                                    cudaStream_t stream) 
+{
+
+   int batchSize = ob;
+
+   int nInputChannels = ic;
+   int inputWidth = iw;
+   int inputHeight = ih;
+
+   int nOutputChannels = oc;
+   int outputWidth = ow;
+   int outputHeight = oh;
+
+   dim3 blocks_grid(batchSize, inputHeight, inputWidth);
+   dim3 threads_block(THREADS_PER_BLOCK);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
+
+  channels_first<scalar_t><<<blocks_grid,threads_block, 0, stream>>>(
+      input1.data<scalar_t>(), rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);
+
+  }));
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
+
+  channels_first<scalar_t><<<blocks_grid,threads_block, 0, stream>>> (
+      input2.data<scalar_t>(), rInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth, pad_size);
+
+  }));
+
+   dim3 threadsPerBlock(THREADS_PER_BLOCK);
+   dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
+
+   correlation_forward<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>> 
+                        (output.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,
+                         rInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,
+                         rInput2.data<scalar_t>(),
+                         pad_size,
+                         kernel_size,
+                         max_displacement,
+                         stride1,
+                         stride2);
+
+  }));
+
+  cudaError_t err = cudaGetLastError();
+
+
+  // check for errors
+  if (err != cudaSuccess) {
+    printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
+    return 0;
+  }
+
+  return 1;
+}
+
+
+int correlation_backward_cuda_kernel(
+                                    at::Tensor& gradOutput,
+                                    int gob,
+                                    int goc,
+                                    int goh,
+                                    int gow,
+                                    int gosb,
+                                    int gosc,
+                                    int gosh,
+                                    int gosw,
+
+                                    at::Tensor& input1,
+                                    int ic,
+                                    int ih,
+                                    int iw,
+                                    int isb,
+                                    int isc,
+                                    int ish,
+                                    int isw,
+
+                                    at::Tensor& input2,
+                                    int gsb,
+                                    int gsc,
+                                    int gsh,
+                                    int gsw,
+
+                                    at::Tensor& gradInput1,
+                                    int gisb,
+                                    int gisc,
+                                    int gish,
+                                    int gisw,
+
+                                    at::Tensor& gradInput2,
+                                    int ggc,
+                                    int ggsb,
+                                    int ggsc,
+                                    int ggsh,
+                                    int ggsw,
+
+                                    at::Tensor& rInput1,
+                                    at::Tensor& rInput2,
+                                    int pad_size,
+                                    int kernel_size,
+                                    int max_displacement,
+                                    int stride1,
+                                    int stride2,
+                                    int corr_type_multiply,
+                                    cudaStream_t stream)
+{
+
+    int batchSize = gob;
+    int num = batchSize;
+
+    int nInputChannels = ic;
+    int inputWidth = iw;
+    int inputHeight = ih;
+
+    int nOutputChannels = goc;
+    int outputWidth = gow;
+    int outputHeight = goh;
+
+    dim3 blocks_grid(batchSize, inputHeight, inputWidth);
+    dim3 threads_block(THREADS_PER_BLOCK);
+
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
+
+        channels_first<scalar_t><<<blocks_grid, threads_block, 0, stream>>>(
+            input1.data<scalar_t>(),
+            rInput1.data<scalar_t>(),
+            nInputChannels,
+            inputHeight,
+            inputWidth,
+            pad_size
+        );
+    }));
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
+
+        channels_first<scalar_t><<<blocks_grid, threads_block, 0, stream>>>(
+            input2.data<scalar_t>(),
+            rInput2.data<scalar_t>(),
+            nInputChannels,
+            inputHeight,
+            inputWidth,
+            pad_size
+        );
+    }));
+
+    dim3 threadsPerBlock(THREADS_PER_BLOCK);
+    dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
+
+    for (int n = 0; n < num; ++n) {
+
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
+
+
+          correlation_backward_input1<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>> (
+              n, gradInput1.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,
+              gradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,
+              rInput2.data<scalar_t>(),
+              pad_size,
+              kernel_size,
+              max_displacement,
+              stride1,
+              stride2);
+      }));
+    }
+
+    for(int n = 0; n < batchSize; n++) {
+
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
+
+        correlation_backward_input2<scalar_t><<<totalBlocksCorr, threadsPerBlock, 0, stream>>>(
+            n, gradInput2.data<scalar_t>(), nInputChannels, inputHeight, inputWidth,
+            gradOutput.data<scalar_t>(), nOutputChannels, outputHeight, outputWidth,
+            rInput1.data<scalar_t>(),
+            pad_size,
+            kernel_size,
+            max_displacement,
+            stride1,
+            stride2);
+
+        }));
+    }
+
+  // check for errors
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess) {
+    printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
+    return 0;
+  }
+
+  return 1;
+}
diff --git a/windflow/external_packages/correlation_package/correlation_cuda_kernel.cuh b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..1586d3af6bc184bfea8482a991a6625f865f02b3
--- /dev/null
+++ b/windflow/external_packages/correlation_package/correlation_cuda_kernel.cuh
@@ -0,0 +1,91 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <cuda_runtime.h>
+
+int correlation_forward_cuda_kernel(at::Tensor& output,
+    int ob,
+    int oc,
+    int oh,
+    int ow,
+    int osb,
+    int osc,
+    int osh,
+    int osw,
+
+    at::Tensor& input1,
+    int ic,
+    int ih,
+    int iw,
+    int isb,
+    int isc,
+    int ish,
+    int isw,
+
+    at::Tensor& input2,
+    int gc,
+    int gsb,
+    int gsc,
+    int gsh,
+    int gsw,
+
+    at::Tensor& rInput1,
+    at::Tensor& rInput2,
+    int pad_size,
+    int kernel_size,
+    int max_displacement,
+    int stride1,
+    int stride2,
+    int corr_type_multiply,
+    cudaStream_t stream);
+
+
+int correlation_backward_cuda_kernel(   
+    at::Tensor& gradOutput,
+    int gob,
+    int goc,
+    int goh,
+    int gow,
+    int gosb,
+    int gosc,
+    int gosh,
+    int gosw,
+
+    at::Tensor& input1,
+    int ic,
+    int ih,
+    int iw,
+    int isb,
+    int isc,
+    int ish,
+    int isw,
+
+    at::Tensor& input2,
+    int gsb,
+    int gsc,
+    int gsh,
+    int gsw,
+
+    at::Tensor& gradInput1, 
+    int gisb,
+    int gisc,
+    int gish,
+    int gisw,
+
+    at::Tensor& gradInput2,
+    int ggc,
+    int ggsb,
+    int ggsc,
+    int ggsh,
+    int ggsw,
+
+    at::Tensor& rInput1,
+    at::Tensor& rInput2,
+    int pad_size,
+    int kernel_size,
+    int max_displacement,
+    int stride1,
+    int stride2,
+    int corr_type_multiply,
+    cudaStream_t stream);
diff --git a/windflow/external_packages/correlation_package/cuda/deform_conv_cuda.cu b/windflow/external_packages/correlation_package/cuda/deform_conv_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..087adfe9feb631c05d2872967394b218af946abf
--- /dev/null
+++ b/windflow/external_packages/correlation_package/cuda/deform_conv_cuda.cu
@@ -0,0 +1,691 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <vector>
+#include <iostream>
+#include <cmath>
+
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+                       const int channels, const int height, const int width,
+                       const int ksize_h, const int ksize_w, const int pad_h,
+                       const int pad_w, const int stride_h, const int stride_w,
+                       const int dilation_h, const int dilation_w,
+                       const int parallel_imgs, const int deformable_group,
+                       at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+                       const int channels, const int height, const int width,
+                       const int ksize_h, const int ksize_w, const int pad_h,
+                       const int pad_w, const int stride_h, const int stride_w,
+                       const int dilation_h, const int dilation_w,
+                       const int parallel_imgs, const int deformable_group,
+                       at::Tensor grad_im);
+
+void deformable_col2im_coord(
+    const at::Tensor data_col, const at::Tensor data_im,
+    const at::Tensor data_offset, const int channels, const int height,
+    const int width, const int ksize_h, const int ksize_w, const int pad_h,
+    const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int parallel_imgs,
+    const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+    const at::Tensor data_im, const at::Tensor data_offset,
+    const at::Tensor data_mask, const int batch_size, const int channels,
+    const int height_im, const int width_im, const int height_col,
+    const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int deformable_group,
+    at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+    const at::Tensor data_col, const at::Tensor data_offset,
+    const at::Tensor data_mask, const int batch_size, const int channels,
+    const int height_im, const int width_im, const int height_col,
+    const int width_col, const int kernel_h, const int kenerl_w,
+    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+    const int dilation_h, const int dilation_w, const int deformable_group,
+    at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+    const at::Tensor data_col, const at::Tensor data_im,
+    const at::Tensor data_offset, const at::Tensor data_mask,
+    const int batch_size, const int channels, const int height_im,
+    const int width_im, const int height_col, const int width_col,
+    const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+    const int stride_h, const int stride_w, const int dilation_h,
+    const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+    at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+                 at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+                 int padW, int dilationH, int dilationW, int group,
+                 int deformable_group) 
+{
+  TORCH_CHECK(weight.ndimension() == 4,
+           "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+           "but got: %s",
+           weight.ndimension());
+
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+  TORCH_CHECK(kW > 0 && kH > 0,
+           "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+           kW);
+
+  TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+           "kernel size should be consistent with weight, ",
+           "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+           kW, weight.size(2), weight.size(3));
+
+  TORCH_CHECK(dW > 0 && dH > 0,
+           "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+  TORCH_CHECK(
+      dilationW > 0 && dilationH > 0,
+      "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+      dilationH, dilationW);
+
+  int ndim = input.ndimension();
+  int dimf = 0;
+  int dimh = 1;
+  int dimw = 2;
+
+  if (ndim == 4) {
+    dimf++;
+    dimh++;
+    dimw++;
+  }
+
+  TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+           ndim);
+
+  long nInputPlane = weight.size(1) * group;
+  long inputHeight = input.size(dimh);
+  long inputWidth = input.size(dimw);
+  long nOutputPlane = weight.size(0);
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+  TORCH_CHECK(nInputPlane % deformable_group == 0,
+           "input channels must divide deformable group size");
+
+  if (outputWidth < 1 || outputHeight < 1)
+    AT_ERROR(
+        "Given input size: (%ld x %ld x %ld). "
+        "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+        nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+        outputWidth);
+
+  TORCH_CHECK(input.size(1) == nInputPlane,
+           "invalid number of input planes, expected: %d, but got: %d",
+           nInputPlane, input.size(1));
+
+  TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+           "input image is smaller than kernel");
+
+  TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+           "invalid spatial size of offset, expected height: %d width: %d, but "
+           "got height: %d width: %d",
+           outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+  TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+           "invalid number of channels of offset");
+
+  if (gradOutput != NULL) {
+    TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+             "invalid number of gradOutput planes, expected: %d, but got: %d",
+             nOutputPlane, gradOutput->size(dimf));
+
+    TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+              gradOutput->size(dimw) == outputWidth),
+             "invalid size of gradOutput, expected height: %d width: %d , but "
+             "got height: %d width: %d",
+             outputHeight, outputWidth, gradOutput->size(dimh),
+             gradOutput->size(dimw));
+  }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+                             at::Tensor offset, at::Tensor output,
+                             at::Tensor columns, at::Tensor ones, int kW,
+                             int kH, int dW, int dH, int padW, int padH,
+                             int dilationW, int dilationH, int group,
+                             int deformable_group, int im2col_step) 
+{
+  // todo: resize columns to include im2col: done
+  // todo: add im2col_step as input
+  // todo: add new output buffer and transpose it to output (or directly
+  // transpose output) todo: possibly change data indexing because of
+  // parallel_imgs
+
+  shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+              dilationH, dilationW, group, deformable_group);
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  weight = weight.contiguous();
+
+  int batch = 1;
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input.unsqueeze_(0);
+    offset.unsqueeze_(0);
+  }
+
+  // todo: assert batchsize dividable by im2col_step
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = weight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+                        outputHeight, outputWidth});
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+    ones = at::ones({outputHeight, outputWidth}, input.options());
+  }
+
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  at::Tensor output_buffer =
+      at::zeros({batchSize / im2col_step, nOutputPlane,
+                 im2col_step * outputHeight, outputWidth},
+                output.options());
+
+  output_buffer = output_buffer.view(
+      {output_buffer.size(0), group, output_buffer.size(1) / group,
+       output_buffer.size(2), output_buffer.size(3)});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, columns);
+
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      output_buffer[elt][g] = output_buffer[elt][g]
+                                  .flatten(1)
+                                  .addmm_(weight[g].flatten(1), columns[g])
+                                  .view_as(output_buffer[elt][g]);
+    }
+  }
+
+  output_buffer = output_buffer.view(
+      {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+       output_buffer.size(3), output_buffer.size(4)});
+
+  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+                                      im2col_step, outputHeight, outputWidth});
+  output_buffer.transpose_(1, 2);
+  output.copy_(output_buffer);
+  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    output = output.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+  }
+
+  return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+                                    at::Tensor gradOutput, at::Tensor gradInput,
+                                    at::Tensor gradOffset, at::Tensor weight,
+                                    at::Tensor columns, int kW, int kH, int dW,
+                                    int dH, int padW, int padH, int dilationW,
+                                    int dilationH, int group,
+                                    int deformable_group, int im2col_step) 
+{
+  shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+              dilationH, dilationW, group, deformable_group);
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  gradOutput = gradOutput.contiguous();
+  weight = weight.contiguous();
+
+  int batch = 1;
+
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input = input.view({1, input.size(0), input.size(1), input.size(2)});
+    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+    gradOutput = gradOutput.view(
+        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+  }
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = weight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  // change order of grad output
+  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+                                nOutputPlane, outputHeight, outputWidth});
+  gradOutput.transpose_(1, 2);
+
+  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                              inputHeight, inputWidth});
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+                                deformable_group * 2 * kH * kW, outputHeight,
+                                outputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    // divide into groups
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+    gradOutput = gradOutput.view(
+        {gradOutput.size(0), group, gradOutput.size(1) / group,
+         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+    for (int g = 0; g < group; g++) {
+      columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    gradOutput = gradOutput.view(
+        {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+                            dilationH, dilationW, im2col_step, deformable_group,
+                            gradOffset[elt]);
+
+    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, gradInput[elt]);
+  }
+
+  gradOutput.transpose_(1, 2);
+  gradOutput =
+      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  gradOffset = gradOffset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+    gradOffset =
+        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+  }
+
+  return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+    at::Tensor gradWeight,  // at::Tensor gradBias,
+    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+    int padW, int padH, int dilationW, int dilationH, int group,
+    int deformable_group, float scale, int im2col_step) 
+{
+  // todo: transpose and reshape outGrad
+  // todo: reshape columns
+  // todo: add im2col_step as input
+
+  shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+              padW, dilationH, dilationW, group, deformable_group);
+
+  input = input.contiguous();
+  offset = offset.contiguous();
+  gradOutput = gradOutput.contiguous();
+
+  int batch = 1;
+
+  if (input.ndimension() == 3) {
+    // Force batch
+    batch = 0;
+    input = input.view(
+        at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+    gradOutput = gradOutput.view(
+        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+  }
+
+  long batchSize = input.size(0);
+  long nInputPlane = input.size(1);
+  long inputHeight = input.size(2);
+  long inputWidth = input.size(3);
+
+  long nOutputPlane = gradWeight.size(0);
+
+  long outputWidth =
+      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+  long outputHeight =
+      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+  columns = at::zeros(
+      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+      input.options());
+
+  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+                                nOutputPlane, outputHeight, outputWidth});
+  gradOutput.transpose_(1, 2);
+
+  at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+  gradOutputBuffer =
+      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+                             outputHeight, outputWidth});
+  gradOutputBuffer.copy_(gradOutput);
+  gradOutputBuffer =
+      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+                             im2col_step * outputHeight, outputWidth});
+
+  gradOutput.transpose_(1, 2);
+  gradOutput =
+      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+                      inputHeight, inputWidth});
+  offset =
+      offset.view({batchSize / im2col_step, im2col_step,
+                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+                      dilationW, im2col_step, deformable_group, columns);
+
+    // divide into group
+    gradOutputBuffer = gradOutputBuffer.view(
+        {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+         gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    gradWeight =
+        gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+                         gradWeight.size(2), gradWeight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      gradWeight[g] = gradWeight[g]
+                          .flatten(1)
+                          .addmm_(gradOutputBuffer[elt][g].flatten(1),
+                                  columns[g].transpose(1, 0), 1.0, scale)
+                          .view_as(gradWeight[g]);
+    }
+    gradOutputBuffer = gradOutputBuffer.view(
+        {gradOutputBuffer.size(0),
+         gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+         gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+                                  gradWeight.size(2), gradWeight.size(3),
+                                  gradWeight.size(4)});
+  }
+
+  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+  offset = offset.view(
+      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+  if (batch == 0) {
+    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+    input = input.view({nInputPlane, inputHeight, inputWidth});
+  }
+
+  return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+    const int pad_h, const int pad_w, const int dilation_h,
+    const int dilation_w, const int group, const int deformable_group,
+    const bool with_bias) 
+{
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+
+  const int channels_out = weight.size(0);
+  const int channels_kernel = weight.size(1);
+  const int kernel_h_ = weight.size(2);
+  const int kernel_w_ = weight.size(3);
+
+  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+  if (channels != channels_kernel * group)
+    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+             channels, channels_kernel * group);
+
+  const int height_out =
+      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+  const int width_out =
+      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < height_out * width_out) {
+    // Resize plane and fill with ones...
+    ones = at::ones({height_out, width_out}, input.options());
+  }
+
+  // resize output
+  output = output.view({batch, channels_out, height_out, width_out}).zero_();
+  // resize temporary columns
+  columns =
+      at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+                input.options());
+
+  output = output.view({output.size(0), group, output.size(1) / group,
+                        output.size(2), output.size(3)});
+
+  for (int b = 0; b < batch; b++) {
+    modulated_deformable_im2col_cuda(
+        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, columns);
+
+    // divide into group
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+    for (int g = 0; g < group; g++) {
+      output[b][g] = output[b][g]
+                         .flatten(1)
+                         .addmm_(weight[g].flatten(1), columns[g])
+                         .view_as(output[b][g]);
+    }
+
+    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+                          weight.size(3), weight.size(4)});
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+  }
+
+  output = output.view({output.size(0), output.size(1) * output.size(2),
+                        output.size(3), output.size(4)});
+
+  if (with_bias) {
+    output += bias.view({1, bias.size(0), 1, 1});
+  }
+}
+
+void modulated_deform_conv_cuda_backward(
+    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+    const bool with_bias) 
+{
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+
+  const int channels_kernel = weight.size(1);
+  const int kernel_h_ = weight.size(2);
+  const int kernel_w_ = weight.size(3);
+  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+  if (channels != channels_kernel * group)
+    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+             channels, channels_kernel * group);
+
+  const int height_out =
+      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+  const int width_out =
+      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+  if (ones.ndimension() != 2 ||
+      ones.size(0) * ones.size(1) < height_out * width_out) {
+    // Resize plane and fill with ones...
+    ones = at::ones({height_out, width_out}, input.options());
+  }
+
+  grad_input = grad_input.view({batch, channels, height, width});
+  columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+                      input.options());
+
+  grad_output =
+      grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+                        grad_output.size(2), grad_output.size(3)});
+
+  for (int b = 0; b < batch; b++) {
+    // divide int group
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+                          weight.size(2), weight.size(3)});
+
+    for (int g = 0; g < group; g++) {
+      columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+                        grad_output[b][g].flatten(1), 0.0f, 1.0f);
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+                          weight.size(3), weight.size(4)});
+
+    // gradient w.r.t. input coordinate data
+    modulated_deformable_col2im_coord_cuda(
+        columns, input[b], offset[b], mask[b], 1, channels, height, width,
+        height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+        stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+        grad_mask[b]);
+    // gradient w.r.t. input data
+    modulated_deformable_col2im_cuda(
+        columns, offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+    // gradient w.r.t. weight, dWeight should accumulate across the batch and
+    // group
+    modulated_deformable_im2col_cuda(
+        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+        dilation_h, dilation_w, deformable_group, columns);
+
+    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+    grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+                                    grad_weight.size(1), grad_weight.size(2),
+                                    grad_weight.size(3)});
+    if (with_bias)
+      grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+    for (int g = 0; g < group; g++) {
+      grad_weight[g] =
+          grad_weight[g]
+              .flatten(1)
+              .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+              .view_as(grad_weight[g]);
+      if (with_bias) {
+        grad_bias[g] =
+            grad_bias[g]
+                .view({-1, 1})
+                .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+                .view(-1);
+      }
+    }
+
+    columns =
+        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+    grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+                                    grad_weight.size(2), grad_weight.size(3),
+                                    grad_weight.size(4)});
+    if (with_bias)
+      grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+  }
+  grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+                                  grad_output.size(2), grad_output.size(3),
+                                  grad_output.size(4)});
+}
diff --git a/windflow/external_packages/correlation_package/cuda/deform_pool_cuda.cu b/windflow/external_packages/correlation_package/cuda/deform_pool_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..0be4cb866a95f338e486cdf08e80d901cdee7e88
--- /dev/null
+++ b/windflow/external_packages/correlation_package/cuda/deform_pool_cuda.cu
@@ -0,0 +1,87 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
+
+// based on
+// author: Charles Shang
+// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THC.h>
+#include <THC/THCDeviceUtils.cuh>
+
+#include <vector>
+#include <iostream>
+#include <cmath>
+
+
+void DeformablePSROIPoolForward(
+    const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
+    at::Tensor out, at::Tensor top_count, const int batch, const int channels,
+    const int height, const int width, const int num_bbox,
+    const int channels_trans, const int no_trans, const float spatial_scale,
+    const int output_dim, const int group_size, const int pooled_size,
+    const int part_size, const int sample_per_part, const float trans_std);
+
+void DeformablePSROIPoolBackwardAcc(
+    const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
+    const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
+    at::Tensor trans_grad, const int batch, const int channels,
+    const int height, const int width, const int num_bbox,
+    const int channels_trans, const int no_trans, const float spatial_scale,
+    const int output_dim, const int group_size, const int pooled_size,
+    const int part_size, const int sample_per_part, const float trans_std);
+
+void deform_psroi_pooling_cuda_forward(
+    at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
+    at::Tensor top_count, const int no_trans, const float spatial_scale,
+    const int output_dim, const int group_size, const int pooled_size,
+    const int part_size, const int sample_per_part, const float trans_std) 
+{
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+  const int channels_trans = no_trans ? 2 : trans.size(1);
+
+  const int num_bbox = bbox.size(0);
+  if (num_bbox != out.size(0))
+    AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+             out.size(0), num_bbox);
+
+  DeformablePSROIPoolForward(
+      input, bbox, trans, out, top_count, batch, channels, height, width,
+      num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
+      pooled_size, part_size, sample_per_part, trans_std);
+}
+
+void deform_psroi_pooling_cuda_backward(
+    at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
+    at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
+    const int no_trans, const float spatial_scale, const int output_dim,
+    const int group_size, const int pooled_size, const int part_size,
+    const int sample_per_part, const float trans_std) 
+{
+  TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
+  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+
+  const int batch = input.size(0);
+  const int channels = input.size(1);
+  const int height = input.size(2);
+  const int width = input.size(3);
+  const int channels_trans = no_trans ? 2 : trans.size(1);
+
+  const int num_bbox = bbox.size(0);
+  if (num_bbox != out_grad.size(0))
+    AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+             out_grad.size(0), num_bbox);
+
+  DeformablePSROIPoolBackwardAcc(
+      out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
+      channels, height, width, num_bbox, channels_trans, no_trans,
+      spatial_scale, output_dim, group_size, pooled_size, part_size,
+      sample_per_part, trans_std);
+}
diff --git a/windflow/external_packages/correlation_package/setup.py b/windflow/external_packages/correlation_package/setup.py
new file mode 100755
index 0000000000000000000000000000000000000000..b852d94cdbacd3e6bdae451eb1be903523dee376
--- /dev/null
+++ b/windflow/external_packages/correlation_package/setup.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python3
+import os
+import torch
+
+from setuptools import setup, find_packages
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+cxx_args = ['-std=c++11', '-std=c++14']
+
+nvcc_args = [
+    '-gencode', 'arch=compute_50,code=sm_50',
+    '-gencode', 'arch=compute_52,code=sm_52',
+    '-gencode', 'arch=compute_60,code=sm_60',
+    '-gencode', 'arch=compute_61,code=sm_61',
+    '-gencode', 'arch=compute_70,code=sm_70',
+    '-gencode', 'arch=compute_70,code=compute_70'
+]
+
+setup(
+    name='correlation_cuda',
+    ext_modules=[
+        CUDAExtension('correlation_cuda', [
+            'correlation_cuda.cc',
+            'correlation_cuda_kernel.cu'
+        ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/windflow/external_packages/resample2d_package/__init__.py b/windflow/external_packages/resample2d_package/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/external_packages/resample2d_package/resample2d.py b/windflow/external_packages/resample2d_package/resample2d.py
new file mode 100755
index 0000000000000000000000000000000000000000..92ea0d01f3cd41c778d60a8b9073d5884a74bce0
--- /dev/null
+++ b/windflow/external_packages/resample2d_package/resample2d.py
@@ -0,0 +1,49 @@
+from torch.nn.modules.module import Module
+from torch.autograd import Function, Variable
+import resample2d_cuda
+
+class Resample2dFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input1, input2, kernel_size=1, bilinear= True):
+        assert input1.is_contiguous()
+        assert input2.is_contiguous()
+
+        ctx.save_for_backward(input1, input2)
+        ctx.kernel_size = kernel_size
+        ctx.bilinear = bilinear
+
+        _, d, _, _ = input1.size()
+        b, _, h, w = input2.size()
+        output = input1.new(b, d, h, w).zero_()
+
+        resample2d_cuda.forward(input1, input2, output, kernel_size, bilinear)
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_output = grad_output.contiguous()
+        assert grad_output.is_contiguous()
+
+        input1, input2 = ctx.saved_tensors
+
+        grad_input1 = Variable(input1.new(input1.size()).zero_())
+        grad_input2 = Variable(input1.new(input2.size()).zero_())
+
+        resample2d_cuda.backward(input1, input2, grad_output.data,
+                                 grad_input1.data, grad_input2.data,
+                                 ctx.kernel_size, ctx.bilinear)
+
+        return grad_input1, grad_input2, None, None
+
+class Resample2d(Module):
+
+    def __init__(self, kernel_size=1, bilinear = True):
+        super(Resample2d, self).__init__()
+        self.kernel_size = kernel_size
+        self.bilinear = bilinear
+
+    def forward(self, input1, input2):
+        input1_c = input1.contiguous()
+        return Resample2dFunction.apply(input1_c, input2, self.kernel_size, self.bilinear)
diff --git a/windflow/external_packages/resample2d_package/resample2d_cuda.cc b/windflow/external_packages/resample2d_package/resample2d_cuda.cc
new file mode 100755
index 0000000000000000000000000000000000000000..75cc6260dc6177f4b597a6d821d3a1deae0c21eb
--- /dev/null
+++ b/windflow/external_packages/resample2d_package/resample2d_cuda.cc
@@ -0,0 +1,32 @@
+#include <ATen/ATen.h>
+#include <torch/torch.h>
+
+#include "resample2d_kernel.cuh"
+
+int resample2d_cuda_forward(
+    at::Tensor& input1,
+    at::Tensor& input2, 
+    at::Tensor& output,
+    int kernel_size, bool bilinear) {
+      resample2d_kernel_forward(input1, input2, output, kernel_size, bilinear);
+    return 1;
+}
+
+int resample2d_cuda_backward(
+    at::Tensor& input1, 
+    at::Tensor& input2,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1, 
+    at::Tensor& gradInput2, 
+    int kernel_size, bool bilinear) {
+        resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size, bilinear);
+    return 1;
+}
+
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
+  m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
+}
+
diff --git a/windflow/external_packages/resample2d_package/resample2d_kernel.cu b/windflow/external_packages/resample2d_package/resample2d_kernel.cu
new file mode 100755
index 0000000000000000000000000000000000000000..a67430b0d78731a1abf18417726fd3da830bbb3e
--- /dev/null
+++ b/windflow/external_packages/resample2d_package/resample2d_kernel.cu
@@ -0,0 +1,323 @@
+#include <ATen/ATen.h>
+#include <ATen/Context.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#define CUDA_NUM_THREADS 512 
+#define THREADS_PER_BLOCK 64 
+
+#define DIM0(TENSOR) ((TENSOR).x)
+#define DIM1(TENSOR) ((TENSOR).y)
+#define DIM2(TENSOR) ((TENSOR).z)
+#define DIM3(TENSOR) ((TENSOR).w)
+
+#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
+
+template <typename scalar_t>
+__global__ void kernel_resample2d_update_output(const int n, 
+                                               const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+                                               const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 
+                                               scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size, bool bilinear) {
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (index >= n) {
+        return;
+    }
+
+    scalar_t val = 0.0f;
+
+    int dim_b = DIM0(output_size);
+    int dim_c = DIM1(output_size);
+    int dim_h = DIM2(output_size);
+    int dim_w = DIM3(output_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int c = ( index / dim_hw )  % dim_c;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+    scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+    scalar_t xf = static_cast<scalar_t>(x) + dx;
+    scalar_t yf = static_cast<scalar_t>(y) + dy;
+    scalar_t alpha = xf - floor(xf); // alpha
+    scalar_t beta = yf - floor(yf); // beta
+
+    if (bilinear) {
+        int xL = max(min( int (floor(xf)),    dim_w-1), 0);
+        int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
+        int yT = max(min( int (floor(yf)),    dim_h-1), 0);
+        int yB = max(min( int (floor(yf)+1),  dim_h-1), 0);
+
+        for (int fy = 0; fy < kernel_size; fy += 1) {
+            for (int fx = 0; fx < kernel_size; fx += 1) {
+                val += static_cast<float>((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx));
+                val += static_cast<float>((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx));
+                val += static_cast<float>((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx));
+                val += static_cast<float>((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx));
+            }
+        }
+
+        output[index] = val;
+    }
+    else {
+        int xN = max(min( int (floor(xf + 0.5)), dim_w - 1), 0);
+        int yN = max(min( int (floor(yf + 0.5)), dim_h - 1), 0);
+
+        output[index] = static_cast<float> ( DIM3_INDEX(input1, b, c, yN, xN) );
+    }
+
+}
+
+
+template <typename scalar_t>
+__global__ void kernel_resample2d_backward_input1(
+    const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+    const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
+    const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+    scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, bool bilinear) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (index >= n) {
+        return;
+    }
+
+    int dim_b = DIM0(gradOutput_size);
+    int dim_c = DIM1(gradOutput_size);
+    int dim_h = DIM2(gradOutput_size);
+    int dim_w = DIM3(gradOutput_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int c = ( index / dim_hw )  % dim_c;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+    scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+    scalar_t xf = static_cast<scalar_t>(x) + dx;
+    scalar_t yf = static_cast<scalar_t>(y) + dy;
+    scalar_t alpha = xf - int(xf); // alpha
+    scalar_t beta = yf - int(yf); // beta
+
+    int idim_h = DIM2(input1_size);
+    int idim_w = DIM3(input1_size);
+
+    int xL = max(min( int (floor(xf)),    idim_w-1), 0);
+    int xR = max(min( int (floor(xf)+1), idim_w -1), 0);
+    int yT = max(min( int (floor(yf)),    idim_h-1), 0);
+    int yB = max(min( int (floor(yf)+1),  idim_h-1), 0);
+
+    for (int fy = 0; fy < kernel_size; fy += 1) {
+        for (int fx = 0; fx < kernel_size; fx += 1) {
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)),   (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)),   (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+            atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)),     (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+        }
+    }
+
+}
+
+template <typename scalar_t>
+__global__ void kernel_resample2d_backward_input2(
+    const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+    const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
+    const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+    scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size, bool bilinear) {
+
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (index >= n) {
+        return;
+    }
+
+    scalar_t output = 0.0;
+    int kernel_rad = (kernel_size - 1)/2;
+
+    int dim_b = DIM0(gradInput_size);
+    int dim_c = DIM1(gradInput_size);
+    int dim_h = DIM2(gradInput_size);
+    int dim_w = DIM3(gradInput_size);
+    int dim_chw = dim_c * dim_h * dim_w;
+    int dim_hw  = dim_h * dim_w;
+
+    int b = ( index / dim_chw ) % dim_b;
+    int c = ( index / dim_hw )  % dim_c;
+    int y = ( index / dim_w )   % dim_h;
+    int x = ( index          )  % dim_w;
+
+    int odim_c = DIM1(gradOutput_size);
+
+    scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+    scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+    scalar_t xf = static_cast<scalar_t>(x) + dx;
+    scalar_t yf = static_cast<scalar_t>(y) + dy;
+
+    int xL = max(min( int (floor(xf)),    dim_w-1), 0);
+    int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
+    int yT = max(min( int (floor(yf)),    dim_h-1), 0);
+    int yB = max(min( int (floor(yf)+1),  dim_h-1), 0);
+    
+    if (c % 2) {
+        float gamma = 1 - (xf - floor(xf)); // alpha
+        for (int i = 0; i <= 2*kernel_rad; ++i) {
+            for (int j = 0; j <= 2*kernel_rad; ++j) {
+                for (int ch = 0; ch < odim_c; ++ch) {
+                    output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
+                    output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
+                    output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
+                    output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
+                }
+            }
+        }
+    }
+    else {
+        float gamma = 1 - (yf - floor(yf)); // alpha
+        for (int i = 0; i <= 2*kernel_rad; ++i) {
+            for (int j = 0; j <= 2*kernel_rad; ++j) {
+                for (int ch = 0; ch < odim_c; ++ch) {
+                    output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
+                    output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
+                    output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
+                    output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
+                }
+            }
+        }
+
+    }
+
+    gradInput[index] = output;
+
+}
+
+void resample2d_kernel_forward(
+    at::Tensor& input1, 
+    at::Tensor& input2,
+    at::Tensor& output, 
+    int kernel_size,
+    bool bilinear) {
+
+    int n = output.numel();
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
+    const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
+
+    const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+    const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+    // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF
+//    AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] {
+
+        kernel_resample2d_update_output<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+            n,
+            input1.data<float>(),
+            input1_size,
+            input1_stride, 
+            input2.data<float>(),
+            input2_size,
+            input2_stride,
+            output.data<float>(),
+            output_size,
+            output_stride,
+            kernel_size,
+            bilinear);
+
+//    }));
+
+        // TODO: ATen-equivalent check
+
+       //    THCudaCheck(cudaGetLastError());
+
+}
+
+void resample2d_kernel_backward(
+    at::Tensor& input1,
+    at::Tensor& input2,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1,
+    at::Tensor& gradInput2,
+    int kernel_size,
+    bool bilinear) {
+
+    int n = gradOutput.numel();
+
+    const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+    const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+    const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
+    const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
+
+    const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
+    const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
+
+    const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
+    const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
+
+//    AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] {
+
+        kernel_resample2d_backward_input1<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+            n, 
+            input1.data<float>(), 
+            input1_size,
+            input1_stride,
+            input2.data<float>(),
+            input2_size, 
+            input2_stride,
+            gradOutput.data<float>(),
+            gradOutput_size,
+            gradOutput_stride,
+            gradInput1.data<float>(),
+            gradInput1_size,
+            gradInput1_stride, 
+            kernel_size,
+            bilinear
+        );
+
+//    }));
+
+    const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3));
+    const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3));
+
+    n = gradInput2.numel();
+
+//    AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] {
+
+
+        kernel_resample2d_backward_input2<float><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+            n, 
+            input1.data<float>(), 
+            input1_size, 
+            input1_stride,
+            input2.data<float>(), 
+            input2_size,
+            input2_stride,
+            gradOutput.data<float>(),
+            gradOutput_size,
+            gradOutput_stride,
+            gradInput2.data<float>(),
+            gradInput2_size,
+            gradInput2_stride,
+            kernel_size,
+            bilinear
+       );
+
+//    }));
+
+    // TODO: Use the ATen equivalent to get last error
+
+    //    THCudaCheck(cudaGetLastError());
+
+}
diff --git a/windflow/external_packages/resample2d_package/resample2d_kernel.cuh b/windflow/external_packages/resample2d_package/resample2d_kernel.cuh
new file mode 100755
index 0000000000000000000000000000000000000000..a25951599231fd3dcfa409228c34738e9f039cdb
--- /dev/null
+++ b/windflow/external_packages/resample2d_package/resample2d_kernel.cuh
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <ATen/ATen.h>
+
+void resample2d_kernel_forward(
+    at::Tensor& input1,
+    at::Tensor& input2,
+    at::Tensor& output,
+    int kernel_size,
+    bool bilinear);
+
+void resample2d_kernel_backward(
+    at::Tensor& input1,
+    at::Tensor& input2,
+    at::Tensor& gradOutput,
+    at::Tensor& gradInput1, 
+    at::Tensor& gradInput2, 
+    int kernel_size,
+    bool bilinear);
\ No newline at end of file
diff --git a/windflow/external_packages/resample2d_package/setup.py b/windflow/external_packages/resample2d_package/setup.py
new file mode 100755
index 0000000000000000000000000000000000000000..6720026f58231f9fbd468429df6a210d88bbbe03
--- /dev/null
+++ b/windflow/external_packages/resample2d_package/setup.py
@@ -0,0 +1,30 @@
+#!/usr/bin/env python3
+import os
+import torch
+
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+#cxx_args = ['-std=c++11']
+cxx_args = ['-std=c++14']
+
+nvcc_args = [
+    '-gencode', 'arch=compute_50,code=sm_50',
+    '-gencode', 'arch=compute_52,code=sm_52',
+    '-gencode', 'arch=compute_60,code=sm_60',
+    '-gencode', 'arch=compute_61,code=sm_61',
+    '-gencode', 'arch=compute_70,code=sm_70',
+    '-gencode', 'arch=compute_70,code=compute_70'
+]
+
+setup(
+    name='resample2d_cuda',
+    ext_modules=[
+        CUDAExtension('resample2d_cuda', [
+            'resample2d_cuda.cc',
+            'resample2d_kernel.cu'
+        ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+    })
diff --git a/windflow/inference/__pycache__/inference_flows.cpython-39.pyc b/windflow/inference/__pycache__/inference_flows.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d80a1ed9993ef81eafed5c093853f03fd66b2ca5
Binary files /dev/null and b/windflow/inference/__pycache__/inference_flows.cpython-39.pyc differ
diff --git a/windflow/inference/inference_flows.py b/windflow/inference/inference_flows.py
new file mode 100644
index 0000000000000000000000000000000000000000..e047603f0fb33c0acc458651bb053f4fe5cca9d5
--- /dev/null
+++ b/windflow/inference/inference_flows.py
@@ -0,0 +1,324 @@
+import sys
+import os
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+
+import time
+import glob
+
+import numpy as np
+import pandas as pd
+import torch
+from torch import nn
+import xarray as xr
+import matplotlib.pyplot as plt
+import datetime as dt
+import scipy.stats
+
+from ..datasets.geostationary import goesr, stats
+from ..networks.models import get_flow_model
+from ..utils import plotting
+from ..datasets.preprocess import image_histogram_equalization
+from ..datasets.utils import cartesian_to_speed
+
+def split_array(arr, tile_size, overlap):
+    c, h, w = arr.shape
+    patches = []
+    indicies = []
+    for i in range(0, h, tile_size-overlap):
+        for j in range(0, w, tile_size-overlap):
+            i = min(i, h - tile_size)
+            j = min(j, w - tile_size)
+            indicies.append([i, j])
+            patches.append(arr[np.newaxis,:,i:i+tile_size, j:j+tile_size])
+    indices = np.array(indicies)
+    patches = np.concatenate(patches, axis=0)
+    return patches, indicies
+
+def reassemble_split_array(arr, upperleft, shape, trim=0):    
+    assert len(arr) > 0
+
+    # perform inference on patches
+    height, width = arr.shape[2:4]
+    counter = np.zeros(shape)
+    out_sum = np.zeros(shape)
+    for i, x in enumerate(arr):
+        ix, iy = upperleft[i]  
+        if trim > 0:
+            counter[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += 1
+            out_sum[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += x[:,trim:-trim,trim:-trim]    
+        else:
+            counter[:,ix:ix+height,iy:iy+width] += 1
+            out_sum[:,ix:ix+height,iy:iy+width] += x
+
+    out = out_sum / counter
+    return out
+
+def reassemble_with_2d_gaussian(arr, upperleft, shape, trim=0):    
+    assert len(arr) > 0
+
+    # perform inference on patches
+    height, width = arr.shape[2:4]
+    counter = np.zeros(shape)
+    out_sum = np.zeros(shape)
+    
+    nsig = 3
+    x = np.linspace(-nsig, nsig, height+1-trim*2)
+    kern1d = np.diff(scipy.stats.norm.cdf(x))
+    kern2d = np.outer(kern1d, kern1d)
+    pdf = kern2d/kern2d.sum()
+    
+    for i, x in enumerate(arr):
+        ix, iy = upperleft[i]  
+        if trim > 0:
+            counter[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += pdf
+            out_sum[:,ix+trim:ix+height-trim,iy+trim:iy+width-trim] += x[:,trim:-trim,trim:-trim] * pdf  
+        else:
+            counter[:,ix:ix+height,iy:iy+width] += pdf
+            out_sum[:,ix:ix+height,iy:iy+width] += x * pdf
+
+    out = out_sum / counter
+    return out
+
+
+def write_to_netcdf(filename, ua, va, lat, lon, ts,
+                    pressure=None):
+    
+    ua = xr.DataArray(ua.astype(np.float32), coords=[('pressure', pressure),('lat', lat), ('lon', lon)])
+    va = xr.DataArray(va.astype(np.float32), coords=[('pressure', pressure),('lat', lat), ('lon', lon)])
+    newds = xr.Dataset(dict(ua=ua, va=va, time=ts, pressure=pressure))
+    newds.to_netcdf(filename)
+    print("Dataset written to file: {}".format(filename))
+
+
+def inference(model, X0, X1, 
+              tile_size=512, 
+              overlap=128, 
+              upsample_flow_factor=None,
+              batch_size=32): 
+    c, h, w = X0.shape
+    trim = 0 #overlap // 4
+    
+    if isinstance(upsample_flow_factor, int):
+        upsample_flow = nn.Upsample(scale_factor=upsample_flow_factor, mode='bilinear')
+    else:
+        upsample_flow = None
+        
+    #if isinstance(upsample_input, float)
+    
+    f_sum = np.zeros((1, 2, h, w))
+    f_counter = np.zeros((1, 1, h, w))
+    
+    x0_patches, upperleft = split_array(X0, tile_size, overlap)
+    x1_patches, _ = split_array(X1, tile_size, overlap)
+    
+    x0_patches = torch.from_numpy(x0_patches).float()
+    x1_patches = torch.from_numpy(x1_patches).float()
+    pred = []
+    for batch in range(0, x1_patches.shape[0], batch_size):
+        x0_batch = x0_patches[batch:batch+batch_size].cuda()
+        x1_batch = x1_patches[batch:batch+batch_size].cuda()
+        
+        # testing 2x upsampling before inference
+        #upsample_2x = nn.Upsample(scale_factor=2, mode='bilinear')
+        #x0_batch = upsample_2x(x0_batch)
+        #x1_batch = upsample_2x(x1_batch)
+
+        #try:
+        model_out = model(x0_batch, x1_batch, test_mode=True)[0]
+        #except TypeError:
+            #model_out = model(x0, x1)[0]
+        #    model_out = model(torch.cat([x0_batch, x1_batch], 1))[0] 
+            
+        if upsample_flow:
+            model_out = upsample_flow(model_out) #* upsample_flow_factor
+            #model_out = upsample_2x(model_out) / 2
+            
+        pred.append(model_out.cpu().detach().numpy())
+
+    pred = np.concatenate(pred, 0)   
+    #UV = reassemble_split_array(pred, upperleft, (2, h, w), trim=trim)
+    UV = reassemble_with_2d_gaussian(pred, upperleft, (2, h, w), trim=trim)
+
+    return UV
+
+
+class FlowRunner(object):
+    '''
+    Operator to perform inference on general inputs and flow models
+    '''
+    def __init__(self, model_name, 
+                 tile_size=512, 
+                 overlap=384, 
+                 batch_size=8, 
+                 upsample_input=None,
+                 device='cuda:0'):        
+        self.model = get_flow_model(model_name, small=False)
+        self.model_name = model_name
+        self.tile_size = tile_size
+        self.batch_size = batch_size
+        self.overlap = overlap
+        
+        # upsample input 2x can improve feature tracking at inference time
+        #if upsample_input_factor is not None:
+        #    self.upsample_input = nn.Upsample(scale_factor=upsample_input_factor, mode='bilinear')
+        #else:
+        #    upsample_input_factor = 1
+        
+        if self.model_name.lower() in ['flownets', 'pwcnet', 'pwc-net', 'maskflownet']:
+            self.upsample_flow_factor = 4 #/ upsample_input_factor
+        else:
+            self.upsample_flow_factor = None #1 / upsample_input_factor
+            
+        #self.model.to(device)
+        self.model = self.model.cuda()
+        self.model = torch.nn.DataParallel(self.model)
+        #self.model = torch.nn.DataParallel(self.model, device_ids=[int(list(device)[-1])])
+        #self.model = torch.nn.DataParallel(self.model, device_ids=[0,])
+
+    def load_checkpoint(self, checkpoint_file):
+        checkpoint = torch.load(checkpoint_file)
+        self.global_step = checkpoint['global_step']
+        #for key in checkpoint['model']:
+        #    if 'module' in key:
+        #        new_key = key.replace('module.', '')
+        #        checkpoint[new_key] = checkpoint[key]
+        #        del checkpoint[key]
+
+        try:
+            self.model.module.load_state_dict(checkpoint['model'])
+        except:
+            self.model.load_state_dict(checkpoint['model'])
+                
+        print(f"Loaded checkpoint: {checkpoint_file}")
+
+    def preprocess(self, x):
+        x[~np.isfinite(x)] = 0.
+        x = image_histogram_equalization(x)
+        return x    
+
+    def forward(self, img1, img2):
+        mask = (img1 == img1)
+        mask[~mask] = np.nan
+
+        img1_norm = self.preprocess(img1)
+        img2_norm = self.preprocess(img2)
+
+        flows = inference(self.model, img1_norm[np.newaxis], img2_norm[np.newaxis], 
+                          tile_size=self.tile_size, overlap=self.overlap, 
+                          upsample_flow_factor=self.upsample_flow_factor,
+                          batch_size=self.batch_size)
+        return img1, flows * mask[np.newaxis]
+
+
+
+class GeoFlows(FlowRunner):
+    '''
+    Object to manage optical flow prediction for geostationary L1b observations
+    '''
+    def __init__(self, model_name, 
+                 data_directory, 
+                 product='ABI-L1b-RadF', 
+                 upsample_data=None,
+                 channels=[10], 
+                 timestep=10, 
+                 spatial=None, 
+                 **kwargs):
+        FlowRunner.__init__(self, model_name, **kwargs)
+        self.data_directory = data_directory
+        self.timestep = timestep
+        self.upsample_data = upsample_data
+        self.spatial = spatial
+
+        self.goes = goesr.GOESL1b(product=product, channels=channels, 
+                                  data_directory=data_directory)
+        self.channels = channels
+
+    def flows_by_time(self, t, reproject=False):
+        t2 = t + dt.timedelta(minutes=self.timestep)
+        file1 = self.goes.snapshot_file(t.year, t.timetuple().tm_yday, t.hour, 
+                                        t.minute, spatial=self.spatial).values[0]
+        file2 = self.goes.snapshot_file(t2.year, t2.timetuple().tm_yday, t2.hour, 
+                                        t2.minute, spatial=self.spatial).values[0]
+
+        obj1 = goesr.L1bBand(file1)
+        obj2 = goesr.L1bBand(file2)
+        
+        if self.upsample_data is not None:
+            obj1.interp(self.upsample_data)
+            obj2.interp(self.upsample_data)
+            
+        lats, lons = obj1.latlon()
+
+        if reproject:
+            data1 = obj1.reproject_to_latlon()
+            data2 = obj2.reproject_to_latlon()
+        else:
+            data1 = obj1.open_dataset()
+            data2 = obj2.open_dataset()
+         
+        img1 = data1['Rad'].values
+        img2 = data2['Rad'].values
+
+        flows = self.forward(img1, img2)[1]
+        
+        if reproject:
+            data1['U'] = xr.DataArray(flows[0], dims=['lat', 'lon'], 
+                                      coords=dict(lat=data1.lat.values, lon=data1.lon.values))
+            data1['V']= xr.DataArray(flows[1], dims=['lat', 'lon'], 
+                                      coords=dict(lat=data1.lat.values, lon=data1.lon.values))
+        else:
+            data1['lat'] = xr.DataArray(lats, dims=['y', 'x'], 
+                                        coords=dict(y=data1.y.values, x=data1.x.values))
+            data1['lon'] = xr.DataArray(lons, dims=['y', 'x'],
+                                        coords=dict(y=data1.y.values, x=data1.x.values))
+
+            data1['U'] = xr.DataArray(flows[0], dims=['y', 'x'], 
+                    coords=dict(y=data1.y.values, x=data1.x.values))
+            data1['V']= xr.DataArray(flows[1], dims=['y', 'x'], 
+                    coords=dict(y=data1.y.values, x=data1.x.values))
+
+        data1 = cartesian_to_speed(data1)
+        data1['U'] *= 2000 / (self.timestep * 60) * 0.8 # m/s on 2km grid -- 0.8 scales to remove bias
+        data1['V'] *= 2000 / (self.timestep * 60) * 0.8 # m/s on 2km grid
+        return data1
+    
+    
+    '''def flows_iterate(self, start_year, start_dayofyear, start_hour, steps=6*24):
+        
+        files = self.goes.local_files(start_year, start_dayofyear).reset_index()
+        files = files[files['hour'] >= start_hour]
+        
+        curr_date = dt.datetime(start_year, 1, 1, 1) + dt.timedelta(days=start_dayofyear-1)
+        while len(files) < steps:
+            curr_date += dt.timedelta(days=1)
+            more_files = self.goes.local_files(curr_date.year, curr_date.timetuple().tm_yday).reset_index()
+            files = pd.concat([files, more_files])
+        
+        band1 = goesr.L1bBand(files['file'].values[0,0])
+        data1 = band1.open_dataset()
+        img1 = data1['Rad'].values
+        
+        for idx in range(1, steps):
+            band2 = goesr.L1bBand(files['file'].values[idx,0])
+            data2 = band2.open_dataset()
+            img2 = data2['Rad'].values
+
+            flows = self.forward(img1, img2)[1]
+            
+            yield band1, flows
+            
+            band1 = band2
+            data1 = data2
+            img1 = img2
+    '''      
+         
+if __name__ == "__main__":
+
+    data_directory = '/nex/datapool/geonex/public/GOES16/NOAA-L1B/'
+    model = FlowNetS.FlowNetS(None, 1)
+    checkpoint_path = '/nobackupp10/tvandal/nex-ai-opticalflow/scripts/experiments/gmao_osse/flownets-size_512-lognorm/checkpoint.flownet.pth.tar'
+    model.training = False
+
+    inference = inference_flows.GeoFlows(model, data_directory)
+
+    
diff --git a/windflow/inference/inference_tools.py b/windflow/inference/inference_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..e040d27cc867c59e1be3bb76cd7cb06e2d71ad74
--- /dev/null
+++ b/windflow/inference/inference_tools.py
@@ -0,0 +1,242 @@
+import sys
+import os
+
+import xarray as xr
+from .utils import blocks
+import torch
+import torch.nn as nn
+import torchvision
+
+from slomo import flownet as fl
+from slomo import unet
+
+import numpy as np
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+def block_predictions_to_dataarray(predictions, block):
+    block_predictions = np.concatenate(predictions, 0)
+    block_predictions[block_predictions < 0] = 0
+    block_predictions[block_predictions > 1] = 1
+
+    N_pred = block_predictions.shape[0]
+    T = np.arange(0,N_pred)
+    da = xr.DataArray(block_predictions,#[:,:,shave:-shave,shave:-shave],
+              coords=[T, block.band.values,
+                      block.y.values,#[shave:-shave], 
+                      block.x.values,],#][shave:-shave]],
+              dims=['t', 'band', 'y', 'x'])
+
+    return da
+
+def merge_and_average_dataarrays(dataarrays):
+    ds = xr.merge([xr.Dataset({k: d}) for k, d in enumerate(dataarrays)])
+    das = []
+    for b in range(0,len(dataarrays)):
+        das.append(ds[b])
+
+    return xr.concat(das).mean('concat_dims', skipna=True)
+
+# Split 3D numpy array into patches with overlap
+def split_array(arr, tile_size=128, overlap=16):
+    '''
+    Split a 3D numpy array into patches for inference
+    (Channels, Height, Width)
+    Args:
+        tile_size: width and height of patches to return
+        overlap: number of pixels to overlap between patches
+    Returns:
+        dict(patches, upper_left): patches and indices of original array
+    '''
+    arr = arr[np.newaxis]
+    width, height = arr.shape[2:4]
+    arrs = dict(patches=[], upper_left=[])
+    for i in range(0, width, tile_size - overlap):
+        for j in range(0, height, tile_size - overlap):
+            i = min(i, width - tile_size)
+            j = min(j, height - tile_size)
+            arrs['patches'].append(arr[:,:,i:i+tile_size,j:j+tile_size])
+            arrs['upper_left'].append([[i,j]])
+    arrs['patches'] = np.concatenate(arrs['patches'])
+    arrs['upper_left'] = np.concatenate(arrs['upper_left'])
+    return arrs['patches'], arrs['upper_left']
+
+## Load and return flownet, interpnet, and warper
+def load_models(n_channels, model_path, multivariate=False,
+                nn_model=unet.UNetSmall):
+
+    model_filename = os.path.join(model_path, 'best.flownet.pth.tar')
+    if multivariate:
+        flownet = fl.SloMoFlowNetMV(n_channels, model=nn_model)#.cuda()
+        interpnet = fl.SloMoInterpNetMV(n_channels, model=nn_model)#.cuda()
+    else:
+        flownet = fl.SloMoFlowNet(n_channels, model=nn_model)#.cuda()
+        interpnet = fl.SloMoInterpNet(n_channels, model=nn_model)#.cuda()
+
+    warper = fl.FlowWarper()
+
+    flownet = nn.DataParallel(flownet)
+    interpnet = nn.DataParallel(interpnet)
+    warper = nn.DataParallel(warper)
+
+    flownet = flownet.to(device)
+    interpnet = interpnet.to(device)
+    warper = warper.to(device)
+
+    def load_checkpoint(flownet, interpnet):
+        checkpoint = torch.load(model_filename)
+        flownet.load_state_dict(checkpoint['flownet_state_dict'])
+        interpnet.load_state_dict(checkpoint['interpnet_state_dict'])
+            
+        epoch = 0
+        if os.path.isfile(model_filename):
+            print("loading checkpoint %s" % model_filename)
+            checkpoint = torch.load(model_filename)
+            flownet.load_state_dict(checkpoint['flownet_state_dict'])
+            interpnet.load_state_dict(checkpoint['interpnet_state_dict'])
+            epoch = checkpoint['epoch']
+            print("=> loaded checkpoint '{}' (epoch {})"
+                    .format(model_filename, epoch))
+        else:
+            print("=> no checkpoint found at '{}'".format(model_filename))
+        return flownet, interpnet
+
+    flownet.train()
+    interpnet.train()
+    flownet, interpnet = load_checkpoint(flownet, interpnet)
+    return flownet, interpnet, warper
+
+## Perform interpolation for a single timepoint t
+
+def single_inference(X0, X1, t, flownet, interpnet,
+                     multivariate):
+    X0_arr_torch = torch.from_numpy(X0)
+    X1_arr_torch = torch.from_numpy(X1)
+
+    # nans to 0
+    X0_arr_torch[np.isnan(X0_arr_torch)] = 0.
+    X1_arr_torch[np.isnan(X1_arr_torch)] = 0.
+
+    X0_arr_torch = torch.unsqueeze(X0_arr_torch, 0).to(device, dtype=torch.float)
+    X1_arr_torch = torch.unsqueeze(X1_arr_torch, 0).to(device, dtype=torch.float)
+
+    f = flownet(X0_arr_torch, X1_arr_torch)
+    n_channels = X0_arr_torch.shape[1]
+
+    if multivariate:
+        f_01 = f[:,:2*n_channels]
+        f_10 = f[:,2*n_channels:]
+    else:
+        f_01 = f[:,:2]
+        f_10 = f[:,2:]
+
+    predicted_frames = []
+
+    I_t, g0, g1, V_t0, V_t1, delta_f_t0, delta_f_t1 = interpnet(X0_arr_torch, X1_arr_torch, f_01, f_10, t)
+    predicted_frames.append(I_t.cpu().detach().numpy())
+
+    out = {'f_01': f_01, 'f_10': f_10, 'I_t': I_t,
+           'V_t0': V_t0, 'V_t1': V_t1, 'delta_f_t0': delta_f_t0,
+           'delta_f_t1': delta_f_t0}
+
+    for k in out.keys():
+        out[k] = out[k][0].cpu().detach().numpy()
+
+    #torch.cuda.empty_cache()
+    return out
+
+## Split interpolation for a single timepoint t
+def single_inference_split(X0, X1, t, flownet, interpnet,
+                           multivariate, block_size=128, overlap=16,
+                           discard=0):
+    X0_split, upper_left_idxs = split_array(X0, block_size, overlap)
+    X1_split, _ = split_array(X1, block_size, overlap)
+
+    assert len(X0_split) > 0
+
+    # perform inference on patches
+    height, width = X0.shape[1:3]
+    #counter = np.zeros((1, height, width))
+    counter = np.zeros((1, height-discard*2, width-discard*2))
+    res_sum = {}
+    for i, (x0, x1) in enumerate(zip(X0_split, X1_split)):
+        ix, iy = upper_left_idxs[i]
+        res_i = single_inference(x0, x1, t, flownet, interpnet,
+                                 multivariate)
+        keys = res_i.keys()
+        if i == 0:
+            res_sum = {k: np.zeros((res_i[k].shape[0], height-discard*2, width-discard*2)) for k in keys}
+
+        for var in keys:
+            if discard > 0:
+                res_i[var] = res_i[var][:,discard:-discard,discard:-discard]
+            res_sum[var][:,ix:ix+block_size-discard*2,iy:iy+block_size-discard*2] += res_i[var]
+            counter[:,ix:ix+block_size-discard*2,iy:iy+block_size-discard*2] += 1.
+
+    out = {}
+    for var in res_sum.keys():
+       out[var] = res_sum[var] / counter
+
+    return out
+
+
+def _inference(X0, X1, flownet, interpnet, warper,
+               multivariate, T=4, block_size=None):
+    '''
+    Given two consecutive frames, interpolation T between them 
+        using flownet and interpnet
+    Returns:
+        Interpolated Frames
+    '''
+
+    X0_arr_torch = torch.from_numpy(X0.values)
+    X1_arr_torch = torch.from_numpy(X1.values)
+
+    # nans to 0
+    X0_arr_torch[np.isnan(X0_arr_torch)] = 0.
+    X1_arr_torch[np.isnan(X1_arr_torch)] = 0.
+
+    X0_arr_torch = torch.unsqueeze(X0_arr_torch, 0).to(device, dtype=torch.float)
+    X1_arr_torch = torch.unsqueeze(X1_arr_torch, 0).to(device, dtype=torch.float)
+
+    f = flownet(X0_arr_torch, X1_arr_torch)
+    n_channels = X0_arr_torch.shape[1]
+
+    if multivariate:
+        f_01 = f[:,:2*n_channels]
+        f_10 = f[:,2*n_channels:]
+    else:
+        f_01 = f[:,:2]
+        f_10 = f[:,2:]
+
+    predicted_frames = []
+
+    for j in range(1,T+1):
+        t = 1. * j / (T+1)
+        I_t, g0, g1, V_t0, V_t1, delta_f_t0, delta_f_t1 = interpnet(X0_arr_torch, X1_arr_torch, f_01, f_10, t)
+        predicted_frames.append(I_t.cpu().detach().numpy())
+
+    torch.cuda.empty_cache()
+    return predicted_frames
+
+## Perform interpolation over multiple time steps
+def inference(X0, X1, flownet, interpnet, warper,
+              multivariate, T=4, block_size=352):
+    '''
+    Given two consecutive frames, interpolation T between them 
+        using flownet and interpnet
+    Returns:
+        Interpolated Frames
+    '''
+    # Get a list of dataarrays chunked
+    X0_blocks = blocks(X0, width=block_size)
+    X1_blocks = blocks(X1, width=block_size)
+
+    interpolated_blocks = []
+    for x0, x1 in zip(X0_blocks, X1_blocks):
+        predicted_frames = _inference(x0, x1, flownet,
+                                      interpnet, warper,
+                                      multivariate, T)
+        predicted_frames = [x0.values[np.newaxis]] + predicted_frames + [x1.values[np.newaxis]]
+        interpolated_blocks += [block_predictions_to_dataarray(predicted_frames, x0)]
+    return merge_and_average_dataarrays(interpolated_blocks)
diff --git a/windflow/inference/utils.py b/windflow/inference/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51fba9e3af95f80bd45482dbdfd05c0882cee366
--- /dev/null
+++ b/windflow/inference/utils.py
@@ -0,0 +1,100 @@
+import xarray as xr
+import numpy as np
+import cv2
+
+import scipy.interpolate
+
+def fillmiss(x):
+    if x.ndim != 2:
+        raise ValueError("X have only 2 dimensions.")
+    mask = ~np.isnan(x)
+    xx, yy = np.meshgrid(np.arange(x.shape[1]), np.arange(x.shape[0]))
+    xym = np.vstack( (np.ravel(xx[mask]), np.ravel(yy[mask])) ).T
+    data0 = np.ravel(x[mask])
+    interp0 = scipy.interpolate.NearestNDInterpolator(xym, data0)
+    result0 = interp0(np.ravel(xx), np.ravel(yy)).reshape(xx.shape)
+    return result0
+
+def interp_dim(x, scale):
+    x0, xlast = x[0], x[-1]
+    newlength = int(len(x) * scale)
+    y = np.linspace(x0, xlast, num=newlength, endpoint=False)
+    return y
+
+def interp_tensor(X, scale, fill=True, how=cv2.INTER_NEAREST):
+    nlt = int(X.shape[1]*scale)
+    nln = int(X.shape[2]*scale)
+    newshape = (X.shape[0], nlt, nln)
+    scaled_tensor = np.empty(newshape)
+    for j, im in enumerate(X):
+        # fill im with nearest neighbor
+        if fill:
+            #im = fillmiss(im)
+            im[np.isnan(im)] = 0
+
+        scaled_tensor[j] = cv2.resize(im, (newshape[2], newshape[1]),
+                                     interpolation=how)
+    return scaled_tensor
+
+def interp_da(da, scale, how=cv2.INTER_LINEAR):
+    """
+    Assume da is of dimensions ('time','lat', 'lon')
+    """
+    tensor = da.values
+
+    # interpolate lat and lons
+    latnew = interp_dim(da[da.dims[1]].values, scale)
+    lonnew = interp_dim(da[da.dims[2]].values, scale)
+
+    # lets store our interpolated data
+    scaled_tensor = interp_tensor(tensor, scale, fill=True, how=how)
+
+    if latnew.shape[0] != scaled_tensor.shape[1]:
+        raise ValueError("New shape is shitty")
+    # intialize a new dataarray
+    return xr.DataArray(scaled_tensor, coords=[da[da.dims[0]].values, latnew, lonnew],
+                 dims=da.dims)
+
+def interp_da2d(da, scale, fillna=False, how=cv2.INTER_NEAREST):
+    """
+    Assume da is of dimensions ('time','lat', 'lon')
+    """
+    # lets store our interpolated data
+    newshape = (int(da.shape[0]*scale),int(da.shape[1]*scale))
+    im = da.values
+    scaled_tensor = np.empty(newshape)
+    # fill im with nearest neighbor
+    if fillna:
+        filled = fillmiss(im)
+    else:
+        filled = im
+    scaled_tensor = cv2.resize(filled, dsize=(0,0), fx=scale, fy=scale,
+                              interpolation=how)
+
+    # interpolate lat and lons
+    latnew = interp_dim(da[da.dims[1]].values, scale)
+    lonnew = interp_dim(da[da.dims[0]].values, scale)
+
+    # intialize a new dataarray
+    return xr.DataArray(scaled_tensor, coords=[lonnew, latnew],
+                 dims=da.dims)
+
+def blocks(data, width=352):
+    #n = data.t.shape[0]
+    w = data.x.shape[0]
+    h = data.y.shape[0]
+    d = data.band.shape[0]
+
+    hs = np.arange(0, h, width)
+    ws = np.arange(0, w, width)
+    blocks = []
+    for hindex in hs:
+        if hindex+width > h:
+            hindex = h - width
+
+        for windex in ws:
+            if windex+width > w:
+                windex = w - width
+            blocks.append(data.sel(y=data.y.values[hindex:hindex+width],
+                                    x=data.x.values[windex:windex+width]))
+    return blocks
diff --git a/windflow/losses/__init__.py b/windflow/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f37cff02a0b650e96239e8f7041b71ad9ed115c2
--- /dev/null
+++ b/windflow/losses/__init__.py
@@ -0,0 +1,2 @@
+from .ssim import SSIM, ssim
+from .losses import SmoothnessLoss, CharbonnierLoss, L1Loss, L1, L2Loss, RMSVDLoss
\ No newline at end of file
diff --git a/windflow/losses/losses.py b/windflow/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..3457b0b12106d1cdb95305ba9acb0ea53ba49d01
--- /dev/null
+++ b/windflow/losses/losses.py
@@ -0,0 +1,85 @@
+import torch
+import torch.nn as nn
+
+class SmoothnessLoss(nn.Module):
+    def __init__(self):
+        super(SmoothnessLoss, self).__init__()
+
+    def forward(self, flows, mask=None):
+        gradx = torch.abs(flows[:,:,:,:-1] - flows[:,:,:,1:])
+        grady = torch.abs(flows[:,:,:-1,:] - flows[:,:,1:,:])
+        normalize = 1.
+        if isinstance(mask, torch.Tensor):
+            gradx *= mask[:,:,:,:-1]
+            grady *= mask[:,:,:-1,:]
+            normalize = torch.mean(mask)
+        loss_smooth = (torch.mean(gradx) + torch.mean(grady)) / normalize
+        return loss_smooth 
+
+class CharbonnierLoss(nn.Module):
+    def __init__(self, alpha, eps=1e-6):
+        super(CharbonnierLoss, self).__init__()
+        self.alpha = alpha
+        self.eps = eps
+
+    def forward(self, I0, I1, mask=None):
+        err = ((I0 - I1)**2 + self.eps)**self.alpha
+        norm = 1.
+        if isinstance(mask, torch.Tensor):
+            err *= mask
+            norm = torch.mean(mask)
+        return torch.mean(err) / norm
+    
+def EPE(input_flow, target_flow):
+    return torch.norm(target_flow-input_flow,p=2,dim=1).mean()
+
+class L1(nn.Module):
+    def __init__(self):
+        super(L1, self).__init__()
+
+    def forward(self, output, target):
+        lossvalue = torch.abs(output - target).mean()
+        return lossvalue
+
+class L2(nn.Module):
+    def __init__(self):
+        super(L2, self).__init__()
+    def forward(self, output, target):
+        lossvalue = torch.norm(output-target,p=2,dim=1).mean()
+        return lossvalue
+
+class L1Loss(nn.Module):
+    def __init__(self, args):
+        super(L1Loss, self).__init__()
+        self.args = args
+        self.loss = L1()
+        self.loss_labels = ['L1', 'EPE']
+
+    def forward(self, output, target):
+        lossvalue = self.loss(output, target)
+        #epevalue = EPE(output, target)
+        return lossvalue
+        #return [lossvalue, epevalue]
+    
+class L2Loss(nn.Module):
+    def __init__(self, args):
+        super(L2Loss, self).__init__()
+        self.args = args
+        self.loss = L2()
+
+    def forward(self, output, target):
+        lossvalue = self.loss(output, target)
+        #epevalue = EPE(output, target)
+        return lossvalue #+ epevalue
+    
+    
+class RMSVDLoss(nn.Module):
+    def __init__(self):
+        super(RMSVDLoss, self).__init__()
+    
+    def forward(self, output, target):
+        # output and target shape (N, 2, H, W)
+        u_err = (output[:,0] - target[:,0])**2
+        v_err = (output[:,1] - target[:,1])**2
+        rmsvd = (u_err.mean() + v_err.mean())**0.5
+        return rmsvd
diff --git a/windflow/losses/ssim.py b/windflow/losses/ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1e7d6cf32205922db285a9c837ec35d271e615e
--- /dev/null
+++ b/windflow/losses/ssim.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from math import exp
+
+def gaussian(window_size, sigma):
+    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
+    return gauss/gauss.sum()
+
+def create_window(window_size, channel):
+    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+    return window
+
+def _ssim(img1, img2, window, window_size, channel, size_average = True):
+    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
+    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
+
+    mu1_sq = mu1.pow(2)
+    mu2_sq = mu2.pow(2)
+    mu1_mu2 = mu1*mu2
+
+    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
+    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
+    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
+
+    C1 = 0.01**2
+    C2 = 0.03**2
+
+    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
+
+    if size_average:
+        return ssim_map.mean()
+    else:
+        return ssim_map.mean(1).mean(1).mean(1)
+
+class SSIM(torch.nn.Module):
+    def __init__(self, window_size = 11, size_average = True):
+        super(SSIM, self).__init__()
+        self.window_size = window_size
+        self.size_average = size_average
+        self.channel = 1
+        self.window = create_window(window_size, self.channel)
+
+    def forward(self, img1, img2):
+        (_, channel, _, _) = img1.size()
+
+        if channel == self.channel and self.window.data.type() == img1.data.type():
+            window = self.window
+        else:
+            window = create_window(self.window_size, channel)
+            
+            if img1.is_cuda:
+                window = window.cuda(img1.get_device())
+            window = window.type_as(img1)
+            
+            self.window = window
+            self.channel = channel
+
+
+        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
+
+def ssim(img1, img2, window_size = 11, size_average = True):
+    (_, channel, _, _) = img1.size()
+    window = create_window(window_size, channel)
+    
+    if img1.is_cuda:
+        window = window.cuda(img1.get_device())
+    window = window.type_as(img1)
+    
+    return _ssim(img1, img2, window, window_size, channel, size_average)
\ No newline at end of file
diff --git a/windflow/networks/PWCNet.py b/windflow/networks/PWCNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b5531de356a9c353ca64c1097afa8e6287fe3c6
--- /dev/null
+++ b/windflow/networks/PWCNet.py
@@ -0,0 +1,1146 @@
+"""
+implementation of the PWC-DC network for optical flow estimation by Sun et al., 2018
+
+Jinwei Gu and Zhile Ren
+
+"""
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import os
+os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory 
+#from correlation_package.modules.corr import Correlation 
+from spatial_correlation_sampler import SpatialCorrelationSampler
+
+import numpy as np
+
+
+__all__ = [
+    'pwc_dc_net', 'pwc_dc_net_old'
+    ]
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
+    return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+                        padding=padding, dilation=dilation, bias=True),
+            nn.LeakyReLU(0.1))
+
+def predict_flow(in_planes):
+    return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)
+
+def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
+    return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)
+
+class CorrelationLayerSimple(nn.Module):
+    def __init__(self, md=4):
+        super(CorrelationLayerSimple,self).__init__()
+        self.md = md
+        self.model = SpatialCorrelationSampler(kernel_size=1, patch_size=md*2+1, 
+                                              stride=1, dilation_patch=1, padding=0)
+    def forward(self, x1, x2):
+        output = self.model(x1, x2)
+        b, ph, pw, h, w = output.size()
+        return output.view(b, ph * pw, h, w)
+    
+class PWCDCNetSmall(nn.Module):
+    """
+    PWC-DC net. add dilation convolution and densenet connections
+
+    """
+    def __init__(self, md=4):
+        """
+        input: md --- maximum displacement (for correlation. default: 4), after warpping
+
+        """
+        super(PWCDCNetSmall, self).__init__()
+
+        self.conv1a  = conv(1,   32, kernel_size=3, stride=2)
+        self.conv1aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv1b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv2a  = conv(32,  32, kernel_size=3, stride=2)
+        self.conv2aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv2b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv3a  = conv(32,  32, kernel_size=3, stride=2)
+        self.conv3aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv3b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv4a  = conv(32,  32, kernel_size=3, stride=2)
+        self.conv4aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv4b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv5a  = conv(32,  32, kernel_size=3, stride=2)
+        self.conv5aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv5b  = conv(32,  32, kernel_size=3, stride=1)
+
+        # self.corr    = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
+                    
+        self.corr = CorrelationLayerSimple(md=md)
+
+        self.leakyRELU = nn.LeakyReLU(0.1)
+        
+        nd = (2*md+1)**2
+        #dd = np.cumsum([128,128,96,64,32])
+        dd = np.cumsum([32,32,32,32,32])
+
+        od = nd
+        self.conv5_0 = conv(od,       32, kernel_size=3, stride=1)
+        self.conv5_1 = conv(od+dd[0], 32, kernel_size=3, stride=1)
+        self.conv5_2 = conv(od+dd[1], 32, kernel_size=3, stride=1)
+        self.conv5_3 = conv(od+dd[2], 32, kernel_size=3, stride=1)
+        self.conv5_4 = conv(od+dd[3], 32, kernel_size=3, stride=1)
+        self.predict_flow5 = predict_flow(od+dd[4]) 
+        self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+32+4
+        self.conv4_0 = conv(od,      32, kernel_size=3, stride=1)
+        self.conv4_1 = conv(od+dd[0],32, kernel_size=3, stride=1)
+        self.conv4_2 = conv(od+dd[1],32,  kernel_size=3, stride=1)
+        self.conv4_3 = conv(od+dd[2],32,  kernel_size=3, stride=1)
+        self.conv4_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow4 = predict_flow(od+dd[4]) 
+        self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+32+4
+        self.conv3_0 = conv(od,      32, kernel_size=3, stride=1)
+        self.conv3_1 = conv(od+dd[0],32, kernel_size=3, stride=1)
+        self.conv3_2 = conv(od+dd[1],32,  kernel_size=3, stride=1)
+        self.conv3_3 = conv(od+dd[2],32,  kernel_size=3, stride=1)
+        self.conv3_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow3 = predict_flow(od+dd[4]) 
+        self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+32+4
+        self.conv2_0 = conv(od,      32, kernel_size=3, stride=1)
+        self.conv2_1 = conv(od+dd[0],32, kernel_size=3, stride=1)
+        self.conv2_2 = conv(od+dd[1],32,  kernel_size=3, stride=1)
+        self.conv2_3 = conv(od+dd[2],32,  kernel_size=3, stride=1)
+        self.conv2_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow2 = predict_flow(od+dd[4]) 
+        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        
+        self.dc_conv1 = conv(od+dd[4], 32, kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv2 = conv(32,      32, kernel_size=3, stride=1, padding=2,  dilation=2)
+        self.dc_conv3 = conv(32,      32, kernel_size=3, stride=1, padding=4,  dilation=4)
+        self.dc_conv4 = conv(32,      32,  kernel_size=3, stride=1, padding=8,  dilation=8)
+        self.dc_conv5 = conv(32,       32,  kernel_size=3, stride=1, padding=16, dilation=16)
+        self.dc_conv6 = conv(32,       32,  kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv7 = predict_flow(32)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_in')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+
+
+    def warp(self, x, flo):
+        """
+        warp an image/tensor (im2) back to im1, according to the optical flow
+
+        x: [B, C, H, W] (im2)
+        flo: [B, 2, H, W] flow
+
+        """
+        B, C, H, W = x.size()
+        # mesh grid 
+        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+        grid = torch.cat((xx,yy),1).float()
+
+        if x.is_cuda:
+            grid = grid.to(x.get_device())
+        vgrid = Variable(grid) + flo
+
+        # scale grid to [-1,1] 
+        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
+        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
+
+        vgrid = vgrid.permute(0,2,3,1)        
+        output = nn.functional.grid_sample(x, vgrid)
+        mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device())
+        mask = nn.functional.grid_sample(mask, vgrid)
+
+        # if W==128:
+            # np.save('mask.npy', mask.cpu().data.numpy())
+            # np.save('warp.npy', output.cpu().data.numpy())
+        
+        mask[mask<0.9999] = 0
+        mask[mask>0] = 1
+        
+        return output*mask
+
+
+    def forward(self, im1, im2):
+        #im1 = x[:,:3,:,:]
+        #im2 = x[:,3:,:,:]
+        
+        c11 = self.conv1b(self.conv1aa(self.conv1a(im1)))
+        c21 = self.conv1b(self.conv1aa(self.conv1a(im2)))
+        c12 = self.conv2b(self.conv2aa(self.conv2a(c11)))
+        c22 = self.conv2b(self.conv2aa(self.conv2a(c21)))
+        c13 = self.conv3b(self.conv3aa(self.conv3a(c12)))
+        c23 = self.conv3b(self.conv3aa(self.conv3a(c22)))
+        c14 = self.conv4b(self.conv4aa(self.conv4a(c13)))
+        c24 = self.conv4b(self.conv4aa(self.conv4a(c23)))
+        c15 = self.conv5b(self.conv5aa(self.conv5a(c14)))
+        c25 = self.conv5b(self.conv5aa(self.conv5a(c24)))
+        #c16 = self.conv6b(self.conv6a(self.conv6aa(c15)))
+        #c26 = self.conv6b(self.conv6a(self.conv6aa(c25)))
+
+        corr5 = self.corr(c15, c25) 
+        corr5 = self.leakyRELU(corr5)   
+
+       # x = torch.cat((self.conv6_0(corr6), corr6),1)
+       # x = torch.cat((self.conv6_1(x), x),1)
+       # x = torch.cat((self.conv6_2(x), x),1)
+       # x = torch.cat((self.conv6_3(x), x),1)
+       # x = torch.cat((self.conv6_4(x), x),1)
+       # flow6 = self.predict_flow6(x)
+       # up_flow6 = self.deconv6(flow6)
+       # up_feat6 = self.upfeat6(x)
+
+       # warp5 = self.warp(c25, up_flow6*0.625)
+       # corr5 = self.corr(c15, warp5) 
+       # corr5 = self.leakyRELU(corr5)
+       # x = torch.cat((corr5, c15, up_flow6, up_feat6), 1)
+        
+        x = torch.cat((self.conv5_0(corr5), corr5), 1)
+        #x = torch.cat((self.conv5_0(x), x),1)
+        x = torch.cat((self.conv5_1(x), x),1)
+        x = torch.cat((self.conv5_2(x), x),1)
+        x = torch.cat((self.conv5_3(x), x),1)
+        x = torch.cat((self.conv5_4(x), x),1)
+        flow5 = self.predict_flow5(x)
+        up_flow5 = self.deconv5(flow5)
+        up_feat5 = self.upfeat5(x)
+
+        warp4 = self.warp(c24, up_flow5*1.25)
+        corr4 = self.corr(c14, warp4)  
+        corr4 = self.leakyRELU(corr4)
+        x = torch.cat((corr4, c14, up_flow5, up_feat5), 1)
+        x = torch.cat((self.conv4_0(x), x),1)
+        x = torch.cat((self.conv4_1(x), x),1)
+        x = torch.cat((self.conv4_2(x), x),1)
+        x = torch.cat((self.conv4_3(x), x),1)
+        x = torch.cat((self.conv4_4(x), x),1)
+        flow4 = self.predict_flow4(x)
+        up_flow4 = self.deconv4(flow4)
+        up_feat4 = self.upfeat4(x)
+
+
+        warp3 = self.warp(c23, up_flow4*2.5)
+        corr3 = self.corr(c13, warp3) 
+        corr3 = self.leakyRELU(corr3)
+        
+
+        x = torch.cat((corr3, c13, up_flow4, up_feat4), 1)
+        x = torch.cat((self.conv3_0(x), x),1)
+        x = torch.cat((self.conv3_1(x), x),1)
+        x = torch.cat((self.conv3_2(x), x),1)
+        x = torch.cat((self.conv3_3(x), x),1)
+        x = torch.cat((self.conv3_4(x), x),1)
+        flow3 = self.predict_flow3(x)
+        up_flow3 = self.deconv3(flow3)
+        up_feat3 = self.upfeat3(x)
+
+
+        warp2 = self.warp(c22, up_flow3*5.0) 
+        corr2 = self.corr(c12, warp2)
+        corr2 = self.leakyRELU(corr2)
+        x = torch.cat((corr2, c12, up_flow3, up_feat3), 1)
+        x = torch.cat((self.conv2_0(x), x),1)
+        x = torch.cat((self.conv2_1(x), x),1)
+        x = torch.cat((self.conv2_2(x), x),1)
+        x = torch.cat((self.conv2_3(x), x),1)
+        x = torch.cat((self.conv2_4(x), x),1)
+        flow2 = self.predict_flow2(x)
+ 
+        x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
+        flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))
+        
+        if self.training:
+            return flow2, flow3, flow4, flow5
+        else:
+            return flow2
+
+    
+class PWCDCNet(nn.Module):
+    """
+    PWC-DC net. add dilation convolution and densenet connections
+
+    """
+    def __init__(self, md=4):
+        """
+        input: md --- maximum displacement (for correlation. default: 4), after warpping
+
+        """
+        super(PWCDCNet,self).__init__()
+
+        self.conv1a  = conv(1,   16, kernel_size=3, stride=2)
+        self.conv1aa = conv(16,  16, kernel_size=3, stride=1)
+        self.conv1b  = conv(16,  16, kernel_size=3, stride=1)
+        self.conv2a  = conv(16,  32, kernel_size=3, stride=2)
+        self.conv2aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv2b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv3a  = conv(32,  64, kernel_size=3, stride=2)
+        self.conv3aa = conv(64,  64, kernel_size=3, stride=1)
+        self.conv3b  = conv(64,  64, kernel_size=3, stride=1)
+        self.conv4a  = conv(64,  96, kernel_size=3, stride=2)
+        self.conv4aa = conv(96,  96, kernel_size=3, stride=1)
+        self.conv4b  = conv(96,  96, kernel_size=3, stride=1)
+        self.conv5a  = conv(96, 128, kernel_size=3, stride=2)
+        self.conv5aa = conv(128,128, kernel_size=3, stride=1)
+        self.conv5b  = conv(128,128, kernel_size=3, stride=1)
+        self.conv6aa = conv(128,196, kernel_size=3, stride=2)
+        self.conv6a  = conv(196,196, kernel_size=3, stride=1)
+        self.conv6b  = conv(196,196, kernel_size=3, stride=1)
+
+        # self.corr    = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
+                    
+        self.corr = CorrelationLayerSimple(md=md)
+
+        self.leakyRELU = nn.LeakyReLU(0.1)
+        
+        nd = (2*md+1)**2
+        dd = np.cumsum([128,128,96,64,32])
+
+        od = nd
+        self.conv6_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv6_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv6_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv6_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)        
+        self.predict_flow6 = predict_flow(od+dd[4])
+        self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+128+4
+        self.conv5_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv5_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv5_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv5_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow5 = predict_flow(od+dd[4]) 
+        self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+96+4
+        self.conv4_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv4_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv4_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv4_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow4 = predict_flow(od+dd[4]) 
+        self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+64+4
+        self.conv3_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv3_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv3_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv3_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow3 = predict_flow(od+dd[4]) 
+        self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+32+4
+        self.conv2_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv2_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv2_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv2_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow2 = predict_flow(od+dd[4]) 
+        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        
+        self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2)
+        self.dc_conv3 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4)
+        self.dc_conv4 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8)
+        self.dc_conv5 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16)
+        self.dc_conv6 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv7 = predict_flow(32)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_in')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+
+
+    def warp(self, x, flo):
+        """
+        warp an image/tensor (im2) back to im1, according to the optical flow
+
+        x: [B, C, H, W] (im2)
+        flo: [B, 2, H, W] flow
+
+        """
+        B, C, H, W = x.size()
+        # mesh grid 
+        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+        grid = torch.cat((xx,yy),1).float()
+
+        if x.is_cuda:
+            grid = grid.cuda(x.get_device())
+        vgrid = Variable(grid) + flo
+
+        # scale grid to [-1,1] 
+        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
+        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
+
+        vgrid = vgrid.permute(0,2,3,1)        
+        output = nn.functional.grid_sample(x, vgrid)
+        mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device())
+        mask = nn.functional.grid_sample(mask, vgrid)
+
+        # if W==128:
+            # np.save('mask.npy', mask.cpu().data.numpy())
+            # np.save('warp.npy', output.cpu().data.numpy())
+        
+        mask[mask<0.9999] = 0
+        mask[mask>0] = 1
+        
+        return output #*mask
+
+
+    #def forward(self, im1, im2):
+    def forward(self, x):
+        N, C, H, W = x.shape
+        im1 = x[:,:C//2,:,:]
+        im2 = x[:,C//2:,:,:]
+        
+        c11 = self.conv1b(self.conv1aa(self.conv1a(im1)))
+        c21 = self.conv1b(self.conv1aa(self.conv1a(im2)))
+        c12 = self.conv2b(self.conv2aa(self.conv2a(c11)))
+        c22 = self.conv2b(self.conv2aa(self.conv2a(c21)))
+        c13 = self.conv3b(self.conv3aa(self.conv3a(c12)))
+        c23 = self.conv3b(self.conv3aa(self.conv3a(c22)))
+        c14 = self.conv4b(self.conv4aa(self.conv4a(c13)))
+        c24 = self.conv4b(self.conv4aa(self.conv4a(c23)))
+        c15 = self.conv5b(self.conv5aa(self.conv5a(c14)))
+        c25 = self.conv5b(self.conv5aa(self.conv5a(c24)))
+        c16 = self.conv6b(self.conv6a(self.conv6aa(c15)))
+        c26 = self.conv6b(self.conv6a(self.conv6aa(c25)))
+
+
+        corr6 = self.corr(c16, c26) 
+        corr6 = self.leakyRELU(corr6)   
+
+
+        x = torch.cat((self.conv6_0(corr6), corr6),1)
+        x = torch.cat((self.conv6_1(x), x),1)
+        x = torch.cat((self.conv6_2(x), x),1)
+        x = torch.cat((self.conv6_3(x), x),1)
+        x = torch.cat((self.conv6_4(x), x),1)
+        flow6 = self.predict_flow6(x)
+        up_flow6 = self.deconv6(flow6)
+        up_feat6 = self.upfeat6(x)
+
+        
+        warp5 = self.warp(c25, up_flow6*0.625)
+        corr5 = self.corr(c15, warp5) 
+        corr5 = self.leakyRELU(corr5)
+        x = torch.cat((corr5, c15, up_flow6, up_feat6), 1)
+        x = torch.cat((self.conv5_0(x), x),1)
+        x = torch.cat((self.conv5_1(x), x),1)
+        x = torch.cat((self.conv5_2(x), x),1)
+        x = torch.cat((self.conv5_3(x), x),1)
+        x = torch.cat((self.conv5_4(x), x),1)
+        flow5 = self.predict_flow5(x)
+        up_flow5 = self.deconv5(flow5)
+        up_feat5 = self.upfeat5(x)
+
+        warp4 = self.warp(c24, up_flow5*1.25)
+        corr4 = self.corr(c14, warp4)  
+        corr4 = self.leakyRELU(corr4)
+        x = torch.cat((corr4, c14, up_flow5, up_feat5), 1)
+        x = torch.cat((self.conv4_0(x), x),1)
+        x = torch.cat((self.conv4_1(x), x),1)
+        x = torch.cat((self.conv4_2(x), x),1)
+        x = torch.cat((self.conv4_3(x), x),1)
+        x = torch.cat((self.conv4_4(x), x),1)
+        flow4 = self.predict_flow4(x)
+        up_flow4 = self.deconv4(flow4)
+        up_feat4 = self.upfeat4(x)
+
+
+        warp3 = self.warp(c23, up_flow4*2.5)
+        corr3 = self.corr(c13, warp3) 
+        corr3 = self.leakyRELU(corr3)
+        
+
+        x = torch.cat((corr3, c13, up_flow4, up_feat4), 1)
+        x = torch.cat((self.conv3_0(x), x),1)
+        x = torch.cat((self.conv3_1(x), x),1)
+        x = torch.cat((self.conv3_2(x), x),1)
+        x = torch.cat((self.conv3_3(x), x),1)
+        x = torch.cat((self.conv3_4(x), x),1)
+        flow3 = self.predict_flow3(x)
+        up_flow3 = self.deconv3(flow3)
+        up_feat3 = self.upfeat3(x)
+
+
+        warp2 = self.warp(c22, up_flow3*5.0) 
+        corr2 = self.corr(c12, warp2)
+        corr2 = self.leakyRELU(corr2)
+        x = torch.cat((corr2, c12, up_flow3, up_feat3), 1)
+        x = torch.cat((self.conv2_0(x), x),1)
+        x = torch.cat((self.conv2_1(x), x),1)
+        x = torch.cat((self.conv2_2(x), x),1)
+        x = torch.cat((self.conv2_3(x), x),1)
+        x = torch.cat((self.conv2_4(x), x),1)
+        flow2 = self.predict_flow2(x)
+ 
+        x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
+        flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))
+        
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return flow2
+
+
+class PWCDCNetFullRes(nn.Module):
+    """
+    PWC-DC net. add dilation convolution and densenet connections
+
+    """
+    def __init__(self, md=4):
+        """
+        input: md --- maximum displacement (for correlation. default: 4), after warpping
+
+        """
+        super(PWCDCNetFullRes,self).__init__()
+
+        #self.conv1a  = conv(1,   16, kernel_size=3, stride=2)
+        self.conv1a  = conv(1,   16, kernel_size=3, stride=1)
+        self.conv1aa = conv(16,  16, kernel_size=3, stride=1)
+        self.conv1b  = conv(16,  16, kernel_size=3, stride=1)
+        self.conv2a  = conv(16,  32, kernel_size=3, stride=2)
+        self.conv2aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv2b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv3a  = conv(32,  64, kernel_size=3, stride=2)
+        self.conv3aa = conv(64,  64, kernel_size=3, stride=1)
+        self.conv3b  = conv(64,  64, kernel_size=3, stride=1)
+        self.conv4a  = conv(64,  96, kernel_size=3, stride=2)
+        self.conv4aa = conv(96,  96, kernel_size=3, stride=1)
+        self.conv4b  = conv(96,  96, kernel_size=3, stride=1)
+        self.conv5a  = conv(96, 128, kernel_size=3, stride=2)
+        self.conv5aa = conv(128,128, kernel_size=3, stride=1)
+        self.conv5b  = conv(128,128, kernel_size=3, stride=1)
+        self.conv6aa = conv(128,196, kernel_size=3, stride=2)
+        self.conv6a  = conv(196,196, kernel_size=3, stride=1)
+        self.conv6b  = conv(196,196, kernel_size=3, stride=1)
+
+        # self.corr    = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
+                    
+        self.corr = CorrelationLayerSimple(md=md)
+
+        self.leakyRELU = nn.LeakyReLU(0.1)
+        
+        nd = (2*md+1)**2
+        dd = np.cumsum([128,128,96,64,32])
+
+        od = nd
+        self.conv6_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv6_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv6_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv6_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)        
+        self.predict_flow6 = predict_flow(od+dd[4])
+        self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+128+4
+        self.conv5_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv5_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv5_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv5_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow5 = predict_flow(od+dd[4]) 
+        self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+96+4
+        self.conv4_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv4_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv4_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv4_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow4 = predict_flow(od+dd[4]) 
+        self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+64+4
+        self.conv3_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv3_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv3_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv3_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow3 = predict_flow(od+dd[4]) 
+        self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+32+4
+        self.conv2_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv2_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv2_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv2_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow2 = predict_flow(od+dd[4]) 
+        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat2 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+16+4
+        self.conv1_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv1_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv1_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv1_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv1_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow1 = predict_flow(od+dd[4]) 
+        self.deconv1 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat1 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)     
+        
+
+        self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2)
+        self.dc_conv3 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4)
+        self.dc_conv4 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8)
+        self.dc_conv5 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16)
+        self.dc_conv6 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv7 = predict_flow(32)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_in')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+
+
+    def warp(self, x, flo):
+        """
+        warp an image/tensor (im2) back to im1, according to the optical flow
+
+        x: [B, C, H, W] (im2)
+        flo: [B, 2, H, W] flow
+
+        """
+        B, C, H, W = x.size()
+        # mesh grid 
+        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+        grid = torch.cat((xx,yy),1).float()
+
+        if x.is_cuda:
+            grid = grid.cuda(x.get_device())
+        vgrid = Variable(grid) + flo
+
+        # scale grid to [-1,1] 
+        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
+        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
+
+        vgrid = vgrid.permute(0,2,3,1)        
+        output = nn.functional.grid_sample(x, vgrid)
+        mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device())
+        mask = nn.functional.grid_sample(mask, vgrid)
+
+        # if W==128:
+            # np.save('mask.npy', mask.cpu().data.numpy())
+            # np.save('warp.npy', output.cpu().data.numpy())
+        
+        mask[mask<0.9999] = 0
+        mask[mask>0] = 1
+        
+        return output #*mask
+
+
+    def forward(self, im1, im2):
+        #im1 = x[:,:3,:,:]
+        #im2 = x[:,3:,:,:]
+        
+        c11 = self.conv1b(self.conv1aa(self.conv1a(im1)))
+        c21 = self.conv1b(self.conv1aa(self.conv1a(im2)))
+        c12 = self.conv2b(self.conv2aa(self.conv2a(c11)))
+        c22 = self.conv2b(self.conv2aa(self.conv2a(c21)))
+        c13 = self.conv3b(self.conv3aa(self.conv3a(c12)))
+        c23 = self.conv3b(self.conv3aa(self.conv3a(c22)))
+        c14 = self.conv4b(self.conv4aa(self.conv4a(c13)))
+        c24 = self.conv4b(self.conv4aa(self.conv4a(c23)))
+        c15 = self.conv5b(self.conv5aa(self.conv5a(c14)))
+        c25 = self.conv5b(self.conv5aa(self.conv5a(c24)))
+        c16 = self.conv6b(self.conv6a(self.conv6aa(c15)))
+        c26 = self.conv6b(self.conv6a(self.conv6aa(c25)))
+
+
+        corr6 = self.corr(c16, c26) 
+        corr6 = self.leakyRELU(corr6)   
+
+
+        x = torch.cat((self.conv6_0(corr6), corr6),1)
+        x = torch.cat((self.conv6_1(x), x),1)
+        x = torch.cat((self.conv6_2(x), x),1)
+        x = torch.cat((self.conv6_3(x), x),1)
+        x = torch.cat((self.conv6_4(x), x),1)
+        flow6 = self.predict_flow6(x)
+        up_flow6 = self.deconv6(flow6)
+        up_feat6 = self.upfeat6(x)
+
+        
+        warp5 = self.warp(c25, up_flow6*0.625)
+        corr5 = self.corr(c15, warp5) 
+        corr5 = self.leakyRELU(corr5)
+        x = torch.cat((corr5, c15, up_flow6, up_feat6), 1)
+        x = torch.cat((self.conv5_0(x), x),1)
+        x = torch.cat((self.conv5_1(x), x),1)
+        x = torch.cat((self.conv5_2(x), x),1)
+        x = torch.cat((self.conv5_3(x), x),1)
+        x = torch.cat((self.conv5_4(x), x),1)
+        flow5 = self.predict_flow5(x)
+        up_flow5 = self.deconv5(flow5)
+        up_feat5 = self.upfeat5(x)
+
+        warp4 = self.warp(c24, up_flow5*1.25)
+        corr4 = self.corr(c14, warp4)  
+        corr4 = self.leakyRELU(corr4)
+        x = torch.cat((corr4, c14, up_flow5, up_feat5), 1)
+        x = torch.cat((self.conv4_0(x), x),1)
+        x = torch.cat((self.conv4_1(x), x),1)
+        x = torch.cat((self.conv4_2(x), x),1)
+        x = torch.cat((self.conv4_3(x), x),1)
+        x = torch.cat((self.conv4_4(x), x),1)
+        flow4 = self.predict_flow4(x)
+        up_flow4 = self.deconv4(flow4)
+        up_feat4 = self.upfeat4(x)
+
+
+        warp3 = self.warp(c23, up_flow4*2.5)
+        corr3 = self.corr(c13, warp3) 
+        corr3 = self.leakyRELU(corr3)
+        x = torch.cat((corr3, c13, up_flow4, up_feat4), 1)
+        x = torch.cat((self.conv3_0(x), x),1)
+        x = torch.cat((self.conv3_1(x), x),1)
+        x = torch.cat((self.conv3_2(x), x),1)
+        x = torch.cat((self.conv3_3(x), x),1)
+        x = torch.cat((self.conv3_4(x), x),1)
+        flow3 = self.predict_flow3(x)
+        up_flow3 = self.deconv3(flow3)
+        up_feat3 = self.upfeat3(x)
+
+
+        warp2 = self.warp(c22, up_flow3*5.0) 
+        corr2 = self.corr(c12, warp2)
+        corr2 = self.leakyRELU(corr2)
+        x = torch.cat((corr2, c12, up_flow3, up_feat3), 1)
+        x = torch.cat((self.conv2_0(x), x),1)
+        x = torch.cat((self.conv2_1(x), x),1)
+        x = torch.cat((self.conv2_2(x), x),1)
+        x = torch.cat((self.conv2_3(x), x),1)
+        x = torch.cat((self.conv2_4(x), x),1)
+        flow2 = self.predict_flow2(x)
+        up_flow2 = self.deconv2(flow2)
+        up_feat2 = self.upfeat2(x)
+        
+        warp1 = self.warp(c21, up_flow2*10.0) 
+        corr1 = self.corr(c11, warp1)
+        corr1 = self.leakyRELU(corr1)
+        x = torch.cat((corr1, c11, up_flow2, up_feat2), 1)
+        x = torch.cat((self.conv1_0(x), x),1)
+        x = torch.cat((self.conv1_1(x), x),1)
+        x = torch.cat((self.conv1_2(x), x),1)
+        x = torch.cat((self.conv1_3(x), x),1)
+        x = torch.cat((self.conv1_4(x), x),1)
+        flow1 = self.predict_flow1(x)
+ 
+        x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
+        flow1 = flow1 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))
+        
+        if self.training:
+            return flow1,flow2,flow3,flow4,flow5,flow6
+        else:
+            return flow1
+
+        
+class PWCNetSmallFullRes(nn.Module):
+    """
+    PWC-DC net. add dilation convolution and densenet connections
+
+    """
+    def __init__(self, md=4):
+        """
+        input: md --- maximum displacement (for correlation. default: 4), after warpping
+
+        """
+        super(PWCNetSmallFullRes,self).__init__()
+
+        #self.conv1a  = conv(1,   16, kernel_size=3, stride=2)
+        self.conv1a  = conv(2,   16, kernel_size=3, stride=1)
+        self.conv1aa = conv(16,  16, kernel_size=3, stride=1)
+        self.conv1b  = conv(16,  16, kernel_size=3, stride=1)
+        self.conv2a  = conv(16,  32, kernel_size=3, stride=2)
+        self.conv2aa = conv(32,  32, kernel_size=3, stride=1)
+        self.conv2b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv3a  = conv(32,  64, kernel_size=3, stride=2)
+        self.conv3aa = conv(64,  64, kernel_size=3, stride=1)
+        self.conv3b  = conv(64,  64, kernel_size=3, stride=1)
+        
+        # self.corr    = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
+        # self.corr = CorrelationLayerSimple(md=md)
+
+        self.leakyRELU = nn.LeakyReLU(0.1)
+        
+        nd = 64 #(2*md+1)**2
+        dd = np.cumsum([64,32])
+
+        od = nd
+        self.conv3_0 = conv(od,       64, kernel_size=3, stride=1)
+        self.conv3_1 = conv(od+dd[0], 32, kernel_size=3, stride=1)
+        self.predict_flow3 = predict_flow(od+dd[1]) 
+        self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat3 = deconv(od+dd[1], 2, kernel_size=4, stride=2, padding=1)
+        
+        od = nd+32+4
+        self.conv2_0 = conv(od,      64, kernel_size=3, stride=1)
+        self.conv2_1 = conv(od+dd[0],32, kernel_size=3, stride=1)
+        self.predict_flow2 = predict_flow(od+dd[1]) 
+        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat2 = deconv(od+dd[1], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+16+4
+        self.conv1_0 = conv(od,      64, kernel_size=3, stride=1)
+        self.conv1_1 = conv(od+dd[0], 32, kernel_size=3, stride=1)
+        self.predict_flow1 = predict_flow(od+dd[1]) 
+        self.deconv1 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat1 = deconv(od+dd[1], 2, kernel_size=4, stride=2, padding=1)     
+        
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_in')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+
+    def warp(self, x, flo):
+        """
+        warp an image/tensor (im2) back to im1, according to the optical flow
+
+        x: [B, C, H, W] (im2)
+        flo: [B, 2, H, W] flow
+
+        """
+        B, C, H, W = x.size()
+        # mesh grid 
+        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+        grid = torch.cat((xx,yy),1).float()
+
+        if x.is_cuda:
+            grid = grid.cuda(x.get_device())
+        vgrid = Variable(grid) + flo
+
+        # scale grid to [-1,1] 
+        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
+        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
+
+        vgrid = vgrid.permute(0,2,3,1)        
+        output = nn.functional.grid_sample(x, vgrid)
+        mask = torch.autograd.Variable(torch.ones(x.size())).cuda(x.get_device())
+        mask = nn.functional.grid_sample(mask, vgrid)
+
+        # if W==128:
+            # np.save('mask.npy', mask.cpu().data.numpy())
+            # np.save('warp.npy', output.cpu().data.numpy())
+        
+        mask[mask<0.9999] = 0
+        mask[mask>0] = 1
+        
+        return output #*mask
+
+
+    def forward(self, im1, im2):
+        #im1 = x[:,:3,:,:]
+        #im2 = x[:,3:,:,:]
+        x = torch.cat([im1, im2], 1)
+        
+        c11 = self.conv1b(self.conv1aa(self.conv1a(x)))
+        c12 = self.conv2b(self.conv2aa(self.conv2a(c11)))
+        c13 = self.conv3b(self.conv3aa(self.conv3a(c12)))
+
+        x = torch.cat((self.conv3_0(c13), c13),1)
+        x = torch.cat((self.conv3_1(x), x),1)
+        flow3 = self.predict_flow3(x)
+        up_flow3 = self.deconv3(flow3)
+        up_feat3 = self.upfeat3(x)
+        
+        x = torch.cat((c12, up_flow3, up_feat3), 1)
+        x = torch.cat((self.conv2_0(x), x),1)
+        x = torch.cat((self.conv2_1(x), x),1)
+        flow2 = self.predict_flow2(x)
+        up_flow2 = self.deconv2(flow2)
+        up_feat2 = self.upfeat2(x)
+        
+        x = torch.cat((c11, up_flow2, up_feat2), 1)
+        x = torch.cat((self.conv1_0(x), x),1)
+        x = torch.cat((self.conv1_1(x), x),1)
+        flow1 = self.predict_flow1(x)
+        
+        if self.training:
+            return flow1,flow2,flow3
+        else:
+            return flow1
+        
+
+class PWCDCNet_old(nn.Module):
+    """
+    PWC-DC net. add dilation convolution and densenet connections
+
+    """
+    def __init__(self, md=4):
+        """
+        input: md --- maximum displacement (for correlation. default: 4), after warpping
+
+        """
+        super(PWCDCNet_old,self).__init__()
+
+        self.conv1a  = conv(3,   16, kernel_size=3, stride=2)
+        self.conv1b  = conv(16,  16, kernel_size=3, stride=1)
+        self.conv2a  = conv(16,  32, kernel_size=3, stride=2)
+        self.conv2b  = conv(32,  32, kernel_size=3, stride=1)
+        self.conv3a  = conv(32,  64, kernel_size=3, stride=2)
+        self.conv3b  = conv(64,  64, kernel_size=3, stride=1)
+        self.conv4a  = conv(64,  96, kernel_size=3, stride=2)
+        self.conv4b  = conv(96,  96, kernel_size=3, stride=1)
+        self.conv5a  = conv(96, 128, kernel_size=3, stride=2)
+        self.conv5b  = conv(128,128, kernel_size=3, stride=1)
+        self.conv6a  = conv(128,196, kernel_size=3, stride=2)
+        self.conv6b  = conv(196,196, kernel_size=3, stride=1)
+
+        self.corr    = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
+        self.leakyRELU = nn.LeakyReLU(0.1)
+        
+        nd = (2*md+1)**2
+        dd = np.cumsum([128,128,96,64,32])
+
+        od = nd
+        self.conv6_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv6_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv6_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv6_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)        
+        self.predict_flow6 = predict_flow(od+dd[4])
+        self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+128+4
+        self.conv5_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv5_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv5_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv5_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow5 = predict_flow(od+dd[4]) 
+        self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+96+4
+        self.conv4_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv4_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv4_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv4_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow4 = predict_flow(od+dd[4]) 
+        self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+64+4
+        self.conv3_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv3_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv3_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv3_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow3 = predict_flow(od+dd[4]) 
+        self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 
+        
+        od = nd+32+4
+        self.conv2_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv2_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv2_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv2_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.predict_flow2 = predict_flow(od+dd[4]) 
+        self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 
+        
+        self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2)
+        self.dc_conv3 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4)
+        self.dc_conv4 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8)
+        self.dc_conv5 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16)
+        self.dc_conv6 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv7 = predict_flow(32)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_in')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+
+
+    def warp(self, x, flo):
+        """
+        warp an image/tensor (im2) back to im1, according to the optical flow
+
+        x: [B, C, H, W] (im2)
+        flo: [B, 2, H, W] flow
+
+        """
+        B, C, H, W = x.size()
+        # mesh grid 
+        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+        grid = torch.cat((xx,yy),1).float()
+
+        if x.is_cuda:
+            grid = grid.cuda()
+        vgrid = Variable(grid) + flo
+
+        # scale grid to [-1,1] 
+        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
+        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
+
+        vgrid = vgrid.permute(0,2,3,1)        
+        output = nn.functional.grid_sample(x, vgrid)
+        mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
+        mask = nn.functional.grid_sample(mask, vgrid)
+        
+        mask[mask<0.999] = 0
+        mask[mask>0] = 1
+        
+        return output*mask
+
+
+    def forward(self,x):
+        im1 = x[:,:3,:,:]
+        im2 = x[:,3:,:,:]
+        
+        c11 = self.conv1b(self.conv1a(im1))
+        c21 = self.conv1b(self.conv1a(im2))
+        c12 = self.conv2b(self.conv2a(c11))
+        c22 = self.conv2b(self.conv2a(c21))
+        c13 = self.conv3b(self.conv3a(c12))
+        c23 = self.conv3b(self.conv3a(c22))
+        c14 = self.conv4b(self.conv4a(c13))
+        c24 = self.conv4b(self.conv4a(c23))        
+        c15 = self.conv5b(self.conv5a(c14))
+        c25 = self.conv5b(self.conv5a(c24))
+        c16 = self.conv6b(self.conv6a(c15))
+        c26 = self.conv6b(self.conv6a(c25))
+        
+        corr6 = self.corr(c16, c26) 
+        corr6 = self.leakyRELU(corr6)        
+        x = torch.cat((corr6, self.conv6_0(corr6)),1)
+        x = torch.cat((self.conv6_1(x), x),1)
+        x = torch.cat((x, self.conv6_2(x)),1)
+        x = torch.cat((x, self.conv6_3(x)),1)
+        x = torch.cat((x, self.conv6_4(x)),1)
+        flow6 = self.predict_flow6(x)
+        up_flow6 = self.deconv6(flow6)
+        up_feat6 = self.upfeat6(x)
+        
+        warp5 = self.warp(c25, up_flow6*0.625)
+        corr5 = self.corr(c15, warp5) 
+        corr5 = self.leakyRELU(corr5)
+        x = torch.cat((corr5, c15, up_flow6, up_feat6), 1)
+        x = torch.cat((x, self.conv5_0(x)),1)
+        x = torch.cat((self.conv5_1(x), x),1)
+        x = torch.cat((x, self.conv5_2(x)),1)
+        x = torch.cat((x, self.conv5_3(x)),1)
+        x = torch.cat((x, self.conv5_4(x)),1)
+        flow5 = self.predict_flow5(x)
+        up_flow5 = self.deconv5(flow5)
+        up_feat5 = self.upfeat5(x)
+        
+        warp4 = self.warp(c24, up_flow5*1.25)
+        corr4 = self.corr(c14, warp4)  
+        corr4 = self.leakyRELU(corr4)
+        x = torch.cat((corr4, c14, up_flow5, up_feat5), 1)
+        x = torch.cat((x, self.conv4_0(x)),1)
+        x = torch.cat((self.conv4_1(x), x),1)
+        x = torch.cat((x, self.conv4_2(x)),1)
+        x = torch.cat((x, self.conv4_3(x)),1)
+        x = torch.cat((x, self.conv4_4(x)),1)
+        flow4 = self.predict_flow4(x)
+        up_flow4 = self.deconv4(flow4)
+        up_feat4 = self.upfeat4(x)
+
+        warp3 = self.warp(c23, up_flow4*2.5)
+        corr3 = self.corr(c13, warp3) 
+        corr3 = self.leakyRELU(corr3)
+        x = torch.cat((corr3, c13, up_flow4, up_feat4), 1)
+        x = torch.cat((x, self.conv3_0(x)),1)
+        x = torch.cat((self.conv3_1(x), x),1)
+        x = torch.cat((x, self.conv3_2(x)),1)
+        x = torch.cat((x, self.conv3_3(x)),1)
+        x = torch.cat((x, self.conv3_4(x)),1)
+        flow3 = self.predict_flow3(x)
+        up_flow3 = self.deconv3(flow3)
+        up_feat3 = self.upfeat3(x)
+        
+        warp2 = self.warp(c22, up_flow3*5.0) 
+        corr2 = self.corr(c12, warp2)
+        corr2 = self.leakyRELU(corr2)
+        x = torch.cat((corr2, c12, up_flow3, up_feat3), 1)
+        x = torch.cat((x, self.conv2_0(x)),1)
+        x = torch.cat((self.conv2_1(x), x),1)
+        x = torch.cat((x, self.conv2_2(x)),1)
+        x = torch.cat((x, self.conv2_3(x)),1)
+        x = torch.cat((x, self.conv2_4(x)),1)
+        flow2 = self.predict_flow2(x)
+ 
+        x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
+        flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))
+        
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return flow2
+
+
+
+
+
+def pwc_dc_net(path=None):
+
+    model = PWCDCNet()
+    if path is not None:
+        data = torch.load(path)
+        if 'state_dict' in data.keys():
+            model.load_state_dict(data['state_dict'])
+        else:
+            model.load_state_dict(data)
+    return model
+
+
+
+
+def pwc_dc_net_old(path=None):
+
+    model = PWCDCNet_old()
+    if path is not None:
+        data = torch.load(path)
+        if 'state_dict' in data.keys():
+            model.load_state_dict(data['state_dict'])
+        else:
+            model.load_state_dict(data)
+    return model
diff --git a/windflow/networks/__init__.py b/windflow/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8aad493343bf538715a691bc64f6a825b5b19d7
--- /dev/null
+++ b/windflow/networks/__init__.py
@@ -0,0 +1,4 @@
+from .raft import raft
+from .raft.raft import RAFT
+# from .flownet import *
+# from .maskflownet import MaskFlownet
diff --git a/windflow/networks/flownet/FlowNet2.py b/windflow/networks/flownet/FlowNet2.py
new file mode 100755
index 0000000000000000000000000000000000000000..45ac9f8c9198d29b7d92f01c2006e8bb31c1996d
--- /dev/null
+++ b/windflow/networks/flownet/FlowNet2.py
@@ -0,0 +1,501 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+import math
+import numpy as np
+
+from ...external_packages.channelnorm_package.channelnorm import ChannelNorm
+from ...external_packages.resample2d_package.resample2d import Resample2d
+
+from . import FlowNetC, FlowNetS, FlowNetSD, FlowNetFusion
+
+from .submodules import *
+    
+
+'Parameter count = 162,518,834'
+
+class FlowNet2(nn.Module):
+
+    def __init__(self, input_channels=1, batchNorm=False, div_flow = 20., fp16=False):
+        super(FlowNet2,self).__init__()
+        self.batchNorm = batchNorm
+        self.div_flow = div_flow
+        self.rgb_max = 1 #args.rgb_max
+        self.input_channels = input_channels
+        
+        self.channelnorm = ChannelNorm()
+
+        # First Block (FlowNetC)
+        self.flownetc = FlowNetC.FlowNetC(input_channels=self.input_channels, 
+                                          batchNorm=self.batchNorm, fp16=fp16)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
+
+        if fp16:
+            self.resample1 = nn.Sequential(
+                            tofp32(), 
+                            Resample2d(),
+                            tofp16()) 
+        else:
+            self.resample1 = Resample2d()
+
+        # Block (FlowNetS1)
+        s_input_channels = input_channels*4 + 2
+        self.flownets_1 = FlowNetS.FlowNetS(input_channels=s_input_channels, 
+                                            batchNorm=self.batchNorm)
+        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear')
+        if fp16:
+            self.resample2 = nn.Sequential(
+                            tofp32(), 
+                            Resample2d(),
+                            tofp16()) 
+        else:
+            self.resample2 = Resample2d()
+
+
+        # Block (FlowNetS2)
+        self.flownets_2 = FlowNetS.FlowNetS(input_channels=s_input_channels, 
+                                            batchNorm=self.batchNorm)
+
+        # Block (FlowNetSD)
+        self.flownets_d = FlowNetSD.FlowNetSD(input_channels=input_channels*2, 
+                                              batchNorm=self.batchNorm) 
+        self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest') 
+        self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest') 
+
+        if fp16:
+            self.resample3 = nn.Sequential(
+                            tofp32(), 
+                            Resample2d(),
+                            tofp16()) 
+        else:
+            self.resample3 = Resample2d()
+
+        if fp16:
+            self.resample4 = nn.Sequential(
+                            tofp32(), 
+                            Resample2d(),
+                            tofp16()) 
+        else:
+            self.resample4 = Resample2d()
+
+        # Block (FLowNetFusion)
+        fusion_input_channels = self.input_channels + 8
+        self.flownetfusion = FlowNetFusion.FlowNetFusion(input_channels=fusion_input_channels,
+                                                         batchNorm=self.batchNorm)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+
+    def init_deconv_bilinear(self, weight):
+        f_shape = weight.size()
+        heigh, width = f_shape[-2], f_shape[-1]
+        f = np.ceil(width/2.0)
+        c = (2 * f - 1 - f % 2) / (2.0 * f)
+        bilinear = np.zeros([heigh, width])
+        for x in range(width):
+            for y in range(heigh):
+                value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
+                bilinear[x, y] = value
+        min_dim = min(f_shape[0], f_shape[1])
+        weight.data.fill_(0.)
+        for i in range(min_dim):
+            weight.data[i,i,:,:] = torch.from_numpy(bilinear)
+        return 
+
+    def forward(self, x):
+        #rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,))
+        
+        #x = inputs #(inputs - rgb_mean) / self.rgb_max
+        #x1 = x[:,:,0,:,:]
+        #x2 = x[:,:,1,:,:]
+        #x = torch.cat((x1,x2), dim = 1)
+
+        # flownetc
+        flownetc_flow2 = self.flownetc(x)[0]
+        flownetc_flow = self.upsample1(flownetc_flow2*self.div_flow)
+        
+        # warp img1 to img0; magnitude of diff between img0 and and warped_img1, 
+        resampled_img1 = self.resample1(x[:,self.input_channels:,:,:], flownetc_flow)
+        diff_img0 = x[:,:self.input_channels,:,:] - resampled_img1 
+        norm_diff_img0 = self.channelnorm(diff_img0)
+
+        # concat img0, img1, img1->img0, flow, diff-mag ; 
+        concat1 = torch.cat((x, resampled_img1, flownetc_flow/self.div_flow, norm_diff_img0), dim=1)
+        
+        # flownets1
+        flownets1_flow2 = self.flownets_1(concat1)[0]
+        flownets1_flow = self.upsample2(flownets1_flow2*self.div_flow) 
+
+        # warp img1 to img0 using flownets1; magnitude of diff between img0 and and warped_img1
+        resampled_img1 = self.resample2(x[:,self.input_channels:,:,:], flownets1_flow)
+        diff_img0 = x[:,:self.input_channels,:,:] - resampled_img1
+        norm_diff_img0 = self.channelnorm(diff_img0)
+
+        # concat img0, img1, img1->img0, flow, diff-mag
+        concat2 = torch.cat((x, resampled_img1, flownets1_flow/self.div_flow, norm_diff_img0), dim=1)
+
+        # flownets2
+        flownets2_flow2 = self.flownets_2(concat2)[0]
+        flownets2_flow = self.upsample4(flownets2_flow2 * self.div_flow)
+        norm_flownets2_flow = self.channelnorm(flownets2_flow)
+
+        diff_flownets2_flow = self.resample4(x[:,self.input_channels:,:,:], flownets2_flow)
+        # if not diff_flownets2_flow.volatile:
+        #     diff_flownets2_flow.register_hook(save_grad(self.args.grads, 'diff_flownets2_flow'))
+
+        diff_flownets2_img1 = self.channelnorm((x[:,:self.input_channels,:,:]-diff_flownets2_flow))
+        # if not diff_flownets2_img1.volatile:
+        #     diff_flownets2_img1.register_hook(save_grad(self.args.grads, 'diff_flownets2_img1'))
+
+        # flownetsd
+        flownetsd_flow2 = self.flownets_d(x)[0]
+        flownetsd_flow = self.upsample3(flownetsd_flow2 / self.div_flow)
+        norm_flownetsd_flow = self.channelnorm(flownetsd_flow)
+        
+        diff_flownetsd_flow = self.resample3(x[:,self.input_channels:,:,:], flownetsd_flow)
+        # if not diff_flownetsd_flow.volatile:
+        #     diff_flownetsd_flow.register_hook(save_grad(self.args.grads, 'diff_flownetsd_flow'))
+
+        diff_flownetsd_img1 = self.channelnorm((x[:,:self.input_channels,:,:]-diff_flownetsd_flow))
+        # if not diff_flownetsd_img1.volatile:
+        #     diff_flownetsd_img1.register_hook(save_grad(self.args.grads, 'diff_flownetsd_img1'))
+
+        # concat img1 flownetsd, flownets2, norm_flownetsd, norm_flownets2, diff_flownetsd_img1, diff_flownets2_img1
+        concat3 = torch.cat((x[:,:self.input_channels,:,:], flownetsd_flow, flownets2_flow, norm_flownetsd_flow, norm_flownets2_flow, diff_flownetsd_img1, diff_flownets2_img1), dim=1)
+        flownetfusion_flow = self.flownetfusion(concat3)
+
+        # if not flownetfusion_flow.volatile:
+        #     flownetfusion_flow.register_hook(save_grad(self.args.grads, 'flownetfusion_flow'))
+
+        return flownetfusion_flow, 
+
+class FlowNet2C(FlowNetC.FlowNetC):
+    def __init__(self, batchNorm=False, div_flow=20):
+        super(FlowNet2C,self).__init__(batchNorm=batchNorm, div_flow=20)
+        self.rgb_max = 1 #args.rgb_max
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,))
+        
+        x = inputs #(inputs - rgb_mean) / self.rgb_max
+        x1 = x[:,:,0,:,:]
+        x2 = x[:,:,1,:,:]
+
+        # FlownetC top input stream
+        out_conv1a = self.conv1(x1)
+        out_conv2a = self.conv2(out_conv1a)
+        out_conv3a = self.conv3(out_conv2a)
+
+        # FlownetC bottom input stream
+        out_conv1b = self.conv1(x2)
+        
+        out_conv2b = self.conv2(out_conv1b)
+        out_conv3b = self.conv3(out_conv2b)
+
+        # Merge streams
+        out_corr = self.corr(out_conv3a, out_conv3b) # False
+        out_corr = self.corr_activation(out_corr)
+
+        # Redirect top input stream and concatenate
+        out_conv_redir = self.conv_redir(out_conv3a)
+
+        in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
+
+        # Merged conv layers
+        out_conv3_1 = self.conv3_1(in_conv3_1)
+
+        out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
+
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6       = self.predict_flow6(out_conv6)
+        flow6_up    = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+
+        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
+
+        flow5       = self.predict_flow5(concat5)
+        flow5_up    = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
+
+        flow4       = self.predict_flow4(concat4)
+        flow4_up    = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1)
+
+        flow3       = self.predict_flow3(concat3)
+        flow3_up    = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+        concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1)
+
+        flow2 = self.predict_flow2(concat2)
+
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return self.upsample1(flow2*self.div_flow)
+
+class FlowNet2S(FlowNetS.FlowNetS):
+    def __init__(self, batchNorm=False, div_flow=20):
+        super(FlowNet2S,self).__init__(input_channels = 6, batchNorm=batchNorm)
+        self.rgb_max = 1 #args.rgb_max
+        self.div_flow = div_flow
+        
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,))
+        x = (inputs - rgb_mean) / self.rgb_max
+        x = torch.cat( (x[:,:,0,:,:], x[:,:,1,:,:]), dim = 1)
+
+        out_conv1 = self.conv1(x)
+
+        out_conv2 = self.conv2(out_conv1)
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6       = self.predict_flow6(out_conv6)
+        flow6_up    = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+        
+        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
+        flow5       = self.predict_flow5(concat5)
+        flow5_up    = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        
+        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
+        flow4       = self.predict_flow4(concat4)
+        flow4_up    = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        
+        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
+        flow3       = self.predict_flow3(concat3)
+        flow3_up    = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+
+        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
+        flow2 = self.predict_flow2(concat2)
+
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return self.upsample1(flow2*self.div_flow)
+
+class FlowNet2SD(FlowNetSD.FlowNetSD):
+    def __init__(self, batchNorm=False, div_flow=20):
+        super(FlowNet2SD,self).__init__(batchNorm=batchNorm)
+        self.rgb_max = 1 #args.rgb_max
+        self.div_flow = div_flow
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,))
+        x = inputs #(inputs - rgb_mean) / self.rgb_max
+        x = torch.cat( (x[:,:,0,:,:], x[:,:,1,:,:]), dim = 1)
+
+        out_conv0 = self.conv0(x)
+        out_conv1 = self.conv1_1(self.conv1(out_conv0))
+        out_conv2 = self.conv2_1(self.conv2(out_conv1))
+
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6       = self.predict_flow6(out_conv6)
+        flow6_up    = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+        
+        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
+        out_interconv5 = self.inter_conv5(concat5)
+        flow5       = self.predict_flow5(out_interconv5)
+
+        flow5_up    = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        
+        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
+        out_interconv4 = self.inter_conv4(concat4)
+        flow4       = self.predict_flow4(out_interconv4)
+        flow4_up    = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        
+        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
+        out_interconv3 = self.inter_conv3(concat3)
+        flow3       = self.predict_flow3(out_interconv3)
+        flow3_up    = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+
+        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
+        out_interconv2 = self.inter_conv2(concat2)
+        flow2 = self.predict_flow2(out_interconv2)
+
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return self.upsample1(flow2*self.div_flow)
+
+class FlowNet2CS(nn.Module):
+
+    def __init__(self, batchNorm=False, div_flow = 20., fp16=False):
+        super(FlowNet2CS,self).__init__()
+        self.batchNorm = batchNorm
+        self.div_flow = div_flow
+        self.rgb_max = 1 #args.rgb_max
+        self.args = args
+
+        self.channelnorm = ChannelNorm()
+
+        # First Block (FlowNetC)
+        self.flownetc = FlowNetC.FlowNetC(batchNorm=self.batchNorm, fp16=fp16)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
+
+        if fp16:
+            self.resample1 = nn.Sequential(
+                            tofp32(), 
+                            Resample2d(),
+                            tofp16()) 
+        else:
+            self.resample1 = Resample2d()
+
+        # Block (FlowNetS1)
+        self.flownets_1 = FlowNetS.FlowNetS(batchNorm=self.batchNorm)
+        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear')
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+                # init_deconv_bilinear(m.weight)
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,))
+        
+        x = inputs #(inputs - rgb_mean) / self.rgb_max
+        x1 = x[:,:,0,:,:]
+        x2 = x[:,:,1,:,:]
+        x = torch.cat((x1,x2), dim = 1)
+
+        # flownetc
+        flownetc_flow2 = self.flownetc(x)[0]
+        flownetc_flow = self.upsample1(flownetc_flow2*self.div_flow)
+        
+        # warp img1 to img0; magnitude of diff between img0 and and warped_img1, 
+        resampled_img1 = self.resample1(x[:,3:,:,:], flownetc_flow)
+        diff_img0 = x[:,:3,:,:] - resampled_img1 
+        norm_diff_img0 = self.channelnorm(diff_img0)
+
+        # concat img0, img1, img1->img0, flow, diff-mag ; 
+        concat1 = torch.cat((x, resampled_img1, flownetc_flow/self.div_flow, norm_diff_img0), dim=1)
+        
+        # flownets1
+        flownets1_flow2 = self.flownets_1(concat1)[0]
+        flownets1_flow = self.upsample2(flownets1_flow2*self.div_flow) 
+
+        return flownets1_flow
+
+class FlowNet2CSS(nn.Module):
+
+    def __init__(self, batchNorm=False, div_flow = 20., fp16=False):
+        super(FlowNet2CSS,self).__init__()
+        self.batchNorm = batchNorm
+        self.div_flow = div_flow
+        self.rgb_max = 1  #args.rgb_max
+        self.args = args
+
+        self.channelnorm = ChannelNorm()
+
+        # First Block (FlowNetC)
+        self.flownetc = FlowNetC.FlowNetC(args, batchNorm=self.batchNorm)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
+
+        if fp16:
+            self.resample1 = nn.Sequential(
+                            tofp32(), 
+                            Resample2d(),
+                            tofp16()) 
+        else:
+            self.resample1 = Resample2d()
+
+        # Block (FlowNetS1)
+        self.flownets_1 = FlowNetS.FlowNetS(args, batchNorm=self.batchNorm)
+        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear')
+        if fp16:
+            self.resample2 = nn.Sequential(
+                            tofp32(), 
+                            Resample2d(),
+                            tofp16()) 
+        else:
+            self.resample2 = Resample2d()
+
+
+        # Block (FlowNetS2)
+        self.flownets_2 = FlowNetS.FlowNetS(args, batchNorm=self.batchNorm)
+        self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest') 
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform(m.bias)
+                init.xavier_uniform(m.weight)
+                # init_deconv_bilinear(m.weight)
+
+    def forward(self, inputs):
+        rgb_mean = inputs.contiguous().view(inputs.size()[:2]+(-1,)).mean(dim=-1).view(inputs.size()[:2] + (1,1,1,))
+        
+        x = inputs #(inputs - rgb_mean) / self.rgb_max
+        x1 = x[:,:,0,:,:]
+        x2 = x[:,:,1,:,:]
+        x = torch.cat((x1,x2), dim = 1)
+
+        # flownetc
+        flownetc_flow2 = self.flownetc(x)[0]
+        flownetc_flow = self.upsample1(flownetc_flow2*self.div_flow)
+        
+        # warp img1 to img0; magnitude of diff between img0 and and warped_img1, 
+        resampled_img1 = self.resample1(x[:,3:,:,:], flownetc_flow)
+        diff_img0 = x[:,:3,:,:] - resampled_img1 
+        norm_diff_img0 = self.channelnorm(diff_img0)
+
+        # concat img0, img1, img1->img0, flow, diff-mag ; 
+        concat1 = torch.cat((x, resampled_img1, flownetc_flow/self.div_flow, norm_diff_img0), dim=1)
+        
+        # flownets1
+        flownets1_flow2 = self.flownets_1(concat1)[0]
+        flownets1_flow = self.upsample2(flownets1_flow2*self.div_flow) 
+
+        # warp img1 to img0 using flownets1; magnitude of diff between img0 and and warped_img1
+        resampled_img1 = self.resample2(x[:,3:,:,:], flownets1_flow)
+        diff_img0 = x[:,:3,:,:] - resampled_img1
+        norm_diff_img0 = self.channelnorm(diff_img0)
+
+        # concat img0, img1, img1->img0, flow, diff-mag
+        concat2 = torch.cat((x, resampled_img1, flownets1_flow/self.div_flow, norm_diff_img0), dim=1)
+
+        # flownets2
+        flownets2_flow2 = self.flownets_2(concat2)[0]
+        flownets2_flow = self.upsample3(flownets2_flow2 * self.div_flow)
+
+        return flownets2_flow
+
diff --git a/windflow/networks/flownet/FlowNetC.py b/windflow/networks/flownet/FlowNetC.py
new file mode 100755
index 0000000000000000000000000000000000000000..3aee083ef3174543fbb6eca5ca65340b8693ccd8
--- /dev/null
+++ b/windflow/networks/flownet/FlowNetC.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+import math
+import numpy as np
+
+from ...external_packages.correlation_package.correlation import Correlation
+
+from .submodules import *
+'Parameter count , 39,175,298 '
+
+class FlowNetC(nn.Module):
+    def __init__(self, input_channels=1, batchNorm=True, div_flow = 20, fp16=False):
+        super(FlowNetC,self).__init__()
+
+        self.batchNorm = batchNorm
+        self.div_flow = div_flow
+        self.input_channels = input_channels
+
+        self.conv1   = conv(self.batchNorm, input_channels,   64, kernel_size=7, stride=2)
+        self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
+        self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
+        self.conv_redir  = conv(self.batchNorm, 256,   32, kernel_size=1, stride=1)
+
+        if fp16:
+            self.corr = nn.Sequential(
+                tofp32(),
+                Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1),
+                tofp16())
+        else:
+            self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1)
+
+        self.corr_activation = nn.LeakyReLU(0.1,inplace=True)
+        self.conv3_1 = conv(self.batchNorm, 473,  256)
+        self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
+        self.conv4_1 = conv(self.batchNorm, 512,  512)
+        self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
+        self.conv5_1 = conv(self.batchNorm, 512,  512)
+        self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
+        self.conv6_1 = conv(self.batchNorm,1024, 1024)
+
+        self.deconv5 = deconv(1024,512)
+        self.deconv4 = deconv(1026,256)
+        self.deconv3 = deconv(770,128)
+        self.deconv2 = deconv(386,64)
+
+        self.predict_flow6 = predict_flow(1024)
+        self.predict_flow5 = predict_flow(1026)
+        self.predict_flow4 = predict_flow(770)
+        self.predict_flow3 = predict_flow(386)
+        self.predict_flow2 = predict_flow(194)
+
+        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
+        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
+        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
+        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
+
+    def forward(self, x):
+        x1 = x[:,0:self.input_channels,:,:]
+        x2 = x[:,self.input_channels:,:,:]
+
+        out_conv1a = self.conv1(x1)
+        out_conv2a = self.conv2(out_conv1a)
+        out_conv3a = self.conv3(out_conv2a)
+
+        # FlownetC bottom input stream
+        out_conv1b = self.conv1(x2)
+        
+        out_conv2b = self.conv2(out_conv1b)
+        out_conv3b = self.conv3(out_conv2b)
+
+        # Merge streams
+        #out_corr = self.corr(out_conv3a, out_conv3b) # False
+        out_corr = self.corr(out_conv3a, out_conv3b) # False
+        out_corr = self.corr_activation(out_corr)
+
+        # Redirect top input stream and concatenate
+        out_conv_redir = self.conv_redir(out_conv3a)
+
+        in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
+
+        # Merged conv layers
+        out_conv3_1 = self.conv3_1(in_conv3_1)
+
+        out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
+
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6       = self.predict_flow6(out_conv6)
+        flow6_up    = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+
+        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
+
+        flow5       = self.predict_flow5(concat5)
+        flow5_up    = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
+
+        flow4       = self.predict_flow4(concat4)
+        flow4_up    = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1)
+
+        flow3       = self.predict_flow3(concat3)
+        flow3_up    = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+        concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1)
+
+        flow2 = self.predict_flow2(concat2)
+
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return flow2,
diff --git a/windflow/networks/flownet/FlowNetFusion.py b/windflow/networks/flownet/FlowNetFusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..81b079ae1bb919affc5780b9dfcdc763936fbc5a
--- /dev/null
+++ b/windflow/networks/flownet/FlowNetFusion.py
@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+import math
+import numpy as np
+
+from .submodules import *
+'Parameter count = 581,226'
+
+class FlowNetFusion(nn.Module):
+    def __init__(self, input_channels=11, batchNorm=True):
+        super(FlowNetFusion,self).__init__()
+
+        self.batchNorm = batchNorm
+        self.conv0   = conv(self.batchNorm,  input_channels,   64)
+        self.conv1   = conv(self.batchNorm,  64,   64, stride=2)
+        self.conv1_1 = conv(self.batchNorm,  64,   128)
+        self.conv2   = conv(self.batchNorm,  128,  128, stride=2)
+        self.conv2_1 = conv(self.batchNorm,  128,  128)
+
+        self.deconv1 = deconv(128,32)
+        self.deconv0 = deconv(162,16)
+
+        self.inter_conv1 = i_conv(self.batchNorm,  162,   32)
+        self.inter_conv0 = i_conv(self.batchNorm,  82,   16)
+
+        self.predict_flow2 = predict_flow(128)
+        self.predict_flow1 = predict_flow(32)
+        self.predict_flow0 = predict_flow(16)
+
+        self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+
+    def forward(self, x):
+        out_conv0 = self.conv0(x)
+        out_conv1 = self.conv1_1(self.conv1(out_conv0))
+        out_conv2 = self.conv2_1(self.conv2(out_conv1))
+
+        flow2       = self.predict_flow2(out_conv2)
+        flow2_up    = self.upsampled_flow2_to_1(flow2)
+        out_deconv1 = self.deconv1(out_conv2)
+        
+        concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1)
+        out_interconv1 = self.inter_conv1(concat1)
+        flow1       = self.predict_flow1(out_interconv1)
+        flow1_up    = self.upsampled_flow1_to_0(flow1)
+        out_deconv0 = self.deconv0(concat1)
+        
+        concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1)
+        out_interconv0 = self.inter_conv0(concat0)
+        flow0       = self.predict_flow0(out_interconv0)
+
+        return flow0
+
diff --git a/windflow/networks/flownet/FlowNetHR.py b/windflow/networks/flownet/FlowNetHR.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5eb9f9512851c896f0fc8aee7a20db9fcc149f8
--- /dev/null
+++ b/windflow/networks/flownet/FlowNetHR.py
@@ -0,0 +1,92 @@
+'''
+Portions of this code copyright 2017, Clement Pinard
+'''
+
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+import math
+import numpy as np
+
+from .flownet_submodules import *
+
+#def deconv(in_planes, out_planes):
+#    return nn.Sequential(
+#        nn.Conv2d(in_planes, out_planes*4, kernel_size=1, bias=True),
+#        nn.PixelShuffle(2),
+#        nn.Conv2d(out_planes, out_planes, kernel_size=3, padding=1, bias=True),
+#        nn.LeakyReLU(0.1,inplace=True)
+#    )
+
+class FlowNetHR(nn.Module):
+    def __init__(self, args, input_channels = 1, batchNorm=True, partial=False):
+        super(FlowNetHR, self).__init__()
+
+        self.batchNorm = batchNorm
+        self.conv0 = conv(self.batchNorm, input_channels*2, 32, kernel_size=7, partial=partial)
+        self.conv0_1 = conv(self.batchNorm, 32,  32, partial=partial)
+        self.conv1   = conv(self.batchNorm, 32,  64, kernel_size=5, stride=2, partial=partial)
+        self.conv1_1 = conv(self.batchNorm, 64,  64, partial=partial)
+        self.conv2   = conv(self.batchNorm, 64,  128, kernel_size=5, stride=2, partial=partial)   
+        self.conv2_1 = conv(self.batchNorm, 128,  128, partial=partial)
+        self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2, partial=partial)
+        self.conv3_1 = conv(self.batchNorm, 256,  256, partial=partial)
+
+        self.deconv2 = deconv(256, 64)
+        self.deconv1 = deconv(194, 32)
+        self.deconv0 = deconv(98, 16)
+
+        self.predict_flow3 = predict_flow(256, partial=partial)
+        self.predict_flow2 = predict_flow(194, partial=partial)
+        self.predict_flow1 = predict_flow(98, partial=partial)
+        self.predict_flow0 = predict_flow(50, partial=partial)
+
+        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
+        self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
+        #self.upsampled_flow3_to_2 = nn.UpsamplingBilinear2d(scale_factor=2)
+        #self.upsampled_flow2_to_1 = nn.UpsamplingBilinear2d(scale_factor=2)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample_bilinear = nn.Upsample(scale_factor=2, mode='bilinear')
+
+    def forward(self, x1, x2):
+        x = torch.cat([x1,x2],1)
+        
+        out_conv0 = self.conv0_1(self.conv0(x))
+        out_conv1 = self.conv1_1(self.conv1(out_conv0)) # 64
+        out_conv2 = self.conv2_1(self.conv2(out_conv1)) # 128
+        out_conv3 = self.conv3_1(self.conv3(out_conv2)) # 256
+        
+        flow3       = self.predict_flow3(out_conv3) # 2
+        flow3_up    = self.upsample_bilinear(flow3) # 2 
+        out_deconv2 = self.deconv2(out_conv3) # 64
+        
+        concat = torch.cat((out_conv2, flow3_up, out_deconv2), 1) # 128 + 64 + 2 = 194
+        flow2 = self.predict_flow2(concat)
+        flow2_up = self.upsample_bilinear(flow2)
+        out_deconv1 = self.deconv1(concat)
+        
+        concat = torch.cat((out_conv1, flow2_up, out_deconv1), 1) # 64 + 32 + 2 = 98
+        flow1 = self.predict_flow1(concat)
+        flow1_up = self.upsample_bilinear(flow1)
+        out_deconv0 = self.deconv0(concat)
+        
+        concat = torch.cat((out_conv0, flow1_up, out_deconv0), 1) # 32 + 16 + 2 = 50
+        flow0 = self.predict_flow0(concat)
+        
+        if self.training:
+            return flow0,flow1,flow2,flow3
+        else:
+            return flow0,
+
diff --git a/windflow/networks/flownet/FlowNetPartial.py b/windflow/networks/flownet/FlowNetPartial.py
new file mode 100644
index 0000000000000000000000000000000000000000..92145c8035ef8977cb683d1c829d8b9dd25ae811
--- /dev/null
+++ b/windflow/networks/flownet/FlowNetPartial.py
@@ -0,0 +1,95 @@
+'''
+Portions of this code copyright 2017, Clement Pinard
+'''
+
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+import math
+import numpy as np
+
+from .flownet_submodules import *
+from .layers import partial
+
+def conv_partial(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, partial=True):
+    if batchNorm:
+        return nn.Sequential(
+            partial.Conv2d_partial(in_planes, out_planes, kernel_size=kernel_size, stride=stride, bias=False, partial=partial),
+            nn.BatchNorm2d(out_planes),
+            nn.LeakyReLU(0.1,inplace=True)
+        )
+    else:
+        return nn.Sequential(
+            partial.Conv2d_partial(in_planes, out_planes, kernel_size=kernel_size, stride=stride, partial=partial, bias=True),
+            nn.LeakyReLU(0.1,inplace=True)
+        )
+
+
+class FlowNetPartial(nn.Module):
+    def __init__(self, args, input_channels = 1, batchNorm=False, partial=True):
+        super(FlowNetHR, self).__init__()
+        self.batchNorm = batchNorm
+        self.conv0 = conv_partial(self.batchNorm, input_channels*2, 32, kernel_size=7, partial=True)
+        self.conv0_1 = conv_partial(self.batchNorm, 32,  32, partial=True)
+        self.conv1   = conv_partial(self.batchNorm, 32,  64, kernel_size=5, stride=2, partial=True)
+        self.conv1_1 = conv_partial(self.batchNorm, 64,  64, partial=True)
+        self.conv2   = conv_partial(self.batchNorm, 64,  128, kernel_size=5, stride=2, partial=True)
+        self.conv2_1 = conv_partial(self.batchNorm, 128,  128, partial=True)
+        self.conv3   = conv_partial(self.batchNorm, 128,  256, kernel_size=5, stride=2, partial=True)
+        self.conv3_1 = conv_partial(self.batchNorm, 256,  256, partial=True)
+
+        self.deconv2 = deconv(256, 64)
+        self.deconv1 = deconv(194, 32)
+        self.deconv0 = deconv(98, 16)
+
+        self.predict_flow3 = predict_flow(256)
+        self.predict_flow2 = predict_flow(194)
+        self.predict_flow1 = predict_flow(98)
+        self.predict_flow0 = predict_flow(50)
+
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample_bilinear = nn.Upsample(scale_factor=2, mode='bilinear')
+
+    def forward(self, x1, x2):
+        x = torch.cat([x1,x2],1)
+        
+        out_conv0 = self.conv0_1(self.conv0(x))
+        out_conv1 = self.conv1_1(self.conv1(out_conv0)) # 64
+        out_conv2 = self.conv2_1(self.conv2(out_conv1)) # 128
+        out_conv3 = self.conv3_1(self.conv3(out_conv2)) # 256
+        
+        
+        flow3       = self.predict_flow3(out_conv3) # 2
+        flow3_up    = self.upsample_bilinear(flow3) # 2 
+        out_deconv2 = self.deconv2(out_conv3) # 64
+        
+        concat = torch.cat((out_conv2, flow3_up, out_deconv2), 1) # 128 + 64 + 2 = 194
+        flow2 = self.predict_flow2(concat)
+        flow2_up = self.upsample_bilinear(flow2)
+        out_deconv1 = self.deconv1(concat)
+        
+        concat = torch.cat((out_conv1, flow2_up, out_deconv1), 1) # 64 + 32 + 2 = 98
+        flow1 = self.predict_flow1(concat)
+        flow1_up = self.upsample_bilinear(flow1)
+        out_deconv0 = self.deconv0(concat)
+        
+        concat = torch.cat((out_conv0, flow1_up, out_deconv0), 1) # 32 + 16 + 2 = 50
+        flow0 = self.predict_flow0(concat)
+        
+        if self.training:
+            return flow0,flow1,flow2,flow3
+        else:
+            return flow0,
+
diff --git a/windflow/networks/flownet/FlowNetS.py b/windflow/networks/flownet/FlowNetS.py
new file mode 100755
index 0000000000000000000000000000000000000000..9288cddb5474d3deb89440ea1bc3eb6c65798836
--- /dev/null
+++ b/windflow/networks/flownet/FlowNetS.py
@@ -0,0 +1,95 @@
+'''
+Portions of this code copyright 2017, Clement Pinard
+'''
+
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+import math
+import numpy as np
+
+from .submodules import *
+'Parameter count : 38,676,504 '
+
+class FlowNetS(nn.Module):
+    def __init__(self, input_channels = 12, batchNorm=True):
+        super(FlowNetS,self).__init__()
+
+        self.batchNorm = batchNorm
+        self.conv1   = conv(self.batchNorm,  input_channels,   64, kernel_size=7, stride=2)
+        self.conv2   = conv(self.batchNorm,  64,  128, kernel_size=5, stride=2)
+        self.conv3   = conv(self.batchNorm, 128,  256, kernel_size=5, stride=2)
+        self.conv3_1 = conv(self.batchNorm, 256,  256)
+        self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
+        self.conv4_1 = conv(self.batchNorm, 512,  512)
+        self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
+        self.conv5_1 = conv(self.batchNorm, 512,  512)
+        self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
+        self.conv6_1 = conv(self.batchNorm,1024, 1024)
+
+        self.deconv5 = deconv(1024,512)
+        self.deconv4 = deconv(1026,256)
+        self.deconv3 = deconv(770,128)
+        self.deconv2 = deconv(386,64)
+
+        self.predict_flow6 = predict_flow(1024)
+        self.predict_flow5 = predict_flow(1026)
+        self.predict_flow4 = predict_flow(770)
+        self.predict_flow3 = predict_flow(386)
+        self.predict_flow2 = predict_flow(194)
+
+        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
+        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
+        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
+        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
+
+    def forward(self, x):
+        out_conv1 = self.conv1(x)
+
+        out_conv2 = self.conv2(out_conv1)
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6       = self.predict_flow6(out_conv6)
+        flow6_up    = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+        
+        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
+        flow5       = self.predict_flow5(concat5)
+        flow5_up    = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        
+        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
+        flow4       = self.predict_flow4(concat4)
+        flow4_up    = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        
+        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
+        flow3       = self.predict_flow3(concat3)
+        flow3_up    = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+
+        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
+        flow2 = self.predict_flow2(concat2)
+
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return flow2,
+
diff --git a/windflow/networks/flownet/FlowNetSD.py b/windflow/networks/flownet/FlowNetSD.py
new file mode 100755
index 0000000000000000000000000000000000000000..0f9eff0d9d269b39bafa307194bf2e7879d089d1
--- /dev/null
+++ b/windflow/networks/flownet/FlowNetSD.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+
+import math
+import numpy as np
+
+from .submodules import *
+'Parameter count = 45,371,666'
+
+class FlowNetSD(nn.Module):
+    def __init__(self, input_channels=2, batchNorm=True):
+        super(FlowNetSD,self).__init__()
+
+        self.batchNorm = batchNorm
+        self.conv0   = conv(self.batchNorm,  input_channels,   64)
+        self.conv1   = conv(self.batchNorm,  64,   64, stride=2)
+        self.conv1_1 = conv(self.batchNorm,  64,   128)
+        self.conv2   = conv(self.batchNorm,  128,  128, stride=2)
+        self.conv2_1 = conv(self.batchNorm,  128,  128)
+        self.conv3   = conv(self.batchNorm, 128,  256, stride=2)
+        self.conv3_1 = conv(self.batchNorm, 256,  256)
+        self.conv4   = conv(self.batchNorm, 256,  512, stride=2)
+        self.conv4_1 = conv(self.batchNorm, 512,  512)
+        self.conv5   = conv(self.batchNorm, 512,  512, stride=2)
+        self.conv5_1 = conv(self.batchNorm, 512,  512)
+        self.conv6   = conv(self.batchNorm, 512, 1024, stride=2)
+        self.conv6_1 = conv(self.batchNorm,1024, 1024)
+
+        self.deconv5 = deconv(1024,512)
+        self.deconv4 = deconv(1026,256)
+        self.deconv3 = deconv(770,128)
+        self.deconv2 = deconv(386,64)
+
+        self.inter_conv5 = i_conv(self.batchNorm,  1026,   512)
+        self.inter_conv4 = i_conv(self.batchNorm,  770,   256)
+        self.inter_conv3 = i_conv(self.batchNorm,  386,   128)
+        self.inter_conv2 = i_conv(self.batchNorm,  194,   64)
+
+        self.predict_flow6 = predict_flow(1024)
+        self.predict_flow5 = predict_flow(512)
+        self.predict_flow4 = predict_flow(256)
+        self.predict_flow3 = predict_flow(128)
+        self.predict_flow2 = predict_flow(64)
+
+        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+
+            if isinstance(m, nn.ConvTranspose2d):
+                if m.bias is not None:
+                    init.uniform_(m.bias)
+                init.xavier_uniform_(m.weight)
+                # init_deconv_bilinear(m.weight)
+        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
+
+
+
+    def forward(self, x):
+        out_conv0 = self.conv0(x)
+        out_conv1 = self.conv1_1(self.conv1(out_conv0))
+        out_conv2 = self.conv2_1(self.conv2(out_conv1))
+
+        out_conv3 = self.conv3_1(self.conv3(out_conv2))
+        out_conv4 = self.conv4_1(self.conv4(out_conv3))
+        out_conv5 = self.conv5_1(self.conv5(out_conv4))
+        out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+        flow6       = self.predict_flow6(out_conv6)
+        flow6_up    = self.upsampled_flow6_to_5(flow6)
+        out_deconv5 = self.deconv5(out_conv6)
+        
+        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
+        out_interconv5 = self.inter_conv5(concat5)
+        flow5       = self.predict_flow5(out_interconv5)
+
+        flow5_up    = self.upsampled_flow5_to_4(flow5)
+        out_deconv4 = self.deconv4(concat5)
+        
+        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
+        out_interconv4 = self.inter_conv4(concat4)
+        flow4       = self.predict_flow4(out_interconv4)
+        flow4_up    = self.upsampled_flow4_to_3(flow4)
+        out_deconv3 = self.deconv3(concat4)
+        
+        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
+        out_interconv3 = self.inter_conv3(concat3)
+        flow3       = self.predict_flow3(out_interconv3)
+        flow3_up    = self.upsampled_flow3_to_2(flow3)
+        out_deconv2 = self.deconv2(concat3)
+
+        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
+        out_interconv2 = self.inter_conv2(concat2)
+        flow2 = self.predict_flow2(out_interconv2)
+
+        if self.training:
+            return flow2,flow3,flow4,flow5,flow6
+        else:
+            return flow2,
diff --git a/windflow/networks/flownet/__init__.py b/windflow/networks/flownet/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..bbe2eadb45c8e9cebe36ba1c595eb209afef653c
--- /dev/null
+++ b/windflow/networks/flownet/__init__.py
@@ -0,0 +1 @@
+from . import FlowNetS, FlowNetC, FlowNetSD, FlowNetFusion, FlowNet2
\ No newline at end of file
diff --git a/windflow/networks/flownet/submodules.py b/windflow/networks/flownet/submodules.py
new file mode 100755
index 0000000000000000000000000000000000000000..dd6aac3d581faa86968d1522edf41b2a8b7fc1a0
--- /dev/null
+++ b/windflow/networks/flownet/submodules.py
@@ -0,0 +1,92 @@
+# freda (todo) : 
+
+import torch.nn as nn
+import torch
+import numpy as np 
+
+def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
+    if batchNorm:
+        return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
+            nn.BatchNorm2d(out_planes),
+            nn.LeakyReLU(0.1,inplace=True)
+        )
+    else:
+        return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
+            nn.LeakyReLU(0.1,inplace=True)
+        )
+
+def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True):
+    if batchNorm:
+        return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias),
+            nn.BatchNorm2d(out_planes),
+        )
+    else:
+        return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias),
+        )
+
+def predict_flow(in_planes):
+    return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)
+
+def deconv(in_planes, out_planes):
+    return nn.Sequential(
+        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
+        nn.LeakyReLU(0.1,inplace=True)
+    )
+
+class tofp16(nn.Module):
+    def __init__(self):
+        super(tofp16, self).__init__()
+
+    def forward(self, input):
+        return input.half()
+
+
+class tofp32(nn.Module):
+    def __init__(self):
+        super(tofp32, self).__init__()
+
+    def forward(self, input):
+        return input.float()
+
+
+def init_deconv_bilinear(weight):
+    f_shape = weight.size()
+    heigh, width = f_shape[-2], f_shape[-1]
+    f = np.ceil(width/2.0)
+    c = (2 * f - 1 - f % 2) / (2.0 * f)
+    bilinear = np.zeros([heigh, width])
+    for x in range(width):
+        for y in range(heigh):
+            value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
+            bilinear[x, y] = value
+    weight.data.fill_(0.)
+    for i in range(f_shape[0]):
+        for j in range(f_shape[1]):
+            weight.data[i,j,:,:] = torch.from_numpy(bilinear)
+
+
+def save_grad(grads, name):
+    def hook(grad):
+        grads[name] = grad
+    return hook
+
+'''
+def save_grad(grads, name):
+    def hook(grad):
+        grads[name] = grad
+    return hook
+import torch
+from channelnorm_package.modules.channelnorm import ChannelNorm 
+model = ChannelNorm().cuda()
+grads = {}
+a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True)
+a.register_hook(save_grad(grads, 'a'))
+b = model(a)
+y = torch.mean(b)
+y.backward()
+
+'''
diff --git a/windflow/networks/layers/partial.py b/windflow/networks/layers/partial.py
new file mode 100644
index 0000000000000000000000000000000000000000..90cccd8b0caa4cdbbde75e517c0e2bb1778e1423
--- /dev/null
+++ b/windflow/networks/layers/partial.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Conv2d_partial(nn.Conv2d):
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 padding=0, dilation=1, groups=1, bias=True, partial=False):
+        super(Conv2d_partial, self).__init__(
+            in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
+
+        self.partial = partial
+
+    def forward(self, input):
+        if self.partial:
+            self.padding = 0
+
+            pad_val = (self.kernel_size[0] - 1) // 2
+            if pad_val > 0:
+                if (self.kernel_size[0] - self.stride[0]) % 2 == 0:
+                    pad_top = pad_val
+                    pad_bottom = pad_val
+                    pad_left = pad_val
+                    pad_right = pad_val
+                else:
+                    pad_top = pad_val
+                    pad_bottom = self.kernel_size[0] - self.stride[0] - pad_top
+                    pad_left = pad_val
+                    pad_right = self.kernel_size[0] - self.stride[0] - pad_left
+
+                p0 = torch.ones_like(input)
+                p0 = p0.sum()
+
+                input = F.pad(input, (pad_left, pad_right, pad_top, pad_bottom) , mode='constant', value=0)
+
+                p1 = torch.ones_like(input)
+                p1 = p1.sum()
+
+                ratio = torch.div(p1, p0 + 1e-8) 
+                input = torch.mul(input, ratio)  
+            
+        return F.conv2d(input, self.weight, self.bias, self.stride,
+                        self.padding, self.dilation, self.groups)
diff --git a/windflow/networks/maskflownet.py b/windflow/networks/maskflownet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c57c427a01a33a1b34736ca65a2abc685a568a4
--- /dev/null
+++ b/windflow/networks/maskflownet.py
@@ -0,0 +1,724 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from torchvision import ops
+# ops.DeformConv2d()
+
+#from .correlation_package.correlation import Correlation
+from spatial_correlation_sampler import SpatialCorrelationSampler
+
+
+def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, activation=True):
+    if activation:
+        return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+                      padding=padding, dilation=dilation, bias=True),
+            nn.LeakyReLU(0.1))
+    else:
+        return nn.Sequential(
+            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
+                      padding=padding, dilation=dilation, bias=True))
+
+
+def predict_flow(in_planes):
+    return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=True)
+
+
+def predict_mask(in_planes):
+    return nn.Conv2d(in_planes, 1, kernel_size=3, stride=1, padding=1, bias=True)
+
+
+def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
+    return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)
+
+
+def deformable_conv(in_planes, out_planes, kernel_size=3, strides=1, padding=1, use_bias=True):
+    return ops.DeformConv2d(in_planes, out_planes, kernel_size, strides, padding, bias=use_bias)
+
+
+def upsample_kernel2d(w, device):
+    c = w // 2
+    kernel = 1 - torch.abs(c - torch.arange(w, dtype=torch.float32, device=device)) / (c + 1)
+    kernel = kernel.repeat(w).view(w,-1) * kernel.unsqueeze(1)
+    return kernel.view(1, 1, w, w)
+
+
+def downsample_kernel2d(w, device):
+    kernel = ((w + 1) - torch.abs(w - torch.arange(w * 2 + 1, dtype=torch.float32, device=device))) / (2 * w + 1)
+    kernel = kernel.repeat(w).view(w,-1) * kernel.unsqueeze(1)
+    return kernel.view(1, 1, w * 2 + 1, w * 2 + 1)
+
+class CorrelationLayerSimple(nn.Module):
+    def __init__(self, md=4):
+        super(CorrelationLayerSimple,self).__init__()
+        self.md = md
+        self.model = SpatialCorrelationSampler(kernel_size=1, patch_size=md*2+1, 
+                                              stride=1, dilation_patch=1, padding=0)
+    def forward(self, x1, x2):
+        output = self.model(x1, x2)
+        b, ph, pw, h, w = output.size()
+        return output.view(b, ph * pw, h, w)
+
+def Upsample(img, factor):
+    if factor == 1:
+        return img
+    B, C, H, W = img.shape
+    batch_img = img.view(B*C, 1, H, W)
+    batch_img = F.pad(batch_img, [0, 1, 0, 1], mode='replicate')
+    kernel = upsample_kernel2d(factor * 2 - 1, img.device)
+    upsamp_img = F.conv_transpose2d(batch_img, kernel, stride=factor, padding=(factor-1))
+    upsamp_img = upsamp_img[:, :, : -1, :-1]
+    _, _, H_up, W_up = upsamp_img.shape
+    return upsamp_img.view(B, C, H_up, W_up)
+
+
+def Downsample(img, factor):
+    if factor == 1:
+        return img
+    B, C, H, W = img.shape
+    batch_img = img.view(B*C, 1, H, W)
+    kernel = downsample_kernel2d(factor // 2, img.device)
+    upsamp_img = F.conv2d(batch_img, kernel, stride=factor, padding=factor//2)
+    upsamp_nom = F.conv2d(torch.ones_like(batch_img), kernel, stride=factor, padding=factor//2)
+    _, _, H_up, W_up = upsamp_img.shape
+    upsamp_img = upsamp_img.view(B, C, H_up, W_up)
+    upsamp_nom = upsamp_nom.view(B, C, H_up, W_up)
+    return upsamp_img / upsamp_nom
+
+
+
+class MaskFlownet_S(nn.Module):
+    """
+    PWC-DC net. add dilation convolution and densenet connections
+    """
+    def __init__(self, in_ch=1, **kwargs):
+        """
+        input: md --- maximum displacement (for correlation. default: 4), after warpping
+        """
+        super(MaskFlownet_S, self).__init__()
+        self.scale = 1. # 20 * config.network.flow_multiplier.get(1.)
+        md = 4
+        self.md = md
+        self.in_ch = in_ch
+        self.strides = [64, 32, 16, 8, 4]
+        self.deform_bias = True #config.network.deform_bias.get(True)
+        self.upfeat_ch = [16, 16, 16, 16] #config.network.upfeat_ch.get([16, 16, 16, 16])
+
+        self.conv1a  = conv(in_ch,   16, kernel_size=3, stride=2)
+        self.conv1b = conv(16, 16, kernel_size=3, stride=1)
+        self.conv1c  = conv(16, 16, kernel_size=3, stride=1)
+        self.conv2a  = conv(16,  32, kernel_size=3, stride=2)
+        self.conv2b = conv(32, 32, kernel_size=3, stride=1)
+        self.conv2c  = conv(32, 32, kernel_size=3, stride=1)
+        self.conv3a  = conv(32,  64, kernel_size=3, stride=2)
+        self.conv3b = conv(64, 64, kernel_size=3, stride=1)
+        self.conv3c  = conv(64, 64, kernel_size=3, stride=1)
+        self.conv4a  = conv(64,  96, kernel_size=3, stride=2)
+        self.conv4b = conv(96, 96, kernel_size=3, stride=1)
+        self.conv4c  = conv(96, 96, kernel_size=3, stride=1)
+        self.conv5a  = conv(96, 128, kernel_size=3, stride=2)
+        self.conv5b = conv(128, 128, kernel_size=3, stride=1)
+        self.conv5c  = conv(128, 128, kernel_size=3, stride=1)
+        self.conv6a = conv(128, 196, kernel_size=3, stride=2)
+        self.conv6b  = conv(196, 196, kernel_size=3, stride=1)
+        self.conv6c  = conv(196, 196, kernel_size=3, stride=1)
+
+        #self.corr    = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
+        self.corr = CorrelationLayerSimple(md=md)
+
+        self.leakyRELU = nn.LeakyReLU(0.1)
+
+        nd = (2*md+1)**2
+        dd = np.cumsum([128,128,96,64,32])
+
+        od = nd
+        self.conv6_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv6_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv6_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv6_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow6 = predict_flow(od + dd[4])
+        self.pred_mask6 = predict_mask(od + dd[4])
+        self.upfeat5 = deconv(od+dd[4], self.upfeat_ch[0], kernel_size=4, stride=2, padding=1)
+        # self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
+        # self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
+
+        # od = nd+128+4
+        od = nd+128+18
+        self.conv5_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv5_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv5_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv5_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow5 = predict_flow(od + dd[4])
+        self.pred_mask5 = predict_mask(od + dd[4])
+        self.upfeat4 = deconv(od+dd[4], self.upfeat_ch[1], kernel_size=4, stride=2, padding=1)
+        # self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
+        # self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
+
+        # od = nd+96+4
+        od = nd+96+18
+        self.conv4_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv4_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv4_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv4_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow4 = predict_flow(od + dd[4])
+        self.pred_mask4 = predict_mask(od + dd[4])
+        self.upfeat3 = deconv(od+dd[4], self.upfeat_ch[2], kernel_size=4, stride=2, padding=1)
+        # self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
+        # self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
+
+        # od = nd+64+4
+        od = nd+64+18
+        self.conv3_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv3_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv3_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv3_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow3 = predict_flow(od + dd[4])
+        self.pred_mask3 = predict_mask(od + dd[4])
+        self.upfeat2 = deconv(od+dd[4], self.upfeat_ch[3], kernel_size=4, stride=2, padding=1)
+        # self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
+        # self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
+
+        # od = nd+32+4
+        od = nd+32+18
+        self.conv2_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv2_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv2_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv2_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow2 = predict_flow(od + dd[4])
+
+        self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2)
+        self.dc_conv3 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4)
+        self.dc_conv4 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8)
+        self.dc_conv5 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16)
+        self.dc_conv6 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv7 = predict_flow(32)
+
+        # self.upfeat5 = deconv()
+
+        self.deform5 = deformable_conv(128, 128)
+        self.deform4 = deformable_conv(96, 96)
+        self.deform3 = deformable_conv(64, 64)
+        self.deform2 = deformable_conv(32, 32)
+
+        self.conv5f = conv(16, 128, kernel_size=3, stride=1, padding=1, activation=False)
+        self.conv4f = conv(16, 96, kernel_size=3, stride=1, padding=1, activation=False)
+        self.conv3f = conv(16, 64, kernel_size=3, stride=1, padding=1, activation=False)
+        self.conv2f = conv(16, 32, kernel_size=3, stride=1, padding=1, activation=False)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+                nn.init.kaiming_normal(m.weight.data, mode='fan_in')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+
+
+    def warp(self, x, flo):
+        """
+        warp an image/tensor (im2) back to im1, according to the optical flow
+        x: [B, C, H, W] (im2)
+        flo: [B, 2, H, W] flow
+        """
+        B, C, H, W = x.size()
+        # mesh grid
+        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+        grid = torch.cat((xx,yy),1).float()
+
+        device = x.device
+        grid = grid.to(device)
+        # vgrid = Variable(grid) + flo
+        vgrid = Variable(grid) + torch.flip(flo, [1])
+
+        # scale grid to [-1,1]
+        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
+        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
+
+        vgrid = vgrid.permute(0,2,3,1)
+        # vgrid = vgrid.permute(0,2,3,1).clamp(-1.1, 1.1)
+        output = nn.functional.grid_sample(x, vgrid, align_corners=True)
+        mask = torch.autograd.Variable(torch.ones(x.size())).to(device)
+        mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
+
+        # if W==128:
+        # np.save('mask.npy', mask.cpu().data.numpy())
+        # np.save('warp.npy', output.cpu().data.numpy())
+
+        mask[mask<0.9999] = 0
+        mask[mask>0] = 1
+
+        return output*mask
+
+
+    def forward(self, x):
+        im1 = x[:,:self.in_ch,:,:]
+        im2 = x[:,self.in_ch:,:,:]
+
+        c11 = self.conv1c(self.conv1b(self.conv1a(im1)))
+        c21 = self.conv1c(self.conv1b(self.conv1a(im2)))
+        c12 = self.conv2c(self.conv2b(self.conv2a(c11)))
+        c22 = self.conv2c(self.conv2b(self.conv2a(c21)))
+        c13 = self.conv3c(self.conv3b(self.conv3a(c12)))
+        c23 = self.conv3c(self.conv3b(self.conv3a(c22)))
+        c14 = self.conv4c(self.conv4b(self.conv4a(c13)))
+        c24 = self.conv4c(self.conv4b(self.conv4a(c23)))
+        c15 = self.conv5c(self.conv5b(self.conv5a(c14)))
+        c25 = self.conv5c(self.conv5b(self.conv5a(c24)))
+        c16 = self.conv6c(self.conv6b(self.conv6a(c15)))
+        c26 = self.conv6c(self.conv6b(self.conv6a(c25)))
+
+
+        corr6 = self.corr(c16, c26)
+        corr6 = self.leakyRELU(corr6)
+
+
+        x = torch.cat((self.conv6_0(corr6), corr6),1)
+        x = torch.cat((self.conv6_1(x), x),1)
+        x = torch.cat((self.conv6_2(x), x),1)
+        x = torch.cat((self.conv6_3(x), x),1)
+        x = torch.cat((self.conv6_4(x), x),1)
+        flow6 = self.pred_flow6(x)
+        mask6 = self.pred_mask6(x)
+
+        feat5 = self.leakyRELU(self.upfeat5(x))
+        flow5 = Upsample(flow6, 2)
+        mask5 = Upsample(mask6, 2)
+        warp5 = (flow5*self.scale/self.strides[1]).unsqueeze(1)
+        warp5 = torch.repeat_interleave(warp5, 9, 1)
+        S1, S2, S3, S4, S5 = warp5.shape
+        warp5 = warp5.view(S1, S2*S3, S4, S5)
+        warp5 = self.deform5(c25, warp5)
+        tradeoff5 = feat5
+        warp5 = (warp5 * F.sigmoid(mask5)) + self.conv5f(tradeoff5)
+        warp5 = self.leakyRELU(warp5)
+        corr5 = self.corr(c15, warp5)
+        corr5 = self.leakyRELU(corr5)
+        x = torch.cat((corr5, c15, feat5, flow5), 1)
+        x = torch.cat((self.conv5_0(x), x),1)
+        x = torch.cat((self.conv5_1(x), x),1)
+        x = torch.cat((self.conv5_2(x), x),1)
+        x = torch.cat((self.conv5_3(x), x),1)
+        x = torch.cat((self.conv5_4(x), x),1)
+        flow5 = flow5 + self.pred_flow5(x)
+        mask5 = self.pred_mask5(x)
+
+        feat4 = self.leakyRELU(self.upfeat4(x))
+        flow4 = Upsample(flow5, 2)
+        mask4 = Upsample(mask5, 2)
+        warp4 = (flow4*self.scale/self.strides[2]).unsqueeze(1)
+        warp4 = torch.repeat_interleave(warp4, 9, 1)
+        S1, S2, S3, S4, S5 = warp4.shape
+        warp4 = warp4.view(S1, S2*S3, S4, S5)
+        warp4 = self.deform4(c24, warp4)
+        tradeoff4 = feat4
+        warp4 = (warp4 * F.sigmoid(mask4)) + self.conv4f(tradeoff4)
+        warp4 = self.leakyRELU(warp4)
+        corr4 = self.corr(c14, warp4)
+        corr4 = self.leakyRELU(corr4)
+        x = torch.cat((corr4, c14, feat4, flow4), 1)
+        x = torch.cat((self.conv4_0(x), x),1)
+        x = torch.cat((self.conv4_1(x), x),1)
+        x = torch.cat((self.conv4_2(x), x),1)
+        x = torch.cat((self.conv4_3(x), x),1)
+        x = torch.cat((self.conv4_4(x), x),1)
+        flow4 = flow4 + self.pred_flow4(x)
+        mask4 = self.pred_mask4(x)
+
+        feat3 = self.leakyRELU(self.upfeat3(x))
+        flow3 = Upsample(flow4, 2)
+        mask3 = Upsample(mask4, 2)
+        warp3 = (flow3*self.scale/self.strides[3]).unsqueeze(1)
+        warp3 = torch.repeat_interleave(warp3, 9, 1)
+        S1, S2, S3, S4, S5 = warp3.shape
+        warp3 = warp3.view(S1, S2*S3, S4, S5)
+        warp3 = self.deform3(c23, warp3)
+        tradeoff3 = feat3
+        warp3 = (warp3 * F.sigmoid(mask3)) + self.conv3f(tradeoff3)
+        warp3 = self.leakyRELU(warp3)
+        corr3 = self.corr(c13, warp3)
+        corr3 = self.leakyRELU(corr3)
+        x = torch.cat((corr3, c13, feat3, flow3), 1)
+        x = torch.cat((self.conv3_0(x), x),1)
+        x = torch.cat((self.conv3_1(x), x),1)
+        x = torch.cat((self.conv3_2(x), x),1)
+        x = torch.cat((self.conv3_3(x), x),1)
+        x = torch.cat((self.conv3_4(x), x),1)
+        flow3 = flow3 + self.pred_flow3(x)
+        mask3 = self.pred_mask3(x)
+
+        feat2 = self.leakyRELU(self.upfeat2(x))
+        flow2 = Upsample(flow3, 2)
+        mask2 = Upsample(mask3, 2)
+        warp2 = (flow2*self.scale/self.strides[4]).unsqueeze(1)
+        warp2 = torch.repeat_interleave(warp2, 9, 1)
+        S1, S2, S3, S4, S5 = warp2.shape
+        warp2 = warp2.view(S1, S2*S3, S4, S5)
+        warp2 = self.deform2(c22, warp2)
+        tradeoff2 = feat2
+        warp2 = (warp2 * F.sigmoid(mask2)) + self.conv2f(tradeoff2)
+        warp2 = self.leakyRELU(warp2)
+        corr2 = self.corr(c12, warp2)
+        corr2 = self.leakyRELU(corr2)
+        x = torch.cat((corr2, c12, feat2, flow2), 1)
+        x = torch.cat((self.conv2_0(x), x),1)
+        x = torch.cat((self.conv2_1(x), x),1)
+        x = torch.cat((self.conv2_2(x), x),1)
+        x = torch.cat((self.conv2_3(x), x),1)
+        x = torch.cat((self.conv2_4(x), x),1)
+        flow2 = flow2 + self.pred_flow2(x)
+
+        x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
+        flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))
+
+        predictions = [flow * self.scale for flow in [flow6, flow5, flow4, flow3, flow2]]
+        occlusion_masks = []
+        occlusion_masks.append(F.sigmoid(mask2))
+        c1s = [c11, c12, c13, c14, c15, c16]
+        c2s = [c21, c12, c13, c24, c25, c26]
+        flows = [flow6, flow5, flow4, flow3, flow2]
+        mask0 = Upsample(mask2, 4)
+        mask0 = F.sigmoid(mask0) - 0.5
+        c30 = im1
+        c40 = self.warp(im2, Upsample(flow2, 4)*self.scale)
+        c30 = torch.cat((c30, torch.zeros_like(mask0)), 1)
+        c40 = torch.cat((c40, mask0), 1)
+        srcs = [c1s, c2s, flows, c30, c40]
+        return predictions, occlusion_masks, srcs
+
+
+class MaskFlownet(nn.Module):
+    def __init__(self, in_ch=1, **kwargs):
+        super(MaskFlownet, self).__init__(**kwargs)
+        self.strides = [64, 32, 16, 8, 4]
+        self.md = 2
+        self.scale = 1. #config.network.flow_multiplier.get(1.)
+        self.deform_bias = True #config.network.deform_bias.get(True)
+        self.upfeat_ch = [16, 16, 16, 16] #config.network.upfeat_ch.get([16, 16, 16, 16])
+        self.in_ch = in_ch
+        self.MaskFlownet_S = MaskFlownet_S(in_ch=in_ch)
+        self.activate = nn.LeakyReLU(0.1)
+
+        self.conv1x = conv(self.in_ch+1,   16, stride=2) # image + 
+        self.conv1y = conv(16,  16, stride=1)
+        self.conv1z = conv(16,  16, stride=1)
+        self.conv2x = conv(16,  32, stride=2)
+        self.conv2y = conv(32,  32, stride=1)
+        self.conv2z = conv(32,  32, stride=1)
+        self.conv3x = conv(32,  64, stride=2)
+        self.conv3y = conv(64,  64, stride=1)
+        self.conv3z = conv(64,  64, stride=1)
+        self.conv4x = conv(64,  96, stride=2)
+        self.conv4y = conv(96,  96, stride=1)
+        self.conv4z = conv(96,  96, stride=1)
+        self.conv5x = conv(96,  128, stride=2)
+        self.conv5y = conv(128, 128, stride=1)
+        self.conv5z = conv(128, 128, stride=1)
+        self.conv6x = conv(128, 196, stride=2)
+        self.conv6y = conv(196, 196, stride=1)
+        self.conv6z = conv(196, 196, stride=1)
+
+        self.leakyRELU = nn.LeakyReLU(0.1)
+        #self.corr    = Correlation(pad_size=self.md, kernel_size=1, max_displacement=self.md, stride1=1, stride2=1, corr_multiply=1)
+        self.corr = CorrelationLayerSimple(md=self.md)
+
+
+        nd = (2*self.md+1)**2
+        dd = np.cumsum([128,128,96,64,32])
+
+        od = nd+nd+2
+        self.conv6_0 = conv(od,       128, kernel_size=3, stride=1)
+        self.conv6_1 = conv(od+dd[0], 128, kernel_size=3, stride=1)
+        self.conv6_2 = conv(od+dd[1], 96,  kernel_size=3, stride=1)
+        self.conv6_3 = conv(od+dd[2], 64,  kernel_size=3, stride=1)
+        self.conv6_4 = conv(od+dd[3], 32,  kernel_size=3, stride=1)
+        self.pred_flow6 = predict_flow(od + dd[4])
+        self.upfeat5 = deconv(od+dd[4], self.upfeat_ch[0], kernel_size=4, stride=2, padding=1)
+
+        # od = nd+128+4
+        od = nd+nd+128+16+2+2
+        self.conv5_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv5_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv5_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv5_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow5 = predict_flow(od + dd[4])
+        self.upfeat4 = deconv(od+dd[4], self.upfeat_ch[1], kernel_size=4, stride=2, padding=1)
+        # self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
+        # self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
+
+        # od = nd+96+4
+        # od = nd+96+18
+        od = nd+nd+96+16+2+2
+        self.conv4_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv4_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv4_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv4_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow4 = predict_flow(od + dd[4])
+        self.upfeat3 = deconv(od+dd[4], self.upfeat_ch[2], kernel_size=4, stride=2, padding=1)
+        # self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
+        # self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
+
+        # od = nd+64+4
+        # od = nd+64+18
+        od = nd+nd+64+16+2+2
+        self.conv3_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv3_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv3_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv3_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow3 = predict_flow(od + dd[4])
+        self.upfeat2 = deconv(od+dd[4], self.upfeat_ch[3], kernel_size=4, stride=2, padding=1)
+        # self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
+        # self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
+
+        # od = nd+32+4
+        # od = nd+32+18
+        od = nd+nd+32+16+2+2
+        self.conv2_0 = conv(od,      128, kernel_size=3, stride=1)
+        self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
+        self.conv2_2 = conv(od+dd[1],96,  kernel_size=3, stride=1)
+        self.conv2_3 = conv(od+dd[2],64,  kernel_size=3, stride=1)
+        self.conv2_4 = conv(od+dd[3],32,  kernel_size=3, stride=1)
+        self.pred_flow2 = predict_flow(od + dd[4])
+
+        self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv2 = conv(128,      128, kernel_size=3, stride=1, padding=2,  dilation=2)
+        self.dc_conv3 = conv(128,      128, kernel_size=3, stride=1, padding=4,  dilation=4)
+        self.dc_conv4 = conv(128,      96,  kernel_size=3, stride=1, padding=8,  dilation=8)
+        self.dc_conv5 = conv(96,       64,  kernel_size=3, stride=1, padding=16, dilation=16)
+        self.dc_conv6 = conv(64,       32,  kernel_size=3, stride=1, padding=1,  dilation=1)
+        self.dc_conv7 = predict_flow(32)
+
+        # self.upfeat5 = deconv()
+
+        self.deform6 = deformable_conv(196, 196)
+        self.deform5 = deformable_conv(128, 128)
+        self.deform4 = deformable_conv(96, 96)
+        self.deform3 = deformable_conv(64, 64)
+        self.deform2 = deformable_conv(32, 32)
+
+    def warp(self, x, flo):
+        """
+        warp an image/tensor (im2) back to im1, according to the optical flow
+        x: [B, C, H, W] (im2)
+        flo: [B, 2, H, W] flow
+        """
+        B, C, H, W = x.size()
+        # mesh grid
+        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
+        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
+        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
+        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
+        grid = torch.cat((xx,yy),1).float()
+
+        if x.is_cuda:
+            grid = grid.cuda()
+        # vgrid = Variable(grid) + flo
+        vgrid = Variable(grid) + torch.flip(flo, [1])
+
+        # scale grid to [-1,1]
+        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
+        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
+
+        # vgrid = vgrid.permute(0,2,3,1)
+        vgrid = vgrid.permute(0,2,3,1).clamp(-1.1, 1.1)
+        output = nn.functional.grid_sample(x, vgrid, align_corners=True)
+        mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
+        mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
+
+        # if W==128:
+        # np.save('mask.npy', mask.cpu().data.numpy())
+        # np.save('warp.npy', output.cpu().data.numpy())
+
+        mask[mask<0.9999] = 0
+        mask[mask>0] = 1
+
+        return output*mask
+
+    def forward(self, x):
+        im1 = x[:,:self.in_ch,:,:]
+        im2 = x[:,self.in_ch:,:,:]
+
+        _, _, srcs = self.MaskFlownet_S(x)
+        c1s, c2s, flows, c30, c40 = srcs
+        c11, c12, c13, c14, c15, c16 = c1s
+        c21, c22, c23, c24, c25, c26 = c2s
+
+        c31 = self.conv1z(self.conv1y(self.conv1x(c30)))
+        c32 = self.conv2z(self.conv2y(self.conv2x(c31)))
+        c33 = self.conv3z(self.conv3y(self.conv3x(c32)))
+        c34 = self.conv4z(self.conv4y(self.conv4x(c33)))
+        c35 = self.conv5z(self.conv5y(self.conv5x(c34)))
+        c36 = self.conv6z(self.conv6y(self.conv6x(c35)))
+
+        c41 = self.conv1z(self.conv1y(self.conv1x(c40)))
+        c42 = self.conv2z(self.conv2y(self.conv2x(c41)))
+        c43 = self.conv3z(self.conv3y(self.conv3x(c42)))
+        c44 = self.conv4z(self.conv4y(self.conv4x(c43)))
+        c45 = self.conv5z(self.conv5y(self.conv5x(c44)))
+        c46 = self.conv6z(self.conv6y(self.conv6x(c45)))
+
+        flow6 = flows[0]
+
+        warp6u = (flow6*self.scale/self.strides[0]).unsqueeze(1)
+        warp6u = torch.repeat_interleave(warp6u, 9, 1)
+        S1, S2, S3, S4, S5 = warp6u.shape
+        warp6u = warp6u.view(S1, S2*S3, S4, S5)
+        warp6u = self.deform6(c26, warp6u)
+        warp6u = self.leakyRELU(warp6u)
+        corr6u = self.leakyRELU(self.corr(c16, warp6u))
+        warp6v = c46
+        corr6v = self.leakyRELU(self.corr(c36, warp6v))
+        x = torch.cat((corr6u, corr6v, flow6),1)
+        x = torch.cat((self.conv6_0(x), x),1)
+        x = torch.cat((self.conv6_1(x), x),1)
+        x = torch.cat((self.conv6_2(x), x),1)
+        x = torch.cat((self.conv6_3(x), x),1)
+        x = torch.cat((self.conv6_4(x), x),1)
+        flow6 = flow6 + self.pred_flow6(x)
+
+        feat5 = self.leakyRELU(self.upfeat5(x))
+        flow5 = Upsample(flow6, 2)
+        warp5u = (flow5*self.scale/self.strides[1]).unsqueeze(1)
+        warp5u = torch.repeat_interleave(warp5u, 9, 1)
+        S1, S2, S3, S4, S5 = warp5u.shape
+        warp5u = warp5u.view(S1, S2*S3, S4, S5)
+        warp5u = self.deform5(c25, warp5u)
+        warp5u = self.leakyRELU(warp5u)
+        corr5u = self.leakyRELU(self.corr(c15, warp5u))
+        warp5v = c45
+        corr5v = self.leakyRELU(self.corr(c35, warp5v))
+        x = torch.cat((c15, feat5, corr5u, corr5v, flow5, flows[1]),1)
+        x = torch.cat((self.conv5_0(x), x),1)
+        x = torch.cat((self.conv5_1(x), x),1)
+        x = torch.cat((self.conv5_2(x), x),1)
+        x = torch.cat((self.conv5_3(x), x),1)
+        x = torch.cat((self.conv5_4(x), x),1)
+        flow5 = flow5 + self.pred_flow5(x)
+
+        feat4 = self.leakyRELU(self.upfeat4(x))
+        flow4 = Upsample(flow5, 2)
+        warp4u = (flow4*self.scale/self.strides[2]).unsqueeze(1)
+        warp4u = torch.repeat_interleave(warp4u, 9, 1)
+        S1, S2, S3, S4, S5 = warp4u.shape
+        warp4u = warp4u.view(S1, S2*S3, S4, S5)
+        warp4u = self.deform4(c24, warp4u)
+        warp4u = self.leakyRELU(warp4u)
+        corr4u = self.leakyRELU(self.corr(c14, warp4u))
+        warp4v = c44
+        corr4v = self.leakyRELU(self.corr(c34, warp4v))
+        x = torch.cat((c14, feat4, corr4u, corr4v, flow4, flows[2]),1)
+        x = torch.cat((self.conv4_0(x), x),1)
+        x = torch.cat((self.conv4_1(x), x),1)
+        x = torch.cat((self.conv4_2(x), x),1)
+        x = torch.cat((self.conv4_3(x), x),1)
+        x = torch.cat((self.conv4_4(x), x),1)
+        flow4 = flow4 + self.pred_flow4(x)
+
+        feat3 = self.leakyRELU(self.upfeat3(x))
+        flow3 = Upsample(flow4, 2)
+        warp3u = (flow3*self.scale/self.strides[3]).unsqueeze(1)
+        warp3u = torch.repeat_interleave(warp3u, 9, 1)
+        S1, S2, S3, S4, S5 = warp3u.shape
+        warp3u = warp3u.view(S1, S2*S3, S4, S5)
+        warp3u = self.deform3(c23, warp3u)
+        warp3u = self.leakyRELU(warp3u)
+        corr3u = self.leakyRELU(self.corr(c13, warp3u))
+        warp3v = c43
+        corr3v = self.leakyRELU(self.corr(c33, warp3v))
+        x = torch.cat((c13, feat3, corr3u, corr3v, flow3, flows[3]),1)
+        x = torch.cat((self.conv3_0(x), x),1)
+        x = torch.cat((self.conv3_1(x), x),1)
+        x = torch.cat((self.conv3_2(x), x),1)
+        x = torch.cat((self.conv3_3(x), x),1)
+        x = torch.cat((self.conv3_4(x), x),1)
+        flow3 = flow3 + self.pred_flow3(x)
+
+        feat2 = self.leakyRELU(self.upfeat2(x))
+        flow2 = Upsample(flow3, 2)
+        warp2u = (flow2*self.scale/self.strides[4]).unsqueeze(1)
+        warp2u = torch.repeat_interleave(warp2u, 9, 1)
+        S1, S2, S3, S4, S5 = warp2u.shape
+        warp2u = warp2u.view(S1, S2*S3, S4, S5)
+        warp2u = self.deform2(c22, warp2u)
+        warp2u = self.leakyRELU(warp2u)
+        corr2u = self.leakyRELU(self.corr(c12, warp2u))
+        warp2v = c42
+        corr2v = self.leakyRELU(self.corr(c32, warp2v))
+        x = torch.cat((c12, feat2, corr2u, corr2v, flow2, flows[4]),1)
+        x = torch.cat((self.conv2_0(x), x),1)
+        x = torch.cat((self.conv2_1(x), x),1)
+        x = torch.cat((self.conv2_2(x), x),1)
+        x = torch.cat((self.conv2_3(x), x),1)
+        x = torch.cat((self.conv2_4(x), x),1)
+        flow2 = flow2 + self.pred_flow2(x)
+
+        x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
+        flow2 = flow2 + self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))
+
+        #preds = [flow * self.scale for flow in [flow6, flow5, flow4, flow3, flow2]]
+        preds = [flow * self.scale for flow in [flow2, flow3, flow4, flow5, flow6]]
+        visuals = []
+        visuals.append(flow2[:,:1])
+        
+        return preds
+        #return preds, visuals, []
+
+
+class EpeLoss(nn.Module):
+    def __init__(self, eps = 0):
+        super(EpeLoss, self).__init__()
+        self.eps = eps
+
+    def forward(self, pred, label):
+        loss = ((pred - label).pow(2).sum(1) + self.eps).sqrt()
+        return loss.view(loss.shape[0], -1).mean(1)
+
+
+class EpeLossWithMask(nn.Module):
+    def __init__(self, eps=1e-8, q=None):
+        super(EpeLossWithMask, self).__init__()
+        self.eps = eps
+        self.q = q
+
+    def forward(self, pred, label, mask):
+        if self.q is not None:
+            loss = ((pred - label).abs().sum(1) + self.eps) ** self.q
+        else:
+            loss = ((pred - label).pow(2).sum(1) + self.eps).sqrt()
+        loss = loss * mask.squeeze(1)
+        loss = loss.view(loss.shape[0], -1).sum(1) / mask.view(mask.shape[0], -1).sum(1)
+        return loss
+
+
+class MultiscaleEpe(nn.Module):
+    def __init__(self, scales, weights, match, eps = 1e-8, q = None):
+        super(MultiscaleEpe, self).__init__()
+
+        self.scales = scales
+        self.weights = weights
+        self.match = match
+        self.eps = eps
+        self.q = q
+
+    def forward(self, flow, mask, *predictions):
+        losses = 0
+        if self.match == 'upsampling':
+            for p, w, s in zip(predictions, self.weights, self.scales):
+                losses += EpeLossWithMask(eps=self.eps, q=self.q)(Upsample(p, s), flow, mask) * w
+        elif self.match == 'downsampling':
+            for p, w, s in zip(predictions, self.weights, self.scales):
+                losses += EpeLossWithMask(eps=self.eps, q=self.q)(p, Downsample(flow, s), Downsample(mask, s)) * w
+        else:
+            raise NotImplementedError
+        return losses
diff --git a/windflow/networks/models.py b/windflow/networks/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..395e30f2b0cdc2c359197cca365c7cb946a65e4d
--- /dev/null
+++ b/windflow/networks/models.py
@@ -0,0 +1,18 @@
+from . import warper, RAFT
+
+
+def get_flow_model(model_name, small=True):
+    # select base model
+    model_name = model_name.lower()
+    if model_name == 'raft':
+        raft_args ={'small': small,
+                    'lr': 1e-5,
+                    'mixed_precision': True,
+                    'dropout': 0.0,
+                    'corr_levels': 4,
+                    'corr_radius': 4}
+        model = RAFT(raft_args)
+    else:
+        model = None
+
+    return model
diff --git a/windflow/networks/raft/__init__.py b/windflow/networks/raft/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/networks/raft/corr.py b/windflow/networks/raft/corr.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae260b992728197063fbe0f68b87b20459fba077
--- /dev/null
+++ b/windflow/networks/raft/corr.py
@@ -0,0 +1,91 @@
+import torch
+import torch.nn.functional as F
+from .utils.utils import bilinear_sampler, coords_grid
+
+try:
+    import alt_cuda_corr
+except:
+    # alt_cuda_corr is not compiled
+    pass
+
+
+class CorrBlock:
+    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+        self.num_levels = num_levels
+        self.radius = radius
+        self.corr_pyramid = []
+
+        # all pairs correlation
+        corr = CorrBlock.corr(fmap1, fmap2)
+
+        batch, h1, w1, dim, h2, w2 = corr.shape
+        corr = corr.reshape(batch*h1*w1, dim, h2, w2)
+        
+        self.corr_pyramid.append(corr)
+        for i in range(self.num_levels-1):
+            corr = F.avg_pool2d(corr, 2, stride=2)
+            self.corr_pyramid.append(corr)
+
+    def __call__(self, coords):
+        r = self.radius
+        coords = coords.permute(0, 2, 3, 1)
+        batch, h1, w1, _ = coords.shape
+
+        out_pyramid = []
+        for i in range(self.num_levels):
+            corr = self.corr_pyramid[i].float()
+            dx = torch.linspace(-r, r, 2*r+1)
+            dy = torch.linspace(-r, r, 2*r+1)
+            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
+
+            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
+            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
+            coords_lvl = centroid_lvl + delta_lvl
+
+            corr = bilinear_sampler(corr, coords_lvl)
+            corr = corr.view(batch, h1, w1, -1)
+            out_pyramid.append(corr)
+
+        out = torch.cat(out_pyramid, dim=-1)
+        return out.permute(0, 3, 1, 2).contiguous().float()
+
+    @staticmethod
+    def corr(fmap1, fmap2):
+        batch, dim, ht, wd = fmap1.shape
+        fmap1 = fmap1.view(batch, dim, ht*wd)
+        fmap2 = fmap2.view(batch, dim, ht*wd) 
+        
+        corr = torch.matmul(fmap1.transpose(1,2), fmap2)
+        corr = corr.view(batch, ht, wd, 1, ht, wd)
+        return corr  / torch.sqrt(torch.tensor(dim).float())
+
+
+class AlternateCorrBlock:
+    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+        self.num_levels = num_levels
+        self.radius = radius
+
+        self.pyramid = [(fmap1, fmap2)]
+        for i in range(self.num_levels):
+            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
+            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
+            self.pyramid.append((fmap1, fmap2))
+
+    def __call__(self, coords):
+        coords = coords.permute(0, 2, 3, 1)
+        B, H, W, _ = coords.shape
+        dim = self.pyramid[0][0].shape[1]
+
+        corr_list = []
+        for i in range(self.num_levels):
+            r = self.radius
+            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
+            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
+
+            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
+            corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
+            corr_list.append(corr.squeeze(1))
+
+        corr = torch.stack(corr_list, dim=1)
+        corr = corr.reshape(B, -1, H, W)
+        return corr / torch.sqrt(torch.tensor(dim).float())
diff --git a/windflow/networks/raft/datasets.py b/windflow/networks/raft/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..3411fdacfb900024005e8997d07c600e963a95ca
--- /dev/null
+++ b/windflow/networks/raft/datasets.py
@@ -0,0 +1,235 @@
+# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torch.nn.functional as F
+
+import os
+import math
+import random
+from glob import glob
+import os.path as osp
+
+from utils import frame_utils
+from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
+
+
+class FlowDataset(data.Dataset):
+    def __init__(self, aug_params=None, sparse=False):
+        self.augmentor = None
+        self.sparse = sparse
+        if aug_params is not None:
+            if sparse:
+                self.augmentor = SparseFlowAugmentor(**aug_params)
+            else:
+                self.augmentor = FlowAugmentor(**aug_params)
+
+        self.is_test = False
+        self.init_seed = False
+        self.flow_list = []
+        self.image_list = []
+        self.extra_info = []
+
+    def __getitem__(self, index):
+
+        if self.is_test:
+            img1 = frame_utils.read_gen(self.image_list[index][0])
+            img2 = frame_utils.read_gen(self.image_list[index][1])
+            img1 = np.array(img1).astype(np.uint8)[..., :3]
+            img2 = np.array(img2).astype(np.uint8)[..., :3]
+            img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+            img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+            return img1, img2, self.extra_info[index]
+
+        if not self.init_seed:
+            worker_info = torch.utils.data.get_worker_info()
+            if worker_info is not None:
+                torch.manual_seed(worker_info.id)
+                np.random.seed(worker_info.id)
+                random.seed(worker_info.id)
+                self.init_seed = True
+
+        index = index % len(self.image_list)
+        valid = None
+        if self.sparse:
+            flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
+        else:
+            flow = frame_utils.read_gen(self.flow_list[index])
+
+        img1 = frame_utils.read_gen(self.image_list[index][0])
+        img2 = frame_utils.read_gen(self.image_list[index][1])
+
+        flow = np.array(flow).astype(np.float32)
+        img1 = np.array(img1).astype(np.uint8)
+        img2 = np.array(img2).astype(np.uint8)
+
+        # grayscale images
+        if len(img1.shape) == 2:
+            img1 = np.tile(img1[...,None], (1, 1, 3))
+            img2 = np.tile(img2[...,None], (1, 1, 3))
+        else:
+            img1 = img1[..., :3]
+            img2 = img2[..., :3]
+
+        if self.augmentor is not None:
+            if self.sparse:
+                img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
+            else:
+                img1, img2, flow = self.augmentor(img1, img2, flow)
+
+        img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+        img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+        flow = torch.from_numpy(flow).permute(2, 0, 1).float()
+
+        if valid is not None:
+            valid = torch.from_numpy(valid)
+        else:
+            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
+
+        return img1, img2, flow, valid.float()
+
+
+    def __rmul__(self, v):
+        self.flow_list = v * self.flow_list
+        self.image_list = v * self.image_list
+        return self
+        
+    def __len__(self):
+        return len(self.image_list)
+        
+
+class MpiSintel(FlowDataset):
+    def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
+        super(MpiSintel, self).__init__(aug_params)
+        flow_root = osp.join(root, split, 'flow')
+        image_root = osp.join(root, split, dstype)
+
+        if split == 'test':
+            self.is_test = True
+
+        for scene in os.listdir(image_root):
+            image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
+            for i in range(len(image_list)-1):
+                self.image_list += [ [image_list[i], image_list[i+1]] ]
+                self.extra_info += [ (scene, i) ] # scene and frame_id
+
+            if split != 'test':
+                self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
+
+
+class FlyingChairs(FlowDataset):
+    def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
+        super(FlyingChairs, self).__init__(aug_params)
+
+        images = sorted(glob(osp.join(root, '*.ppm')))
+        flows = sorted(glob(osp.join(root, '*.flo')))
+        assert (len(images)//2 == len(flows))
+
+        split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
+        for i in range(len(flows)):
+            xid = split_list[i]
+            if (split=='training' and xid==1) or (split=='validation' and xid==2):
+                self.flow_list += [ flows[i] ]
+                self.image_list += [ [images[2*i], images[2*i+1]] ]
+
+
+class FlyingThings3D(FlowDataset):
+    def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
+        super(FlyingThings3D, self).__init__(aug_params)
+
+        for cam in ['left']:
+            for direction in ['into_future', 'into_past']:
+                image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
+                image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
+
+                flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
+                flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
+
+                for idir, fdir in zip(image_dirs, flow_dirs):
+                    images = sorted(glob(osp.join(idir, '*.png')) )
+                    flows = sorted(glob(osp.join(fdir, '*.pfm')) )
+                    for i in range(len(flows)-1):
+                        if direction == 'into_future':
+                            self.image_list += [ [images[i], images[i+1]] ]
+                            self.flow_list += [ flows[i] ]
+                        elif direction == 'into_past':
+                            self.image_list += [ [images[i+1], images[i]] ]
+                            self.flow_list += [ flows[i+1] ]
+      
+
+class KITTI(FlowDataset):
+    def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
+        super(KITTI, self).__init__(aug_params, sparse=True)
+        if split == 'testing':
+            self.is_test = True
+
+        root = osp.join(root, split)
+        images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
+        images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
+
+        for img1, img2 in zip(images1, images2):
+            frame_id = img1.split('/')[-1]
+            self.extra_info += [ [frame_id] ]
+            self.image_list += [ [img1, img2] ]
+
+        if split == 'training':
+            self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
+
+
+class HD1K(FlowDataset):
+    def __init__(self, aug_params=None, root='datasets/HD1k'):
+        super(HD1K, self).__init__(aug_params, sparse=True)
+
+        seq_ix = 0
+        while 1:
+            flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
+            images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
+
+            if len(flows) == 0:
+                break
+
+            for i in range(len(flows)-1):
+                self.flow_list += [flows[i]]
+                self.image_list += [ [images[i], images[i+1]] ]
+
+            seq_ix += 1
+
+
+def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
+    """ Create the data loader for the corresponding trainign set """
+
+    if args.stage == 'chairs':
+        aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
+        train_dataset = FlyingChairs(aug_params, split='training')
+    
+    elif args.stage == 'things':
+        aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
+        clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
+        final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
+        train_dataset = clean_dataset + final_dataset
+
+    elif args.stage == 'sintel':
+        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
+        things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
+        sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
+        sintel_final = MpiSintel(aug_params, split='training', dstype='final')        
+
+        if TRAIN_DS == 'C+T+K+S+H':
+            kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
+            hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
+            train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
+
+        elif TRAIN_DS == 'C+T+K/S':
+            train_dataset = 100*sintel_clean + 100*sintel_final + things
+
+    elif args.stage == 'kitti':
+        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
+        train_dataset = KITTI(aug_params, split='training')
+
+    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 
+        pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
+
+    print('Training with %d image pairs' % len(train_dataset))
+    return train_loader
+
diff --git a/windflow/networks/raft/extractor.py b/windflow/networks/raft/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..312f2b473fcbcf60722157add2db15deed3e5a94
--- /dev/null
+++ b/windflow/networks/raft/extractor.py
@@ -0,0 +1,267 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+        super(ResidualBlock, self).__init__()
+  
+        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+        self.relu = nn.ReLU(inplace=True)
+
+        num_groups = planes // 8
+
+        if norm_fn == 'group':
+            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+            if not stride == 1:
+                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+        
+        elif norm_fn == 'batch':
+            self.norm1 = nn.BatchNorm2d(planes)
+            self.norm2 = nn.BatchNorm2d(planes)
+            if not stride == 1:
+                self.norm3 = nn.BatchNorm2d(planes)
+        
+        elif norm_fn == 'instance':
+            self.norm1 = nn.InstanceNorm2d(planes)
+            self.norm2 = nn.InstanceNorm2d(planes)
+            if not stride == 1:
+                self.norm3 = nn.InstanceNorm2d(planes)
+
+        elif norm_fn == 'none':
+            self.norm1 = nn.Sequential()
+            self.norm2 = nn.Sequential()
+            if not stride == 1:
+                self.norm3 = nn.Sequential()
+
+        if stride == 1:
+            self.downsample = None
+        
+        else:    
+            self.downsample = nn.Sequential(
+                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+
+    def forward(self, x):
+        y = x
+        y = self.relu(self.norm1(self.conv1(y)))
+        y = self.relu(self.norm2(self.conv2(y)))
+
+        if self.downsample is not None:
+            x = self.downsample(x)
+
+        return self.relu(x+y)
+
+
+
+class BottleneckBlock(nn.Module):
+    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+        super(BottleneckBlock, self).__init__()
+  
+        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
+        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
+        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
+        self.relu = nn.ReLU(inplace=True)
+
+        num_groups = planes // 8
+
+        if norm_fn == 'group':
+            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+            if not stride == 1:
+                self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+        
+        elif norm_fn == 'batch':
+            self.norm1 = nn.BatchNorm2d(planes//4)
+            self.norm2 = nn.BatchNorm2d(planes//4)
+            self.norm3 = nn.BatchNorm2d(planes)
+            if not stride == 1:
+                self.norm4 = nn.BatchNorm2d(planes)
+        
+        elif norm_fn == 'instance':
+            self.norm1 = nn.InstanceNorm2d(planes//4)
+            self.norm2 = nn.InstanceNorm2d(planes//4)
+            self.norm3 = nn.InstanceNorm2d(planes)
+            if not stride == 1:
+                self.norm4 = nn.InstanceNorm2d(planes)
+
+        elif norm_fn == 'none':
+            self.norm1 = nn.Sequential()
+            self.norm2 = nn.Sequential()
+            self.norm3 = nn.Sequential()
+            if not stride == 1:
+                self.norm4 = nn.Sequential()
+
+        if stride == 1:
+            self.downsample = None
+        
+        else:    
+            self.downsample = nn.Sequential(
+                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
+
+
+    def forward(self, x):
+        y = x
+        y = self.relu(self.norm1(self.conv1(y)))
+        y = self.relu(self.norm2(self.conv2(y)))
+        y = self.relu(self.norm3(self.conv3(y)))
+
+        if self.downsample is not None:
+            x = self.downsample(x)
+
+        return self.relu(x+y)
+
+class BasicEncoder(nn.Module):
+    def __init__(self, input_dim=1, output_dim=128, norm_fn='batch', dropout=0.0):
+        super(BasicEncoder, self).__init__()
+        self.norm_fn = norm_fn
+
+        if self.norm_fn == 'group':
+            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
+            
+        elif self.norm_fn == 'batch':
+            self.norm1 = nn.BatchNorm2d(64)
+
+        elif self.norm_fn == 'instance':
+            self.norm1 = nn.InstanceNorm2d(64)
+
+        elif self.norm_fn == 'none':
+            self.norm1 = nn.Sequential()
+
+        self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.in_planes = 64
+        self.layer1 = self._make_layer(64,  stride=1)
+        self.layer2 = self._make_layer(96, stride=2)
+        self.layer3 = self._make_layer(128, stride=2)
+
+        # output convolution
+        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
+
+        self.dropout = None
+        if dropout > 0:
+            self.dropout = nn.Dropout2d(p=dropout)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+                if m.weight is not None:
+                    nn.init.constant_(m.weight, 1)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+
+    def _make_layer(self, dim, stride=1):
+        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+        layers = (layer1, layer2)
+        
+        self.in_planes = dim
+        return nn.Sequential(*layers)
+
+
+    def forward(self, x):
+
+        # if input is list, combine batch dimension
+        is_list = isinstance(x, tuple) or isinstance(x, list)
+        if is_list:
+            batch_dim = x[0].shape[0]
+            x = torch.cat(x, dim=0)
+
+        x = self.conv1(x)
+        x = self.norm1(x)
+        x = self.relu1(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        x = self.conv2(x)
+
+        if self.training and self.dropout is not None:
+            x = self.dropout(x)
+
+        if is_list:
+            x = torch.split(x, [batch_dim, batch_dim], dim=0)
+
+        return x
+
+
+class SmallEncoder(nn.Module):
+    def __init__(self, input_dim=1, output_dim=128, norm_fn='batch', dropout=0.0):
+        super(SmallEncoder, self).__init__()
+        self.norm_fn = norm_fn
+
+        if self.norm_fn == 'group':
+            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
+            
+        elif self.norm_fn == 'batch':
+            self.norm1 = nn.BatchNorm2d(32)
+
+        elif self.norm_fn == 'instance':
+            self.norm1 = nn.InstanceNorm2d(32)
+
+        elif self.norm_fn == 'none':
+            self.norm1 = nn.Sequential()
+
+        self.conv1 = nn.Conv2d(input_dim, 32, kernel_size=7, stride=2, padding=3)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.in_planes = 32
+        self.layer1 = self._make_layer(32,  stride=1)
+        self.layer2 = self._make_layer(64, stride=2)
+        self.layer3 = self._make_layer(96, stride=2)
+
+        self.dropout = None
+        if dropout > 0:
+            self.dropout = nn.Dropout2d(p=dropout)
+        
+        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+                if m.weight is not None:
+                    nn.init.constant_(m.weight, 1)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+
+    def _make_layer(self, dim, stride=1):
+        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
+        layers = (layer1, layer2)
+    
+        self.in_planes = dim
+        return nn.Sequential(*layers)
+
+
+    def forward(self, x):
+
+        # if input is list, combine batch dimension
+        is_list = isinstance(x, tuple) or isinstance(x, list)
+        if is_list:
+            batch_dim = x[0].shape[0]
+            x = torch.cat(x, dim=0)
+
+        x = self.conv1(x)
+        x = self.norm1(x)
+        x = self.relu1(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.conv2(x)
+
+        if self.training and self.dropout is not None:
+            x = self.dropout(x)
+
+        if is_list:
+            x = torch.split(x, [batch_dim, batch_dim], dim=0)
+
+        return x
diff --git a/windflow/networks/raft/raft.py b/windflow/networks/raft/raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f7973d246324c2aa92f740d63a14d9c7b0da091
--- /dev/null
+++ b/windflow/networks/raft/raft.py
@@ -0,0 +1,143 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .update import BasicUpdateBlock, SmallUpdateBlock
+from .extractor import BasicEncoder, SmallEncoder
+from .corr import CorrBlock, AlternateCorrBlock
+from .utils.utils import bilinear_sampler, coords_grid, upflow8
+
+try:
+    autocast = torch.cuda.amp.autocast
+except:
+    # dummy autocast for PyTorch < 1.6
+    class autocast:
+        def __init__(self, enabled):
+            pass
+        def __enter__(self):
+            pass
+        def __exit__(self, *args):
+            pass
+
+
+class RAFT(nn.Module):
+    def __init__(self, args):
+        super(RAFT, self).__init__()
+        self.args = args
+
+        if args['small']:
+            self.hidden_dim = hdim = 96
+            self.context_dim = cdim = 64
+            args['corr_levels'] = 4
+            args['corr_radius'] = 3
+        
+        else:
+            self.hidden_dim = hdim = 128
+            self.context_dim = cdim = 128
+            args['corr_levels'] = 4
+            args['corr_radius'] = 4
+
+        if 'dropout' not in self.args:
+            self.args['dropout'] = 0.
+
+        if 'alternate_corr' not in self.args:
+            self.args['alternate_corr'] = False
+
+        # feature network, context network, and update block
+        if args['small']:
+            self.fnet = SmallEncoder(input_dim=1, output_dim=128, norm_fn='instance', dropout=args['dropout'])        
+            self.cnet = SmallEncoder(input_dim=1, output_dim=hdim+cdim, norm_fn='none', dropout=args['dropout'])
+            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
+
+        else:
+            self.fnet = BasicEncoder(input_dim=1, output_dim=256, norm_fn='instance', dropout=args['dropout'])        
+            self.cnet = BasicEncoder(input_dim=1, output_dim=hdim+cdim, norm_fn='batch', dropout=args['dropout'])
+            self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
+
+    def freeze_bn(self):
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eval()
+
+    def initialize_flow(self, img):
+        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+        N, C, H, W = img.shape
+        coords0 = coords_grid(N, H//8, W//8).to(img.device)
+        coords1 = coords_grid(N, H//8, W//8).to(img.device)
+
+        # optical flow computed as difference: flow = coords1 - coords0
+        return coords0, coords1
+
+    def upsample_flow(self, flow, mask):
+        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+        N, _, H, W = flow.shape
+        mask = mask.view(N, 1, 9, 8, 8, H, W)
+        mask = torch.softmax(mask, dim=2)
+
+        up_flow = F.unfold(8 * flow, [3,3], padding=1)
+        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
+
+        up_flow = torch.sum(mask * up_flow, dim=2)
+        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+        return up_flow.reshape(N, 2, 8*H, 8*W)
+
+
+    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
+        """ Estimate optical flow between pair of frames """
+
+        #image1 = 2 * (image1 / 255.0) - 1.0
+        #image2 = 2 * (image2 / 255.0) - 1.0
+
+        #image1 = image1.contiguous()
+        #image2 = image2.contiguous()
+
+        hdim = self.hidden_dim
+        cdim = self.context_dim
+
+        # run the feature network
+        with autocast(enabled=self.args['mixed_precision']):
+            fmap1, fmap2 = self.fnet([image1, image2])        
+        
+        fmap1 = fmap1.float()
+        fmap2 = fmap2.float()
+        if self.args['alternate_corr']:
+            corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args['corr_radius'])
+        else:
+            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args['corr_radius'])
+
+        # run the context network
+        with autocast(enabled=self.args['mixed_precision']):
+            cnet = self.cnet(image1)
+            net, inp = torch.split(cnet, [hdim, cdim], dim=1)
+            net = torch.tanh(net)
+            inp = torch.relu(inp)
+
+        coords0, coords1 = self.initialize_flow(image1)
+        if flow_init is not None:
+            coords1 = coords1 + flow_init
+
+        flow_predictions = []
+        for itr in range(iters):
+            coords1 = coords1.detach()
+            corr = corr_fn(coords1.float()) # index correlation volume
+
+            flow = coords1 - coords0
+            with autocast(enabled=self.args['mixed_precision']):
+                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
+
+            # F(t+1) = F(t) + \Delta(t)
+            coords1 = coords1 + delta_flow
+
+            # upsample predictions
+            if up_mask is None:
+                flow_up = upflow8(coords1 - coords0)
+            else:
+                flow_up = self.upsample_flow(coords1 - coords0, up_mask)
+            
+            flow_predictions.append(flow_up)
+
+        if test_mode:
+            return flow_up,
+            
+        return flow_predictions
diff --git a/windflow/networks/raft/update.py b/windflow/networks/raft/update.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c92a4922bf0e45f16e55a5deb15c0271cb2fee
--- /dev/null
+++ b/windflow/networks/raft/update.py
@@ -0,0 +1,139 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FlowHead(nn.Module):
+    def __init__(self, input_dim=128, hidden_dim=256):
+        super(FlowHead, self).__init__()
+        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        return self.conv2(self.relu(self.conv1(x)))
+
+class ConvGRU(nn.Module):
+    def __init__(self, hidden_dim=128, input_dim=192+128):
+        super(ConvGRU, self).__init__()
+        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+
+    def forward(self, h, x):
+        hx = torch.cat([h, x], dim=1)
+
+        z = torch.sigmoid(self.convz(hx))
+        r = torch.sigmoid(self.convr(hx))
+        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
+
+        h = (1-z) * h + z * q
+        return h
+
+class SepConvGRU(nn.Module):
+    def __init__(self, hidden_dim=128, input_dim=192+128):
+        super(SepConvGRU, self).__init__()
+        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+
+        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+
+
+    def forward(self, h, x):
+        # horizontal
+        hx = torch.cat([h, x], dim=1)
+        z = torch.sigmoid(self.convz1(hx))
+        r = torch.sigmoid(self.convr1(hx))
+        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        
+        h = (1-z) * h + z * q
+
+        # vertical
+        hx = torch.cat([h, x], dim=1)
+        z = torch.sigmoid(self.convz2(hx))
+        r = torch.sigmoid(self.convr2(hx))
+        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       
+        h = (1-z) * h + z * q
+
+        return h
+
+class SmallMotionEncoder(nn.Module):
+    def __init__(self, args):
+        super(SmallMotionEncoder, self).__init__()
+        cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2
+        self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
+        self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
+        self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
+        self.conv = nn.Conv2d(128, 80, 3, padding=1)
+
+    def forward(self, flow, corr):
+        cor = F.relu(self.convc1(corr))
+        flo = F.relu(self.convf1(flow))
+        flo = F.relu(self.convf2(flo))
+        cor_flo = torch.cat([cor, flo], dim=1)
+        out = F.relu(self.conv(cor_flo))
+        return torch.cat([out, flow], dim=1)
+
+class BasicMotionEncoder(nn.Module):
+    def __init__(self, args):
+        super(BasicMotionEncoder, self).__init__()
+        cor_planes = args['corr_levels'] * (2*args['corr_radius'] + 1)**2
+        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
+        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
+        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
+        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
+        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
+
+    def forward(self, flow, corr):
+        cor = F.relu(self.convc1(corr))
+        cor = F.relu(self.convc2(cor))
+        flo = F.relu(self.convf1(flow))
+        flo = F.relu(self.convf2(flo))
+
+        cor_flo = torch.cat([cor, flo], dim=1)
+        out = F.relu(self.conv(cor_flo))
+        return torch.cat([out, flow], dim=1)
+
+class SmallUpdateBlock(nn.Module):
+    def __init__(self, args, hidden_dim=96):
+        super(SmallUpdateBlock, self).__init__()
+        self.encoder = SmallMotionEncoder(args)
+        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
+        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
+
+    def forward(self, net, inp, corr, flow):
+        motion_features = self.encoder(flow, corr)
+        inp = torch.cat([inp, motion_features], dim=1)
+        net = self.gru(net, inp)
+        delta_flow = self.flow_head(net)
+
+        return net, None, delta_flow
+
+class BasicUpdateBlock(nn.Module):
+    def __init__(self, args, hidden_dim=128, input_dim=128):
+        super(BasicUpdateBlock, self).__init__()
+        self.args = args
+        self.encoder = BasicMotionEncoder(args)
+        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
+        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
+
+        self.mask = nn.Sequential(
+            nn.Conv2d(128, 256, 3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(256, 64*9, 1, padding=0))
+
+    def forward(self, net, inp, corr, flow, upsample=True):
+        motion_features = self.encoder(flow, corr)
+        inp = torch.cat([inp, motion_features], dim=1)
+
+        net = self.gru(net, inp)
+        delta_flow = self.flow_head(net)
+
+        # scale mask to balence gradients
+        mask = .25 * self.mask(net)
+        return net, mask, delta_flow
+
+
+
diff --git a/windflow/networks/raft/utils/__init__.py b/windflow/networks/raft/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/networks/raft/utils/augmentor.py b/windflow/networks/raft/utils/augmentor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e81c4f2b5c16c31c0ae236d744f299d430228a04
--- /dev/null
+++ b/windflow/networks/raft/utils/augmentor.py
@@ -0,0 +1,246 @@
+import numpy as np
+import random
+import math
+from PIL import Image
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+import torch
+from torchvision.transforms import ColorJitter
+import torch.nn.functional as F
+
+
+class FlowAugmentor:
+    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
+        
+        # spatial augmentation params
+        self.crop_size = crop_size
+        self.min_scale = min_scale
+        self.max_scale = max_scale
+        self.spatial_aug_prob = 0.8
+        self.stretch_prob = 0.8
+        self.max_stretch = 0.2
+
+        # flip augmentation params
+        self.do_flip = do_flip
+        self.h_flip_prob = 0.5
+        self.v_flip_prob = 0.1
+
+        # photometric augmentation params
+        self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
+        self.asymmetric_color_aug_prob = 0.2
+        self.eraser_aug_prob = 0.5
+
+    def color_transform(self, img1, img2):
+        """ Photometric augmentation """
+
+        # asymmetric
+        if np.random.rand() < self.asymmetric_color_aug_prob:
+            img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
+            img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
+
+        # symmetric
+        else:
+            image_stack = np.concatenate([img1, img2], axis=0)
+            image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+            img1, img2 = np.split(image_stack, 2, axis=0)
+
+        return img1, img2
+
+    def eraser_transform(self, img1, img2, bounds=[50, 100]):
+        """ Occlusion augmentation """
+
+        ht, wd = img1.shape[:2]
+        if np.random.rand() < self.eraser_aug_prob:
+            mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+            for _ in range(np.random.randint(1, 3)):
+                x0 = np.random.randint(0, wd)
+                y0 = np.random.randint(0, ht)
+                dx = np.random.randint(bounds[0], bounds[1])
+                dy = np.random.randint(bounds[0], bounds[1])
+                img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+        return img1, img2
+
+    def spatial_transform(self, img1, img2, flow):
+        # randomly sample scale
+        ht, wd = img1.shape[:2]
+        min_scale = np.maximum(
+            (self.crop_size[0] + 8) / float(ht), 
+            (self.crop_size[1] + 8) / float(wd))
+
+        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+        scale_x = scale
+        scale_y = scale
+        if np.random.rand() < self.stretch_prob:
+            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+        
+        scale_x = np.clip(scale_x, min_scale, None)
+        scale_y = np.clip(scale_y, min_scale, None)
+
+        if np.random.rand() < self.spatial_aug_prob:
+            # rescale the images
+            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+            flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+            flow = flow * [scale_x, scale_y]
+
+        if self.do_flip:
+            if np.random.rand() < self.h_flip_prob: # h-flip
+                img1 = img1[:, ::-1]
+                img2 = img2[:, ::-1]
+                flow = flow[:, ::-1] * [-1.0, 1.0]
+
+            if np.random.rand() < self.v_flip_prob: # v-flip
+                img1 = img1[::-1, :]
+                img2 = img2[::-1, :]
+                flow = flow[::-1, :] * [1.0, -1.0]
+
+        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
+        x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
+        
+        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+
+        return img1, img2, flow
+
+    def __call__(self, img1, img2, flow):
+        img1, img2 = self.color_transform(img1, img2)
+        img1, img2 = self.eraser_transform(img1, img2)
+        img1, img2, flow = self.spatial_transform(img1, img2, flow)
+
+        img1 = np.ascontiguousarray(img1)
+        img2 = np.ascontiguousarray(img2)
+        flow = np.ascontiguousarray(flow)
+
+        return img1, img2, flow
+
+class SparseFlowAugmentor:
+    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
+        # spatial augmentation params
+        self.crop_size = crop_size
+        self.min_scale = min_scale
+        self.max_scale = max_scale
+        self.spatial_aug_prob = 0.8
+        self.stretch_prob = 0.8
+        self.max_stretch = 0.2
+
+        # flip augmentation params
+        self.do_flip = do_flip
+        self.h_flip_prob = 0.5
+        self.v_flip_prob = 0.1
+
+        # photometric augmentation params
+        self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
+        self.asymmetric_color_aug_prob = 0.2
+        self.eraser_aug_prob = 0.5
+        
+    def color_transform(self, img1, img2):
+        image_stack = np.concatenate([img1, img2], axis=0)
+        image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+        img1, img2 = np.split(image_stack, 2, axis=0)
+        return img1, img2
+
+    def eraser_transform(self, img1, img2):
+        ht, wd = img1.shape[:2]
+        if np.random.rand() < self.eraser_aug_prob:
+            mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+            for _ in range(np.random.randint(1, 3)):
+                x0 = np.random.randint(0, wd)
+                y0 = np.random.randint(0, ht)
+                dx = np.random.randint(50, 100)
+                dy = np.random.randint(50, 100)
+                img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+        return img1, img2
+
+    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
+        ht, wd = flow.shape[:2]
+        coords = np.meshgrid(np.arange(wd), np.arange(ht))
+        coords = np.stack(coords, axis=-1)
+
+        coords = coords.reshape(-1, 2).astype(np.float32)
+        flow = flow.reshape(-1, 2).astype(np.float32)
+        valid = valid.reshape(-1).astype(np.float32)
+
+        coords0 = coords[valid>=1]
+        flow0 = flow[valid>=1]
+
+        ht1 = int(round(ht * fy))
+        wd1 = int(round(wd * fx))
+
+        coords1 = coords0 * [fx, fy]
+        flow1 = flow0 * [fx, fy]
+
+        xx = np.round(coords1[:,0]).astype(np.int32)
+        yy = np.round(coords1[:,1]).astype(np.int32)
+
+        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
+        xx = xx[v]
+        yy = yy[v]
+        flow1 = flow1[v]
+
+        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
+        valid_img = np.zeros([ht1, wd1], dtype=np.int32)
+
+        flow_img[yy, xx] = flow1
+        valid_img[yy, xx] = 1
+
+        return flow_img, valid_img
+
+    def spatial_transform(self, img1, img2, flow, valid):
+        # randomly sample scale
+
+        ht, wd = img1.shape[:2]
+        min_scale = np.maximum(
+            (self.crop_size[0] + 1) / float(ht), 
+            (self.crop_size[1] + 1) / float(wd))
+
+        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+        scale_x = np.clip(scale, min_scale, None)
+        scale_y = np.clip(scale, min_scale, None)
+
+        if np.random.rand() < self.spatial_aug_prob:
+            # rescale the images
+            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+            flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
+
+        if self.do_flip:
+            if np.random.rand() < 0.5: # h-flip
+                img1 = img1[:, ::-1]
+                img2 = img2[:, ::-1]
+                flow = flow[:, ::-1] * [-1.0, 1.0]
+                valid = valid[:, ::-1]
+
+        margin_y = 20
+        margin_x = 50
+
+        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
+        x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
+
+        y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
+        x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
+
+        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+        valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+        return img1, img2, flow, valid
+
+
+    def __call__(self, img1, img2, flow, valid):
+        img1, img2 = self.color_transform(img1, img2)
+        img1, img2 = self.eraser_transform(img1, img2)
+        img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
+
+        img1 = np.ascontiguousarray(img1)
+        img2 = np.ascontiguousarray(img2)
+        flow = np.ascontiguousarray(flow)
+        valid = np.ascontiguousarray(valid)
+
+        return img1, img2, flow, valid
diff --git a/windflow/networks/raft/utils/flow_viz.py b/windflow/networks/raft/utils/flow_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcee65e89b91b07ee0496aeb4c7e7436abf99641
--- /dev/null
+++ b/windflow/networks/raft/utils/flow_viz.py
@@ -0,0 +1,132 @@
+# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
+
+
+# MIT License
+#
+# Copyright (c) 2018 Tom Runia
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to conditions.
+#
+# Author: Tom Runia
+# Date Created: 2018-08-03
+
+import numpy as np
+
+def make_colorwheel():
+    """
+    Generates a color wheel for optical flow visualization as presented in:
+        Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+
+    Code follows the original C++ source code of Daniel Scharstein.
+    Code follows the the Matlab source code of Deqing Sun.
+
+    Returns:
+        np.ndarray: Color wheel
+    """
+
+    RY = 15
+    YG = 6
+    GC = 4
+    CB = 11
+    BM = 13
+    MR = 6
+
+    ncols = RY + YG + GC + CB + BM + MR
+    colorwheel = np.zeros((ncols, 3))
+    col = 0
+
+    # RY
+    colorwheel[0:RY, 0] = 255
+    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
+    col = col+RY
+    # YG
+    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
+    colorwheel[col:col+YG, 1] = 255
+    col = col+YG
+    # GC
+    colorwheel[col:col+GC, 1] = 255
+    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
+    col = col+GC
+    # CB
+    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
+    colorwheel[col:col+CB, 2] = 255
+    col = col+CB
+    # BM
+    colorwheel[col:col+BM, 2] = 255
+    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
+    col = col+BM
+    # MR
+    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
+    colorwheel[col:col+MR, 0] = 255
+    return colorwheel
+
+
+def flow_uv_to_colors(u, v, convert_to_bgr=False):
+    """
+    Applies the flow color wheel to (possibly clipped) flow components u and v.
+
+    According to the C++ source code of Daniel Scharstein
+    According to the Matlab source code of Deqing Sun
+
+    Args:
+        u (np.ndarray): Input horizontal flow of shape [H,W]
+        v (np.ndarray): Input vertical flow of shape [H,W]
+        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+    Returns:
+        np.ndarray: Flow visualization image of shape [H,W,3]
+    """
+    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+    colorwheel = make_colorwheel()  # shape [55x3]
+    ncols = colorwheel.shape[0]
+    rad = np.sqrt(np.square(u) + np.square(v))
+    a = np.arctan2(-v, -u)/np.pi
+    fk = (a+1) / 2*(ncols-1)
+    k0 = np.floor(fk).astype(np.int32)
+    k1 = k0 + 1
+    k1[k1 == ncols] = 0
+    f = fk - k0
+    for i in range(colorwheel.shape[1]):
+        tmp = colorwheel[:,i]
+        col0 = tmp[k0] / 255.0
+        col1 = tmp[k1] / 255.0
+        col = (1-f)*col0 + f*col1
+        idx = (rad <= 1)
+        col[idx]  = 1 - rad[idx] * (1-col[idx])
+        col[~idx] = col[~idx] * 0.75   # out of range
+        # Note the 2-i => BGR instead of RGB
+        ch_idx = 2-i if convert_to_bgr else i
+        flow_image[:,:,ch_idx] = np.floor(255 * col)
+    return flow_image
+
+
+def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
+    """
+    Expects a two dimensional flow image of shape.
+
+    Args:
+        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
+        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
+        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+    Returns:
+        np.ndarray: Flow visualization image of shape [H,W,3]
+    """
+    assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+    if clip_flow is not None:
+        flow_uv = np.clip(flow_uv, 0, clip_flow)
+    u = flow_uv[:,:,0]
+    v = flow_uv[:,:,1]
+    rad = np.sqrt(np.square(u) + np.square(v))
+    rad_max = np.max(rad)
+    epsilon = 1e-5
+    u = u / (rad_max + epsilon)
+    v = v / (rad_max + epsilon)
+    return flow_uv_to_colors(u, v, convert_to_bgr)
\ No newline at end of file
diff --git a/windflow/networks/raft/utils/frame_utils.py b/windflow/networks/raft/utils/frame_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c491135efaffc25bd61ec3ecde99d236f5deb12
--- /dev/null
+++ b/windflow/networks/raft/utils/frame_utils.py
@@ -0,0 +1,137 @@
+import numpy as np
+from PIL import Image
+from os.path import *
+import re
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+def readFlow(fn):
+    """ Read .flo file in Middlebury format"""
+    # Code adapted from:
+    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+    # WARNING: this will work on little-endian architectures (eg Intel x86) only!
+    # print 'fn = %s'%(fn)
+    with open(fn, 'rb') as f:
+        magic = np.fromfile(f, np.float32, count=1)
+        if 202021.25 != magic:
+            print('Magic number incorrect. Invalid .flo file')
+            return None
+        else:
+            w = np.fromfile(f, np.int32, count=1)
+            h = np.fromfile(f, np.int32, count=1)
+            # print 'Reading %d x %d flo file\n' % (w, h)
+            data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
+            # Reshape data into 3D array (columns, rows, bands)
+            # The reshape here is for visualization, the original code is (w,h,2)
+            return np.resize(data, (int(h), int(w), 2))
+
+def readPFM(file):
+    file = open(file, 'rb')
+
+    color = None
+    width = None
+    height = None
+    scale = None
+    endian = None
+
+    header = file.readline().rstrip()
+    if header == b'PF':
+        color = True
+    elif header == b'Pf':
+        color = False
+    else:
+        raise Exception('Not a PFM file.')
+
+    dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
+    if dim_match:
+        width, height = map(int, dim_match.groups())
+    else:
+        raise Exception('Malformed PFM header.')
+
+    scale = float(file.readline().rstrip())
+    if scale < 0: # little-endian
+        endian = '<'
+        scale = -scale
+    else:
+        endian = '>' # big-endian
+
+    data = np.fromfile(file, endian + 'f')
+    shape = (height, width, 3) if color else (height, width)
+
+    data = np.reshape(data, shape)
+    data = np.flipud(data)
+    return data
+
+def writeFlow(filename,uv,v=None):
+    """ Write optical flow to file.
+    
+    If v is None, uv is assumed to contain both u and v channels,
+    stacked in depth.
+    Original code by Deqing Sun, adapted from Daniel Scharstein.
+    """
+    nBands = 2
+
+    if v is None:
+        assert(uv.ndim == 3)
+        assert(uv.shape[2] == 2)
+        u = uv[:,:,0]
+        v = uv[:,:,1]
+    else:
+        u = uv
+
+    assert(u.shape == v.shape)
+    height,width = u.shape
+    f = open(filename,'wb')
+    # write the header
+    f.write(TAG_CHAR)
+    np.array(width).astype(np.int32).tofile(f)
+    np.array(height).astype(np.int32).tofile(f)
+    # arrange into matrix form
+    tmp = np.zeros((height, width*nBands))
+    tmp[:,np.arange(width)*2] = u
+    tmp[:,np.arange(width)*2 + 1] = v
+    tmp.astype(np.float32).tofile(f)
+    f.close()
+
+
+def readFlowKITTI(filename):
+    flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
+    flow = flow[:,:,::-1].astype(np.float32)
+    flow, valid = flow[:, :, :2], flow[:, :, 2]
+    flow = (flow - 2**15) / 64.0
+    return flow, valid
+
+def readDispKITTI(filename):
+    disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
+    valid = disp > 0.0
+    flow = np.stack([-disp, np.zeros_like(disp)], -1)
+    return flow, valid
+
+
+def writeFlowKITTI(filename, uv):
+    uv = 64.0 * uv + 2**15
+    valid = np.ones([uv.shape[0], uv.shape[1], 1])
+    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
+    cv2.imwrite(filename, uv[..., ::-1])
+    
+
+def read_gen(file_name, pil=False):
+    ext = splitext(file_name)[-1]
+    if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+        return Image.open(file_name)
+    elif ext == '.bin' or ext == '.raw':
+        return np.load(file_name)
+    elif ext == '.flo':
+        return readFlow(file_name).astype(np.float32)
+    elif ext == '.pfm':
+        flow = readPFM(file_name).astype(np.float32)
+        if len(flow.shape) == 2:
+            return flow
+        else:
+            return flow[:, :, :-1]
+    return []
\ No newline at end of file
diff --git a/windflow/networks/raft/utils/utils.py b/windflow/networks/raft/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f32d281c1c46353a0a2bf36b0550adb74125c65
--- /dev/null
+++ b/windflow/networks/raft/utils/utils.py
@@ -0,0 +1,82 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy import interpolate
+
+
+class InputPadder:
+    """ Pads images such that dimensions are divisible by 8 """
+    def __init__(self, dims, mode='sintel'):
+        self.ht, self.wd = dims[-2:]
+        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
+        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
+        if mode == 'sintel':
+            self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
+        else:
+            self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
+
+    def pad(self, *inputs):
+        return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+
+    def unpad(self,x):
+        ht, wd = x.shape[-2:]
+        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
+        return x[..., c[0]:c[1], c[2]:c[3]]
+
+def forward_interpolate(flow):
+    flow = flow.detach().cpu().numpy()
+    dx, dy = flow[0], flow[1]
+
+    ht, wd = dx.shape
+    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
+
+    x1 = x0 + dx
+    y1 = y0 + dy
+    
+    x1 = x1.reshape(-1)
+    y1 = y1.reshape(-1)
+    dx = dx.reshape(-1)
+    dy = dy.reshape(-1)
+
+    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
+    x1 = x1[valid]
+    y1 = y1[valid]
+    dx = dx[valid]
+    dy = dy[valid]
+
+    flow_x = interpolate.griddata(
+        (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
+
+    flow_y = interpolate.griddata(
+        (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
+
+    flow = np.stack([flow_x, flow_y], axis=0)
+    return torch.from_numpy(flow).float()
+
+
+def bilinear_sampler(img, coords, mode='bilinear', mask=False):
+    """ Wrapper for grid_sample, uses pixel coordinates """
+    H, W = img.shape[-2:]
+    xgrid, ygrid = coords.split([1,1], dim=-1)
+    xgrid = 2*xgrid/(W-1) - 1
+    ygrid = 2*ygrid/(H-1) - 1
+
+    grid = torch.cat([xgrid, ygrid], dim=-1)
+    img = F.grid_sample(img, grid, align_corners=True)
+
+    if mask:
+        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+        return img, mask.float()
+
+    return img
+
+
+def coords_grid(batch, ht, wd):
+    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+    coords = torch.stack(coords[::-1], dim=0).float()
+    return coords[None].repeat(batch, 1, 1, 1)
+
+
+def upflow8(flow, mode='bilinear'):
+    new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
diff --git a/windflow/networks/utils.py b/windflow/networks/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..141799a72d384b4ba838f86ac95026aa54199fb8
--- /dev/null
+++ b/windflow/networks/utils.py
@@ -0,0 +1,25 @@
+import numpy as np
+
+def split_array(arr, tile_size=128, overlap=16):
+    '''
+    Split a 3D numpy array into patches for inference
+    (Channels, Height, Width)
+    Args:
+        tile_size: width and height of patches to return
+        overlap: number of pixels to overlap between patches
+    Returns:
+        dict(patches, upper_left): patches and indices of original array
+    '''
+    arr = arr[np.newaxis]
+    width, height = arr.shape[2:4]
+    arrs = dict(patches=[], upper_left=[])
+    for i in range(0, width, tile_size - overlap):
+        for j in range(0, height, tile_size - overlap):
+            i = min(i, width - tile_size)
+            j = min(j, height - tile_size)
+            arrs['patches'].append(arr[:,:,i:i+tile_size,j:j+tile_size])
+            arrs['upper_left'].append([[i,j]])
+    arrs['patches'] = np.concatenate(arrs['patches'])
+    arrs['upper_left'] = np.concatenate(arrs['upper_left'])
+    return arrs['patches'], arrs['upper_left']
+
diff --git a/windflow/networks/warper.py b/windflow/networks/warper.py
new file mode 100644
index 0000000000000000000000000000000000000000..958c21258cccbd693a5022cd97f237e479530cb8
--- /dev/null
+++ b/windflow/networks/warper.py
@@ -0,0 +1,43 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+class BackwardWarp(nn.Module):
+    def __init__(self):
+        super(BackwardWarp, self).__init__()
+        #self.device = device
+    # end
+
+    def forward(self, tensorInput, tensorFlow):
+        if hasattr(self, 'tensorGrid') == False or self.tensorGrid.size(0) != tensorFlow.size(0) or self.tensorGrid.size(2) != tensorFlow.size(2) or self.tensorGrid.size(3) != tensorFlow.size(3):
+            tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1)
+            tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3))
+
+            self.tensorGrid = torch.cat([ tensorHorizontal, tensorVertical ], 1).cuda()
+        # end
+
+        tensorFlow = torch.cat([ tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0),
+                                 tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0)], 1)
+        out = torch.nn.functional.grid_sample(input=tensorInput, grid=(self.tensorGrid + tensorFlow).permute(0, 2, 3, 1),
+                                              mode='bilinear', padding_mode='border')
+        return out
+    # end
+
+class BackwardWarpMV(nn.Module):
+    def __init__(self):
+        super(BackwardWarpMV, self).__init__()
+        self.warper = BackwardWarp()
+
+    def forward(self, tensorInput, tensorFlow):
+        num_channels = tensorFlow.size(1)//2
+        U = tensorFlow[:,:num_channels]
+        V = tensorFlow[:,num_channels:]
+        out = []
+        for j in range(num_channels):
+            #uv = tensorFlow[:,[j,j+num_channels],:,:]
+            uv = torch.cat([U[:,j:j+1], V[:,j:j+1]], 1)
+            out_j = self.warper(tensorInput[:,j:j+1], uv)
+            out.append(out_j)
+
+        return torch.cat(out, 1)
+    
diff --git a/windflow/train/__init__.py b/windflow/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/train/base.py b/windflow/train/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..637316d6d04889c414c62c0a5684fbf6418daaf5
--- /dev/null
+++ b/windflow/train/base.py
@@ -0,0 +1,115 @@
+import os, sys
+
+import torch
+import torch.nn as nn
+from torch import optim
+from torch.utils import data
+from torch.nn.parallel import DistributedDataParallel as DDP
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from torch.utils.tensorboard import SummaryWriter # there is a bug with summarywriter importing
+
+import torchvision
+
+def scale_image(x):
+    xmn = torch.min(x)
+    xmx = torch.max(x)
+    return (x - xmn) / (xmx - xmn)
+
+class BaseTrainer(nn.Module):
+    def __init__(self, model, 
+                 model_name, 
+                 model_path,
+                 lr=1e-4,
+                 device=None,
+                 distribute=False,
+                 rank=0):
+        super(BaseTrainer, self).__init__()
+        self.model = model
+        self.model_name = model_name
+        self.model_path = model_path
+        self.lr = lr
+        self.device = device
+        self.distribute = distribute
+        self.rank = rank
+        
+        # NEW
+        self.scaler = torch.cuda.amp.GradScaler()
+            
+        self.checkpoint_filepath = os.path.join(model_path, 'checkpoint.pth.tar')
+        if (rank == 0) and (not os.path.exists(model_path)):
+            os.makedirs(model_path)
+        
+        self.global_step = 0
+        self._set_optimizer()
+        self._set_summary_writer()
+
+        
+    def _set_optimizer(self):
+        # set optimizer
+        #if self.rank == 0:
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
+
+
+    def _set_summary_writer(self):
+        self.tfwriter_train = SummaryWriter(os.path.join(self.model_path, 'train', 'tfsummary'))
+        self.tfwriter_valid = SummaryWriter(os.path.join(self.model_path, 'valid', 'tfsummary'))
+
+    def load_checkpoint(self):
+        filename = self.checkpoint_filepath
+        if os.path.isfile(filename):
+            print("loading checkpoint %s" % filename)
+            checkpoint = torch.load(filename)
+            self.global_step = checkpoint['global_step']
+            try:
+                self.model.module.load_state_dict(checkpoint['model'])
+            except:
+                self.model.load_state_dict(checkpoint['model'])
+                
+            self.optimizer.load_state_dict(checkpoint['optimizer'])
+            print("=> loaded checkpoint '{}' (Step {})"
+                    .format(filename, self.global_step))
+        else:
+            print("=> no checkpoint found at '{}'".format(filename))
+        
+    def save_checkpoint(self):
+        if self.distribute:
+            state = {'global_step': self.global_step, 
+                     'model': self.model.module.state_dict(),
+                     'optimizer': self.optimizer.state_dict()}
+        else:
+            state = {'global_step': self.global_step, 
+                     'model': self.model.state_dict(),
+                     'optimizer': self.optimizer.state_dict()}
+        torch.save(state, self.checkpoint_filepath)        
+
+    def log_tensorboard(self):
+        pass
+
+    def get_tfwriter(self, train):
+        if train:
+            return self.tfwriter_train
+        else:
+            return self.tfwriter_valid
+        
+    def log_scalar(self, x, name, train=True):
+        tfwriter = self.get_tfwriter(train)
+        tfwriter.add_scalar(name, x, self.global_step)
+    
+    def log_image_grid(self, img, name, train=True, N=4):
+        '''
+        img of shape (N, C, H, W)
+        '''
+        tfwriter = self.get_tfwriter(train)
+        img_grid = torchvision.utils.make_grid(img[:N])
+        tfwriter.add_image(name, scale_image(img_grid), self.global_step)
+        
+    def log_flow_grid(self, flows, name, train=True, N=4):
+        tfwriter = self.get_tfwriter(train)
+        U_grid = torchvision.utils.make_grid(flows[:N,:1])
+        V_grid = torchvision.utils.make_grid(flows[:N,1:])
+        intensity = (U_grid ** 2 + V_grid ** 2)**0.5
+        tfwriter.add_image(f'{name}/U', scale_image(U_grid), self.global_step)
+        tfwriter.add_image(f'{name}/V', scale_image(V_grid), self.global_step)
+        tfwriter.add_image(f'{name}/intensity', scale_image(intensity), self.global_step)
+        
\ No newline at end of file
diff --git a/windflow/train/flownet_trainer.py b/windflow/train/flownet_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..18a7f5899d82bcafdbd8b9dd0a362337d1644741
--- /dev/null
+++ b/windflow/train/flownet_trainer.py
@@ -0,0 +1,118 @@
+import os
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch import optim
+import torchvision 
+
+from .base import BaseTrainer
+from ..networks import FlowNetS
+from ..losses import CharbonnierLoss, L1Loss, L2Loss, RMSVDLoss
+
+from torch.utils.tensorboard import SummaryWriter
+
+class MultiFrameTrainer(BaseTrainer):
+    def __init__(self,
+                 model, 
+                 model_name, 
+                 model_path, 
+                 lr=1e-4,
+                 device=None,
+                 distribute=False,
+                 rank=0):
+        BaseTrainer.__init__(self, model, model_name, model_path, lr=lr, 
+                             device=device, distribute=distribute, rank=rank)
+        
+        # set loss functions
+        self.photo_loss = L2Loss(None)
+        
+    def step(self, inputs, labels, log=False, train=True):
+        '''
+        Inputs shape: (N, T, C, H, W)
+        Flows shape: (N, T, 2, H, W)
+        '''
+        
+        N_pairs = len(inputs)-1
+        # for every pair, compute 
+
+class FlownetTrainer(BaseTrainer):
+    def __init__(self,
+                 model, 
+                 model_name, 
+                 model_path, 
+                 lr=1e-4,
+                 device=None,
+                 distribute=False,
+                 rank=0,
+                 loss='L2'):
+        BaseTrainer.__init__(self, model, model_name, model_path, lr=lr, 
+                             device=device, distribute=distribute, rank=rank)
+
+        # set loss functions
+        if loss == 'L2':
+            self.photo_loss = L2Loss(None)
+        elif loss.lower() == 'rmsvd':
+            print("RMSVD Loss")
+            self.photo_loss = RMSVDLoss()
+        elif loss == 'CHAR':
+            self.photo_loss = CharbonnierLoss(alpha=0.50)
+        elif loss == 'L1':
+            self.photo_loss = L1Loss(None)
+
+        #self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min',
+                                                                    #factor=0.5, patience=100)
+    #def step(self, I0, I1, UV, log=False, train=True):
+    def step(self, inputs, labels, log=False, train=True):
+        '''
+        Inputs shape: (N, T, C, H, W)
+        Flows shape: (N, T, 2, H, W)
+        '''
+        I0 = inputs[:,0]
+        I1 = inputs[:,1]
+        x = torch.cat([I0, I1], 1)
+        
+        #with torch.cuda.amp.autocast():
+        #flows = self.model(I0, I1)
+        flows = self.model(x)
+        #if not isinstance(flows, list):
+        #    flows = [flows]
+
+        UV_l = labels[:,0]
+        losses = []
+        weights = [0.32, 0.16, 0.08, 0.04, 0.02, 0.005]
+
+        for layer, flow_l in enumerate(flows):
+            H = flow_l.shape[2]
+            scale_factor = flow_l.shape[2] / UV_l.shape[2]
+            UV_l = torch.nn.functional.interpolate(UV_l, scale_factor=scale_factor, 
+                                                   mode='bilinear', recompute_scale_factor=True)
+            losses.append(weights[layer] * self.photo_loss(flow_l, UV_l))
+            if log:
+                self.log_flow_grid(flow_l, f'flow_level_{layer}', train=train)
+
+        total_loss = sum(losses)
+
+        if train:
+            self.optimizer.zero_grad()
+            total_loss.backward()
+            #self.scaler.scale(total_loss).backward()
+            self.optimizer.step()
+            #self.scaler.step(self.optimizer)
+            #self.scaler.update()
+        #else:
+        #    self.scheduler.step(total_loss)
+
+        if log and (self.rank == 0):
+            self.log_scalar(total_loss, "total_loss", train=train)
+            self.log_image_grid(I0, "data/I0", train=train)
+            self.log_image_grid(I1, "data/I1", train=train)
+            self.log_flow_grid(labels[:,0], "label", train=train)
+            for i in range(len(losses)):
+                self.log_scalar(losses[i], f"losses/level_{i}", train=train)
+                self.log_flow_grid(flows[i], f"level_{i}", train=train)
+
+        if train:
+            self.global_step += 1
+
+        return total_loss
diff --git a/windflow/train/raft_trainer.py b/windflow/train/raft_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2486db6263550fecb07febf66cc04cadf81753d
--- /dev/null
+++ b/windflow/train/raft_trainer.py
@@ -0,0 +1,158 @@
+import os
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch import optim
+#from torch.nn.parallel import DistributedDataParallel as DDP
+import torchvision 
+
+from .base import BaseTrainer
+from ..networks import raft, warper
+from ..losses import CharbonnierLoss, L1Loss, L2Loss
+
+from torch.utils.tensorboard import SummaryWriter
+
+MAX_FLOW = 200
+
+def sequence_loss(flow_preds, flow_gt, gamma=0.8, max_flow=MAX_FLOW):
+    """ Loss function defined over sequence of flow predictions """
+
+    n_predictions = len(flow_preds) 
+    flow_loss = 0.0
+
+    # exlude invalid pixels and extremely large diplacements
+    mag = torch.sum(flow_gt**2, dim=1).sqrt()
+    #valid = (valid >= 0.5) & (mag < max_flow)
+    valid = (mag < max_flow)
+    for i in range(n_predictions):
+        i_weight = gamma**(n_predictions - i - 1)
+        i_loss = (flow_preds[i] - flow_gt).abs()
+        flow_loss += i_weight * (valid[:, None] * i_loss).mean()
+        #flow_loss += (i_weight * i_loss).mean()
+
+    return flow_loss
+
+
+def unsupervised_sequence_loss(I0, I1, flow_pred, gamma=0.8):
+    """Loss function unsupervised warping I1 to I0 with predicted flows"""
+    warp = warper.BackwardWarp()
+    loss = CharbonnierLoss(alpha=0.50)
+    
+    n_predictions = len(flow_pred)
+    flow_loss = 0.0
+    
+    for i in range(n_predictions):
+        I0_i = warp(I1, flow_pred[i])
+        i_weight = gamma**(n_predictions - i - 1)
+        i_loss = loss(I0, I0_i)
+        flow_loss +=  (i_weight * i_loss).mean()
+
+    return flow_loss
+    
+
+class RAFTTrainer(BaseTrainer):
+    def __init__(self, model, model_path, iters=24, 
+                 device=None, lr=1e-5, distribute=False, 
+                 clip=1., rank=0):
+        BaseTrainer.__init__(self, model, 'raft', model_path, lr=lr, 
+                             device=device, distribute=distribute, rank=rank)        
+        self.iters = iters
+        self.clip = clip
+        self.optimizer = optim.AdamW(model.parameters(), lr=lr, 
+                                     weight_decay=1e-4, eps=1e-8)
+
+        self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer, 
+                                                       max_lr=lr, 
+                                                       total_steps=500000,
+                                                       pct_start=0.05, 
+                                                       cycle_momentum=False, 
+                                                       anneal_strategy='linear')
+    
+        
+    #def step(self, I0, I1, flow_gt, flow_init=None, log=False, train=True):
+    def step(self, inputs, labels, flow_init=None, log=False, train=True):
+        if train:
+            tfwriter = self.tfwriter_train
+        else:
+            tfwriter = self.tfwriter_valid
+ 
+        with torch.cuda.amp.autocast():
+            I0 = inputs[:,0]
+            I1 = inputs[:,1]
+            flow_predictions = self.model(I0, I1, iters=self.iters, flow_init=flow_init)
+            loss = sequence_loss(flow_predictions, labels[:,0])
+
+            if train:
+                self.optimizer.zero_grad()
+                loss.backward()
+                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
+
+                self.optimizer.step()
+                #self.scaler.step(self.optimizer)
+                #self.scalar.update()
+                self.scheduler.step()
+ 
+
+            if log and (self.rank == 0):
+                self.log_scalar(loss, 'total_loss', train)
+                self.log_flow_grid(labels[:,0], 'label', train)
+                self.log_image_grid(I0, 'data/I0', train)
+                self.log_image_grid(I1, 'data/I1', train)
+ 
+                for i in range(0, self.iters, 2):
+                    self.log_flow_grid(flow_predictions[i], f'flows/iter_{i}', train)    
+            if train:
+                self.global_step = self.global_step + 1 
+        return loss
+
+    
+class RAFTGuided(BaseTrainer):
+    def __init__(self, model, model_path, iters=10, lambda_obs=0.1,
+                 device=None, lr=1e-4, distribute=False, clip=1., rank=0):
+        BaseTrainer.__init__(self, model, 'raft', model_path, lr=lr, 
+                             device=device, distribute=distribute, rank=rank)        
+        self.iters = iters
+        self.clip = clip
+        self.lambda_obs = lambda_obs
+        
+    def step(self, I0_obs, I1_obs, I0_phys, I1_phys, flow_phys, train=True, log=None):
+        if train:
+            tfwriter = self.tfwriter_train
+        else:
+            tfwriter = self.tfwriter_valid
+            
+        with torch.cuda.amp.autocast():
+            flow_predictions_phys = self.model(I0_phys, I1_phys, iters=self.iters)
+            loss_guide = sequence_loss(flow_predictions_phys, flow_phys)
+            
+            #flow_predictions_obs = self.model(I0_obs, I1_obs, iters=self.iters)
+            #loss_obs = unsupervised_sequence_loss(I0_obs, I1_obs, flow_predictions_obs)
+            print('loss', loss_guide)
+            loss = loss_guide # + loss_obs * self.lambda_obs
+            
+            if train:
+                self.optimizer.zero_grad()
+                loss.backward()
+                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
+
+                self.optimizer.step()
+                #self.scaler.step(self.optimizer)
+                #self.scalar.update()
+                
+
+            if log and (self.rank == 0):
+                self.log_scalar(loss_guide, 'loss_guide', train)
+                self.log_scalar(loss_obs, 'loss_obs', train)
+                self.log_scalar(loss, 'total_loss', train)
+                self.log_flow_grid(flow_gt, 'label', train)
+                self.log_image_grid(I0, 'data/I0', train)
+                self.log_image_grid(I1, 'data/I1', train)
+        
+                for i in range(0, self.iters, 2):
+                    self.log_flow_grid(flow_predictions[i], f'flows/iter_{i}', train)   
+                    
+            if train:
+                self.global_step = self.global_step + 1
+                
+        return loss
diff --git a/windflow/train/trainers.py b/windflow/train/trainers.py
new file mode 100644
index 0000000000000000000000000000000000000000..97d3b8076516f94af7e78745d386ff0621382cd9
--- /dev/null
+++ b/windflow/train/trainers.py
@@ -0,0 +1,16 @@
+from .raft_trainer import RAFTTrainer, RAFTGuided
+# from .flownet_trainer import FlownetTrainer, MultiFrameTrainer
+
+
+def get_flow_trainer(model, model_name, model_path, distribute=False, rank=0, lr=1e-4, loss='L2'):
+    # Select trainer
+    if model_name.lower() == 'raft':
+        trainer = RAFTTrainer(model, model_path, distribute=distribute, 
+                              rank=rank, lr=lr, iters=24)
+    elif model_name.lower() in ['flownets', 'pwc-net', 'flownet2', 'flownetc', 'flownetsd', 'pwcnet', 'maskflownet']:
+        trainer = FlownetTrainer(model, model_name, model_path,
+                                 distribute=distribute, rank=rank, lr=lr,
+                                 loss=loss)
+        # trainer = MultiFrameTrainer(model, model_name, model_path,
+        #                          distribute=distribute, rank=rank, lr=lr)
+    return trainer
diff --git a/windflow/utils/__init__.py b/windflow/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/windflow/utils/plotting.py b/windflow/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e20563e3797de12d5935626c23b337d31027f5
--- /dev/null
+++ b/windflow/utils/plotting.py
@@ -0,0 +1,98 @@
+import matplotlib.pyplot as plt
+import torch
+import torch.nn
+import numpy as np
+from mpl_toolkits.basemap import Basemap
+from metpy.plots import colortables
+import cv2
+
+def downsample(x, pool=25):
+    x = torch.from_numpy(x[np.newaxis])
+    xlr = torch.nn.AvgPool2d(pool, stride=pool)(x)
+    return xlr.detach().numpy()[0]
+
+def flow_quiver_plot(u, v, x=None, y=None, ax=None, 
+                     down=25, vmin=None, vmax=None, latlon=False,
+                     size=10, cmap='jet', colorbar=False):
+    
+    intensity = (u**2 + v**2) ** 0.5
+
+    u_l =  downsample(u, down)
+    v_l =  downsample(v, down)
+    
+    
+    intensity_l = ( u_l ** 2 + v_l**2 ) ** 0.5
+    #intensity_l = ( u_l ** 2 + v_l**2 ) ** 0.25
+
+    u_l = u_l #/ intensity_l
+    v_l = v_l #/ intensity_l
+
+    if (x is None) or (y is None):
+        x = np.arange(0, u_l.shape[1])  * down + down/2.
+        y = np.arange(0, v_l.shape[0])  * down + down/2.
+        X, Y = np.meshgrid(x, y)
+    else:
+        x = downsample(x, down)
+        y = downsample(y, down)
+        X, Y = x, y
+    
+    #if X.shape[0] != u_l.shape[0]:
+    #Y = cv2.resize(Y, dsize=v_l.shape, interpolation=cv2.INTER_LINEAR)
+    #X = cv2.resize(X, dsize=u_l.shape, interpolation=cv2.INTER_LINEAR)
+
+    
+    if not ax:
+        ratio = 1.*u.shape[0] / u.shape[1]
+        hi = int(ratio * size)
+        wi = int(size)
+        fig = plt.figure(figsize=(wi,hi), frameon=False)
+        ax = fig.add_axes([0, 0, 1, 1])
+        ax.axis('off')
+        
+    if vmax is None:
+        vmax = np.nanmax(intensity)
+    if vmin is None:
+        vmin = np.nanmin(intensity)
+    
+    im1 = ax.imshow(intensity, origin='upper', cmap=cmap, vmax=vmax, vmin=vmin)
+    if colorbar:
+        plt.colorbar(im1, label='Wind Speed (m/s)', aspect=50, pad=0.01, shrink=0.85)
+        
+    #ax.quiver(X, Y, v_l, u_l, pivot='middle', 
+    #          width=0.001, headwidth=2, headlength=2)
+    scale = 150
+    try:
+        ax.quiver(X, Y, v_l, u_l, latlon=latlon, 
+                  scale_units='inches', scale=scale)
+    except: 
+        ax.quiver(X, Y, v_l, u_l,
+                 scale_units='inches', scale=scale)
+        
+    if hasattr(ax, 'xaxis'):
+        ax.xaxis.set_ticks([])
+        ax.yaxis.set_ticks([])
+        ax.set_aspect('equal')
+    return ax
+
+def plot_infrared(data, x, y, sat_lon, ax=None, cmin=190, cmax=350):
+    # The geostationary projection
+    if ax is None:
+        fig = plt.figure(figsize=[8,8], frameon=False)
+        ax = fig.add_subplot(111)
+
+    m = Basemap(projection='geos', lon_0=sat_lon, resolution='i',
+                 rsphere=(6378137.00,6356752.3142),
+                 llcrnrx=x.min(),llcrnry=y.min(),
+                 urcrnrx=x.max(),urcrnry=y.max(),
+                 ax=ax)
+    m.drawcoastlines()
+    m.drawcountries()
+    m.drawstates()
+    ax.axis('off')
+    
+    # Use a colortable/colormap available from MetPy
+    ir_norm, ir_cmap = colortables.get_with_range('ir_drgb_r', cmin, cmax)
+    # Plot the data using imshow
+    im = m.imshow(data, origin='upper', cmap=ir_cmap, norm=ir_norm)
+    plt.colorbar(im, pad=0, aspect=50, ticks=range(cmin,cmax,10), shrink=0.85)
+    return ax
\ No newline at end of file
diff --git a/windflow/utils/preprocess.py b/windflow/utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b45d90c0c498a3bf2c8ae3cefb3c24d21f5cab5
--- /dev/null
+++ b/windflow/utils/preprocess.py
@@ -0,0 +1,15 @@
+import numpy as np
+
+def image_histogram_equalization(image, number_bins=500):
+    # from http://www.janeriksolem.net/2009/06/histogram-equalization-with-python-and.html
+    # get image histogram
+    image_histogram, bins = np.histogram(image.flatten(), number_bins, density=True)
+    cdf = image_histogram.cumsum() # cumulative distribution function
+    cdf = cdf / cdf[-1] # normalize
+
+    # use linear interpolation of cdf to find new pixel values
+    image_equalized = np.interp(image.flatten(), bins[:-1], cdf)
+
+    image_equalized[~np.isfinite(image_equalized)] = 0. 
+    
+    return image_equalized.reshape(image.shape)#, cdf
\ No newline at end of file
diff --git a/windflow/utils/resample_eco1280_to_geos5nr.py b/windflow/utils/resample_eco1280_to_geos5nr.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0cc21eda2ab9f4e28d33c3dc3ae6be91c219e4a
--- /dev/null
+++ b/windflow/utils/resample_eco1280_to_geos5nr.py
@@ -0,0 +1,75 @@
+from scipy.interpolate import RegularGridInterpolator
+
+import numpy as np
+from netCDF4 import Dataset
+
+# resample methods:
+linear = 'linear'
+cubic = 'cubic'
+nearest = 'nearest'
+
+eco_x = np.linspace(0.0, 2.0, 3600)
+eco_y = np.linspace(0.0, 1.0, 1801)
+g5_x = np.linspace(0.0, 2.0, 5760)
+g5_y = np.linspace(0.0, 1.0, 2881)
+
+g5_scale_x = 5760 / 3600
+g5_scale_y = 2881 / 1801
+
+
+def time_linear_interp_grids(grid_a, grid_b, time_a, time_b, time_i):
+    grd_shape = grid_a.shape
+
+    grid_a = grid_a.flatten()
+    grid_b = grid_b.flatten()
+
+    del_time = time_b - time_a
+
+    w_a = 1.0 - (time_i - time_a) / del_time
+    w_b = (time_i - time_a) / del_time
+
+    grid_i = w_a * grid_a + w_b * grid_b
+    grid_i = np.reshape(grid_i, grd_shape)
+
+    return grid_i
+
+
+def upsample(scalar_field, ej_a=0, ej_b=1801, ei_a=0, ei_b=3600):
+    gj_a = int(g5_scale_y * ej_a)
+    gj_b = int(g5_scale_y * ej_b)
+    gi_a = int(g5_scale_x * ei_a)
+    gi_b = int(g5_scale_x * ei_b)
+
+    # match N-S orientation to GEOS5-NR
+    scalar_field = scalar_field[::-1, :]
+
+    scalar_field = scalar_field[ej_a:ej_b, ei_a:ei_b]
+    print('input dims: ', scalar_field.shape[0], scalar_field.shape[1])
+
+    intrp = RegularGridInterpolator((eco_y[ej_a:ej_b], eco_x[ei_a:ei_b]), scalar_field, method=linear, bounds_error=False)
+
+    g5_y_s = g5_y[gj_a:gj_b]
+    g5_x_s = g5_x[gi_a:gi_b]
+    print('output dims: ', g5_y_s.shape[0], g5_x_s.shape[0])
+
+    xg, yg = np.meshgrid(g5_x_s, g5_y_s, indexing='xy')
+    yg, xg = yg.flatten(), xg.flatten()
+    pts = np.array([yg, xg])
+    t_pts = np.transpose(pts)
+
+    return np.reshape(intrp(t_pts), (g5_y_s.shape[0], g5_x_s.shape[0]))
+
+
+def test(eco_file, g5_file, ej_a=0, ej_b=1801, ei_a=0, ei_b=3600):
+    rtgrp = Dataset(eco_file, 'r', format='NETCDF3')
+    gp_var = rtgrp['gp_newP']
+    eco_gp = gp_var[:, :, 0].copy()
+    print('gp_newP range/shape: ', np.nanmin(eco_gp), np.nanmax(eco_gp), eco_gp.shape)
+
+    rsm_img = upsample(eco_gp, ej_a=ej_a, ej_b=ej_b, ei_a=ei_a, ei_b=ei_b)
+
+    return rsm_img, eco_gp
+
+
+
+