Skip to content
Snippets Groups Projects
Commit d3cce9cf authored by tomrink's avatar tomrink
Browse files

snapshot...

parent c5f2350c
No related branches found
No related tags found
No related merge requests found
import glob
import tensorflow as tf
from util.setup import logdir, modeldir, cachepath, now, ancillary_path
from util.util import EarlyStop, normalize, denormalize, resample, resample_2d_linear, resample_one, get_grid_values_all
from util.util import EarlyStop, normalize, denormalize, resample, resample_2d_linear, resample_one, resample_2d_linear_one, get_grid_values_all
import os, datetime
import numpy as np
import pickle
......@@ -610,24 +610,20 @@ class SRCNN:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
def do_evaluate(self, nda_lr, param, ckpt_dir):
def do_evaluate(self, data, ckpt_dir):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
data = normalize(nda_lr, param, mean_std_dct)
data = np.expand_dims(data, axis=0)
data = np.expand_dims(data, axis=3)
self.reset_test_metrics()
pred = self.model([data], training=False)
self.test_probs = pred
pred = pred.numpy()
return denormalize(pred, param, mean_std_dct)
return pred
def run(self, directory):
train_data_files = glob.glob(directory+'data_train*.npy')
......@@ -648,42 +644,51 @@ class SRCNN:
self.build_evaluation()
self.restore(ckpt_dir)
def run_evaluate(self, nda_lr, param, ckpt_dir):
def run_evaluate(self, data, ckpt_dir):
self.num_data_samples = 80000
self.build_model()
self.build_training()
self.build_evaluation()
return self.do_evaluate(nda_lr, param, ckpt_dir)
return self.do_evaluate(data, ckpt_dir)
def run_evaluate_static(in_file, out_file, ckpt_dir):
nda = np.load(in_file)
nda = nda[:, data_idx, 3:131:2, 3:131:2]
nda = resample(y_64, x_64, nda, s, t)
nda = np.expand_dims(nda, axis=3)
nn = SRCNN()
out_sr = nn.run_evaluate(nda, data_param, ckpt_dir)
if out_file is not None:
np.save(out_file, out_sr)
else:
return out_sr
# def run_evaluate_static(in_file, out_file, ckpt_dir):
# nda = np.load(in_file)
#
# nda = nda[:, data_idx, 3:131:2, 3:131:2]
# nda = resample(y_64, x_64, nda, s, t)
# nda = np.expand_dims(nda, axis=3)
#
# nn = SRCNN()
# out_sr = nn.run_evaluate(nda, data_param, ckpt_dir)
# out_sr = denormalize(out_sr, param, mean_std_dct)
# if out_file is not None:
# np.save(out_file, out_sr)
# else:
# return out_sr
def run_evaluate_static_new(in_file, out_file, ckpt_dir):
h5f = h5py.File(in_file, 'r')
grd = get_grid_values_all(h5f, data_param)
leny, lenx = grd.shape
grd_a = get_grid_values_all(h5f, data_params[0])
grd_b = get_grid_values_all(h5f, data_params[1])
leny, lenx = grd_a.shape
x = np.arange(lenx)
y = np.arange(leny)
x_up = np.arange(0, lenx, 0.5)
y_up = np.arange(0, leny, 0.5)
grd = resample_one(y, x, grd, y_up, x_up)
grd_a = resample_2d_linear_one(y, x, grd_a, y_up, x_up)
grd_a = normalize(grd_a, data_params[0], mean_std_dct)
grd_b = resample_2d_linear_one(y, x, grd_b, y_up, x_up)
data = np.stack([grd_a, grd_b], axis=2)
data = np.expand_dims(data, axis=0)
nn = SRCNN()
out_sr = nn.run_evaluate(grd, data_param, ckpt_dir)
out_sr = nn.run_evaluate(data, ckpt_dir)
#out_sr = denormalize(out_sr, label_params[0], mean_std_dct)
if out_file is not None:
np.save(out_file, out_sr)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment