From dc06a6eafa5efc43b10affc2d15bdba07c7d3358 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 9 May 2023 10:53:24 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_fraction_fcn_abi.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/modules/deeplearning/cloud_fraction_fcn_abi.py b/modules/deeplearning/cloud_fraction_fcn_abi.py
index e990e5fa..683f2c2b 100644
--- a/modules/deeplearning/cloud_fraction_fcn_abi.py
+++ b/modules/deeplearning/cloud_fraction_fcn_abi.py
@@ -280,6 +280,7 @@ class SRCNN:
         self.test_labels = []
         self.test_preds = []
         self.test_probs = None
+        self.test_input = []
 
         self.learningRateSchedule = None
         self.num_data_samples = None
@@ -590,6 +591,7 @@ class SRCNN:
 
         self.test_labels.append(labels)
         self.test_preds.append(pred.numpy())
+        self.test_input.append(inputs)
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
@@ -732,9 +734,10 @@ class SRCNN:
 
         labels = np.concatenate(self.test_labels)
         preds = np.concatenate(self.test_preds)
+        inputs = np.concatenate(self.test_input)
         print(labels.shape, preds.shape)
 
-        return labels, preds
+        return labels, preds, inputs
 
     def do_evaluate(self, inputs, ckpt_dir):
 
@@ -785,10 +788,15 @@ class SRCNN:
 
 def run_restore_static(directory, ckpt_dir, out_file=None):
     nn = SRCNN()
-    labels, preds = nn.run_restore(directory, ckpt_dir)
+    labels, preds, inputs = nn.run_restore(directory, ckpt_dir)
     if out_file is not None:
         np.save(out_file,
-                [np.squeeze(labels), preds.argmax(axis=3)])
+                [np.squeeze(labels), preds.argmax(axis=3),
+                 denormalize(inputs[:, 1:65, 1:65, 0], 'temp_11_0um_nom', mean_std_dct),
+                 denormalize(inputs[:, 1:65, 1:65, 1], 'refl_0_65um_nom', mean_std_dct),
+                 denormalize(inputs[:, 1:65, 1:65, 2], 'refl_0_65um_nom', mean_std_dct),
+                 inputs[:, 1:65, 1:65, 3],
+                 inputs[:, 1:65, 1:65, 4]])
 
 
 def run_evaluate_static(in_file, out_file, ckpt_dir):
-- 
GitLab