diff --git a/modules/deeplearning/esrgan_exp.py b/modules/deeplearning/esrgan_exp.py index 007874e57b9cec25e13baeb292a02be23df3ab93..97c722e69b5703546ceead602fa2c051ccd81783 100644 --- a/modules/deeplearning/esrgan_exp.py +++ b/modules/deeplearning/esrgan_exp.py @@ -76,7 +76,7 @@ def preprocess_image(hr_image): return tf.expand_dims(hr_image, 0) -def run(in_file): +def run(in_file, out_file): # model = hub.load(SAVED_MODEL_PATH) t0 = time.time() hr_image = get_image(in_file) @@ -90,4 +90,7 @@ def run(in_file): print('inference time: ', (t1-t0)) fake_image = tf.squeeze(fake_image) - return fake_image \ No newline at end of file + if out_file is not None: + np.save(out_file, fake_image) + else: + return fake_image