diff --git a/modules/GSOC/E2_ESRGAN/lib/dataset.py b/modules/GSOC/E2_ESRGAN/lib/dataset.py
index 29335c9629f947fde759e68459828cf574f6f812..3ec04b42e91df01f450e9dc4b2bfcfc0a78989fe 100644
--- a/modules/GSOC/E2_ESRGAN/lib/dataset.py
+++ b/modules/GSOC/E2_ESRGAN/lib/dataset.py
@@ -193,7 +193,7 @@ class OpdNpyDataset:
         opd = np.concatenate(opd_s)
         refl = np.concatenate(refl_s)
 
-        hr_image = np.stack([refl, opd], axis=3)
+        hr_image = np.stack([refl, opd], axis=2)
         hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, self.hr_size, self.hr_size)
 
         low_resolution = tf.image.resize(hr_image, [self.lr_size, self.lr_size], method='bicubic')