diff --git a/modules/deeplearning/esrgan_exp.py b/modules/deeplearning/esrgan_exp.py index 12fece18f3cf892c7efb8e94127c1cb2e8b25f99..87bb16699d6e72995b3918ba224eff048787666f 100644 --- a/modules/deeplearning/esrgan_exp.py +++ b/modules/deeplearning/esrgan_exp.py @@ -112,18 +112,22 @@ def preprocess_image(hr_image): def run(in_file, out_file): t0 = time.time() - hr_image = get_image_from_file(in_file) - hr_image = preprocess_image(hr_image) + image = get_image_from_file(in_file) + hr_image = preprocess_image(image) t1 = time.time() print('processing time: ', (t1-t0)) print('start inference: ', hr_image.shape) t0 = time.time() - fake_image = model(hr_image) + sres_image = model(hr_image) t1 = time.time() print('inference time: ', (t1-t0)) - fake_image = tf.squeeze(fake_image) + sres_image = tf.squeeze(sres_image) if out_file is not None: - np.save(out_file, fake_image) + np.save(out_file, sres_image) + with open(out_file, 'wb') as f: + np.save(f, sres_image) + np.save(f, hr_image) + np.save(f, image) else: - return fake_image, hr_image + return sres_image, hr_image