From 1e19e76e50f3e512b21328e28d131e54cb41432a Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Tue, 21 Mar 2023 16:05:29 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/cloud_fraction_fcn.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/deeplearning/cloud_fraction_fcn.py b/modules/deeplearning/cloud_fraction_fcn.py index a0479ce4..7799a0e6 100644 --- a/modules/deeplearning/cloud_fraction_fcn.py +++ b/modules/deeplearning/cloud_fraction_fcn.py @@ -717,7 +717,7 @@ class SRCNN: return labels, preds - def do_evaluate(self, data, ckpt_dir): + def do_evaluate(self, inputs, ckpt_dir): ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) @@ -725,7 +725,7 @@ class SRCNN: self.reset_test_metrics() - pred = self.model([data], training=False) + pred = self.model([inputs], training=False) self.test_probs = pred pred = pred.numpy() @@ -756,7 +756,7 @@ class SRCNN: return self.restore(ckpt_dir) def run_evaluate(self, data, ckpt_dir): - #data = tf.convert_to_tensor(data, dtype=tf.float32) + # data = tf.convert_to_tensor(data, dtype=tf.float32) self.num_data_samples = 80000 self.build_model() self.build_training() @@ -813,10 +813,11 @@ def run_evaluate_static(in_file, out_file, ckpt_dir): nn = SRCNN() out_sr = nn.run_evaluate(data, ckpt_dir) + out_sr = out_sr.argmax(axis=3) if out_file is not None: - np.save(out_file, (out_sr[0, :, :, 0], grd_a, avg, grd_c)) + np.save(out_file, (out_sr[0, :, :], grd_a, avg, grd_c)) else: - return out_sr, grd_a, avg, grd_c + return out_sr[0, :, :], grd_a, avg, grd_c def analyze2(nda_m, nda_i): -- GitLab