diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py index 821cf096f087c11130f516adf8112fa0e233fb2d..91297c912be481e5113bf6095a9bb1c1bd9562c4 100644 --- a/modules/deeplearning/espcn.py +++ b/modules/deeplearning/espcn.py @@ -240,9 +240,11 @@ class ESPCN: nda = np.load(f) label_s.append(nda) - #data = np.concatenate(data_s) + # data = np.concatenate(data_s) data = np.concatenate(label_s) - data = tf.image.resize(data, (32, 32)) + hr_shape = data.shape + data = tf.image.resize(data, (hr_shape[0], hr_shape[1], hr_shape[2] // 2, hr_shape[3] // 2)) + label = np.concatenate(label_s) label = label[:, label_idx, :, :]