diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py
index 33022629f26a250f7a213f1408b665fd026f142e..6b09eaf7c64811bc3327c7a1fb888fddca898fe6 100644
--- a/modules/deeplearning/cloud_opd_srcnn_abi.py
+++ b/modules/deeplearning/cloud_opd_srcnn_abi.py
@@ -1,3 +1,4 @@
+import gc
 import glob
 import tensorflow as tf
 
@@ -687,6 +688,95 @@ class SRCNN:
         self.build_evaluation()
         return self.do_evaluate(data, ckpt_dir)
 
+    def setup_inference(self, ckpt_dir):
+        self.num_data_samples = 80000
+        self.build_model()
+        self.build_training()
+        self.build_evaluation()
+
+        ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
+        ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
+        ckpt.restore(ckpt_manager.latest_checkpoint)
+
+    def do_inference(self, inputs):
+        self.reset_test_metrics()
+
+        pred = self.model([inputs], training=False)
+        self.test_probs = pred
+        pred = pred.numpy()
+
+        return pred
+
+    def run_inference(self, in_file, out_file):
+        gc.collect()
+
+        h5f = h5py.File(in_file, 'r')
+
+        refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
+        LEN_Y, LEN_X = refl.shape
+        print(LEN_Y, LEN_X)
+
+        bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
+
+        cld_opd = get_grid_values_all(h5f, 'cld_opd_dcomp_1')
+
+        refl_sub_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
+        refl_sub_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
+        refl_sub_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
+
+        self.run_inference_(bt, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd, LEN_Y, LEN_X)
+
+    def run_inference_(self, bt, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd, LEN_Y, LEN_X):
+
+        slc_x = slice(0, (LEN_X - 16) + 4)
+        slc_y = slice(0, (LEN_Y - 16) + 4)
+        x_2 = np.arange((LEN_X - 16) + 4)
+        y_2 = np.arange((LEN_Y - 16) + 4)
+        t = np.arange(0, (LEN_X - 16) + 4, 0.5)
+        s = np.arange(0, (LEN_Y - 16) + 4, 0.5)
+
+        # refl = np.where(np.isnan(refl), 0, bt)
+        # refl = refl[slc_y, slc_x]
+        # refl = np.expand_dims(refl, axis=0)
+        # refl_us = upsample_static(refl, x_2, y_2, t, s, None, None)
+        # print(refl_us.shape)
+        # refl_us = normalize(refl_us, 'refl_0_65um_nom', mean_std_dct)
+        # print('REFL done')
+
+        bt = np.where(np.isnan(bt), 0, bt)
+        bt = bt[slc_y, slc_x]
+        bt = np.expand_dims(bt, axis=0)
+        bt_us = upsample_static(bt, x_2, y_2, t, s, None, None)
+        bt_us = normalize(bt_us, 'temp_11_0um_nom', mean_std_dct)
+        print('BT done')
+
+        refl_sub_lo = refl_sub_lo[slc_y, slc_x]
+        refl_sub_lo = np.expand_dims(refl_sub_lo, axis=0)
+        refl_sub_lo = upsample_nearest(refl_sub_lo)
+        refl_sub_lo = normalize(refl_sub_lo, 'refl_0_65um_nom', mean_std_dct)
+
+        refl_sub_hi = refl_sub_hi[slc_y, slc_x]
+        refl_sub_hi = np.expand_dims(refl_sub_hi, axis=0)
+        refl_sub_hi = upsample_nearest(refl_sub_hi)
+        refl_sub_hi = normalize(refl_sub_hi, 'refl_0_65um_nom', mean_std_dct)
+
+        refl_sub_std = refl_sub_std[slc_y, slc_x]
+        refl_sub_std = np.expand_dims(refl_sub_std, axis=0)
+        refl_sub_std = upsample_nearest(refl_sub_std)
+
+        cld_opd = np.where(np.isnan(cld_opd), 0, cld_opd)
+        cld_opd = cld_opd[slc_y, slc_x]
+        cld_opd = np.expand_dims(cld_opd, axis=0)
+        cld_opd_us = upsample_static(cld_opd, x_2, y_2, t, s, None, None)
+        cld_opd_us = normalize(cld_opd_us, label_param, mean_std_dct)
+        print('OPD done')
+
+        data = np.stack([bt_us, refl_sub_lo, refl_sub_hi, refl_sub_std, cld_opd_us], axis=3)
+
+        # data = self.do_inference(data)
+
+        return None
+
 
 def run_restore_static(directory, ckpt_dir, out_file=None):
     nn = SRCNN()