Skip to content
Snippets Groups Projects
Commit abc81791 authored by tomrink's avatar tomrink
Browse files

work on do_evaluate...

parent df9caeee
No related branches found
No related tags found
No related merge requests found
import glob import glob
import tensorflow as tf import tensorflow as tf
from util.setup import logdir, modeldir, cachepath, now, ancillary_path, home_dir from util.setup import logdir, modeldir, cachepath, now, ancillary_path
from util.util import EarlyStop, normalize, scale, make_for_full_domain_predict from util.util import EarlyStop, normalize, denormalize
import os, datetime import os, datetime
import numpy as np import numpy as np
import pickle import pickle
...@@ -740,7 +740,7 @@ class ESPCN: ...@@ -740,7 +740,7 @@ class ESPCN:
self.test_preds = preds self.test_preds = preds
def do_evaluate(self, nda_lr, ckpt_dir): def do_evaluate(self, nda_lr, param, ckpt_dir):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
...@@ -756,6 +756,8 @@ class ESPCN: ...@@ -756,6 +756,8 @@ class ESPCN:
pred = self.model([data]) pred = self.model([data])
self.test_probs = pred self.test_probs = pred
return denormalize(pred, param, mean_std_dct[param])
def run(self, directory): def run(self, directory):
train_data_files = glob.glob(directory+'data_train*.npy') train_data_files = glob.glob(directory+'data_train*.npy')
valid_data_files = glob.glob(directory+'data_valid*.npy') valid_data_files = glob.glob(directory+'data_valid*.npy')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment