diff --git a/modules/deeplearning/cloud_opd_srcnn_abi.py b/modules/deeplearning/cloud_opd_srcnn_abi.py
index 4815fced6f3d899a41d57ccf389293c8c0bd7bfa..58b91b0015c342e4e248557da0d5726e3568ea38 100644
--- a/modules/deeplearning/cloud_opd_srcnn_abi.py
+++ b/modules/deeplearning/cloud_opd_srcnn_abi.py
@@ -2,8 +2,10 @@ import gc
 import glob
 import tensorflow as tf
 
+from util.augment import augment_image
 from util.setup import logdir, modeldir, now, ancillary_path
-from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all, resample_2d_linear, smooth_2d
+from util.util import EarlyStop, normalize, denormalize, scale, descale, get_grid_values_all, resample_2d_linear,\
+    smooth_2d, make_tf_callable_generator
 import os, datetime
 import numpy as np
 import pickle
@@ -343,28 +345,9 @@ class SRCNN:
         label = scale(label, label_param, mean_std_dct)
         label = label[:, self.y_128, self.x_128]
 
-        label = np.where(np.isnan(label), 0.0, label)
         label = np.expand_dims(label, axis=3)
-
-        data = data.astype(np.float32)
         label = label.astype(np.float32)
 
-        if is_training and DO_AUGMENT:
-            # data_ud = np.flip(data, axis=1)
-            # label_ud = np.flip(label, axis=1)
-            #
-            # data_lr = np.flip(data, axis=2)
-            # label_lr = np.flip(label, axis=2)
-            #
-            # data = np.concatenate([data, data_ud, data_lr])
-            # label = np.concatenate([label, label_ud, label_lr])
-
-            data_rot = np.rot90(data, axes=(1, 2))
-            label_rot = np.rot90(label, axes=(1, 2))
-
-            data = np.concatenate([data, data_rot])
-            label = np.concatenate([label, label_rot])
-
         return data, label
 
     def get_in_mem_data_batch_train(self, idxs):
@@ -383,22 +366,35 @@ class SRCNN:
         out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
         return out
 
-    def get_train_dataset(self, indexes):
-        indexes = list(indexes)
+    def get_train_dataset(self, num_files):
+        def integer_gen(limit):
+            n = 0
+            while n < limit:
+                yield n
+                n += 1
+        num_gen = integer_gen(num_files)
+        gen = make_tf_callable_generator(num_gen)
 
-        dataset = tf.data.Dataset.from_tensor_slices(indexes)
+        dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
         dataset = dataset.batch(PROC_BATCH_SIZE)
         dataset = dataset.map(self.data_function, num_parallel_calls=8)
-        dataset = dataset.cache()
         if DO_AUGMENT:
-            dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE)
+            dataset = dataset.map(augment_image(), num_parallel_calls=8)
+        dataset = dataset.cache()
+        dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=False)
         dataset = dataset.prefetch(buffer_size=1)
         self.train_dataset = dataset
 
-    def get_test_dataset(self, indexes):
-        indexes = list(indexes)
+    def get_test_dataset(self, num_files):
+        def integer_gen(limit):
+            n = 0
+            while n < limit:
+                yield n
+                n += 1
+        num_gen = integer_gen(num_files)
+        gen = make_tf_callable_generator(num_gen)
 
-        dataset = tf.data.Dataset.from_tensor_slices(indexes)
+        dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
         dataset = dataset.batch(PROC_BATCH_SIZE)
         dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
         dataset = dataset.cache()
@@ -410,22 +406,17 @@ class SRCNN:
         self.test_data_files = test_data_files
         self.test_label_files = test_label_files
 
-        trn_idxs = np.arange(len(train_data_files))
-        np.random.shuffle(trn_idxs)
-
-        tst_idxs = np.arange(len(test_data_files))
-
-        self.get_train_dataset(trn_idxs)
-        self.get_test_dataset(tst_idxs)
+        self.get_train_dataset(len(train_data_files))
+        self.get_test_dataset(len(test_data_files))
 
         self.num_data_samples = num_train_samples  # approximately
 
         print('datetime: ', now)
         print('training and test data: ')
         print('---------------------------')
-        print('num train samples: ', self.num_data_samples)
+        print('num train files: ', len(train_data_files))
         print('BATCH SIZE: ', BATCH_SIZE)
-        print('num test samples: ', tst_idxs.shape[0])
+        print('num test files: ', len(test_data_files))
         print('setup_pipeline: Done')
 
     def setup_test_pipeline(self, test_data_files, test_label_files):