Skip to content
Snippets Groups Projects
Commit 3245acdf authored by tomrink's avatar tomrink
Browse files

new methods for running predictions

parent 790da061
Branches
No related tags found
No related merge requests found
import tensorflow as tf
import tensorflow_addons as tfa
from util.setup import logdir, modeldir, cachepath, now, ewa_varsdir
from util.setup import logdir, modeldir, cachepath, now
from util.util import homedir, EarlyStop, normalize, make_for_full_domain_predict
from util.geos_nav import get_navigation
......@@ -939,12 +939,14 @@ class IcingIntensityNN:
self.test_preds = preds
def do_evaluate(self, ckpt_dir, prob_thresh=0.5):
def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
if ckpt_dir is not None: # if is None, this has been done already
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
ckpt.restore(ckpt_manager.latest_checkpoint)
self.reset_test_metrics()
pred_s = []
for data in self.eval_dataset:
......@@ -1079,6 +1081,37 @@ def run_evaluate_static(h5f, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5, s
return ice_lons, ice_lats, preds_2d, lons_2d, lats_2d, x_rad, y_rad
def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5):
ckpt_dir_s = os.listdir(ckpt_dir_s_path)
ckpt_dir = ckpt_dir_s[0]
nn = IcingIntensityNN()
nn.flight_level = flight_level
nn.setup_eval_pipeline(data_dct, num_lines * num_elems)
nn.build_model()
nn.build_training()
nn.build_evaluation()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
for flvl in [0, 1, 2, 3, 4]:
nn.flight_level = flvl
nn.do_evaluate()
probs = nn.test_probs
if NumClasses == 2:
preds = np.where(probs > prob_thresh, 1, 0)
else:
preds = np.argmax(probs, axis=1)
preds_2d = preds.reshape((num_lines, num_elems))
probs_2d = probs.reshape((num_lines, num_elems))
return preds_2d, probs_2d
if __name__ == "__main__":
nn = IcingIntensityNN()
nn.run('matchup_filename')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment