diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py index 91297c912be481e5113bf6095a9bb1c1bd9562c4..a09007e109548a29d22568ba020ba82508d80cc1 100644 --- a/modules/deeplearning/espcn.py +++ b/modules/deeplearning/espcn.py @@ -242,9 +242,6 @@ class ESPCN: # data = np.concatenate(data_s) data = np.concatenate(label_s) - hr_shape = data.shape - data = tf.image.resize(data, (hr_shape[0], hr_shape[1], hr_shape[2] // 2, hr_shape[3] // 2)) - label = np.concatenate(label_s) label = label[:, label_idx, :, :] @@ -252,6 +249,7 @@ class ESPCN: data = data[:, data_idx, :, :] data = np.expand_dims(data, axis=3) + data = tf.image.resize(data, (32, 32)).numpy() data = data.astype(np.float32) label = label.astype(np.float32)