From 3245acdf564e247133fb90cd5e19efc32f5f1352 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Wed, 27 Oct 2021 10:44:25 -0500 Subject: [PATCH] new methods for running predictions --- modules/deeplearning/icing_cnn.py | 43 +++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index d8cc5523..afb0cf37 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -1,6 +1,6 @@ import tensorflow as tf import tensorflow_addons as tfa -from util.setup import logdir, modeldir, cachepath, now, ewa_varsdir +from util.setup import logdir, modeldir, cachepath, now from util.util import homedir, EarlyStop, normalize, make_for_full_domain_predict from util.geos_nav import get_navigation @@ -939,12 +939,14 @@ class IcingIntensityNN: self.test_preds = preds - def do_evaluate(self, ckpt_dir, prob_thresh=0.5): + def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5): - ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) - ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + if ckpt_dir is not None: # if is None, this has been done already + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) + ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + ckpt.restore(ckpt_manager.latest_checkpoint) - ckpt.restore(ckpt_manager.latest_checkpoint) + self.reset_test_metrics() pred_s = [] for data in self.eval_dataset: @@ -1079,6 +1081,37 @@ def run_evaluate_static(h5f, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5, s return ice_lons, ice_lats, preds_2d, lons_2d, lats_2d, x_rad, y_rad +def run_evaluate_static_new(data_dct, num_lines, num_elems, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5): + + ckpt_dir_s = os.listdir(ckpt_dir_s_path) + ckpt_dir = ckpt_dir_s[0] + + nn = IcingIntensityNN() + nn.flight_level = flight_level + nn.setup_eval_pipeline(data_dct, num_lines * num_elems) + nn.build_model() + nn.build_training() + nn.build_evaluation() + + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=nn.model) + ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + ckpt.restore(ckpt_manager.latest_checkpoint) + + for flvl in [0, 1, 2, 3, 4]: + nn.flight_level = flvl + nn.do_evaluate() + probs = nn.test_probs + + if NumClasses == 2: + preds = np.where(probs > prob_thresh, 1, 0) + else: + preds = np.argmax(probs, axis=1) + preds_2d = preds.reshape((num_lines, num_elems)) + probs_2d = probs.reshape((num_lines, num_elems)) + + return preds_2d, probs_2d + + if __name__ == "__main__": nn = IcingIntensityNN() nn.run('matchup_filename') -- GitLab