diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py index 7e91a8329c68f307f6726e26481b6e7f3cb1e3d0..c31546da1d70a4318fedefd697b558c9fc2700f0 100644 --- a/modules/deeplearning/espcn.py +++ b/modules/deeplearning/espcn.py @@ -1,7 +1,7 @@ import glob import tensorflow as tf -from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir -from util.util import EarlyStop, normalize, scale, make_for_full_domain_predict +from util.setup import logdir, modeldir, cachepath, now, ancillary_path +from util.util import EarlyStop, normalize, denormalize import os, datetime import numpy as np import pickle @@ -740,7 +740,7 @@ class ESPCN: self.test_preds = preds - def do_evaluate(self, nda_lr, ckpt_dir): + def do_evaluate(self, nda_lr, param, ckpt_dir): ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) @@ -756,6 +756,8 @@ class ESPCN: pred = self.model([data]) self.test_probs = pred + return denormalize(pred, param, mean_std_dct[param]) + def run(self, directory): train_data_files = glob.glob(directory+'data_train*.npy') valid_data_files = glob.glob(directory+'data_valid*.npy')