From f73a7bc148e21d8b9cb558b45e2b4545b4fc1070 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 20 Apr 2023 13:03:48 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_fraction_fcn_viirs.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/cloud_fraction_fcn_viirs.py b/modules/deeplearning/cloud_fraction_fcn_viirs.py
index 28f76db1..3417af38 100644
--- a/modules/deeplearning/cloud_fraction_fcn_viirs.py
+++ b/modules/deeplearning/cloud_fraction_fcn_viirs.py
@@ -278,6 +278,7 @@ class SRCNN:
         self.test_labels = []
         self.test_preds = []
         self.test_probs = None
+        self.test_input = []
 
         self.learningRateSchedule = None
         self.num_data_samples = None
@@ -342,7 +343,7 @@ class SRCNN:
 
             data_norm.append(lo[:, slc_y, slc_x])
             data_norm.append(hi[:, slc_y, slc_x])
-            data_norm.append(avg[:, slc_y, slc_x])
+            data_norm.append(std[:, slc_y, slc_x])
         # ---------------------------------------------------
         # If next uncommented, take out get_grid_cell_mean
         # tmp = input_data[:, label_idx, :, :]
@@ -580,6 +581,7 @@ class SRCNN:
 
         self.test_labels.append(labels)
         self.test_preds.append(pred.numpy())
+        self.test_input.append(inputs[:, :, :, 4])
 
         self.test_loss(t_loss)
         self.test_accuracy(labels, pred)
@@ -722,9 +724,10 @@ class SRCNN:
 
         labels = np.concatenate(self.test_labels)
         preds = np.concatenate(self.test_preds)
+        inputs = np.concatenate(self.test_input)
         print(labels.shape, preds.shape)
 
-        return labels, preds
+        return labels, preds, inputs
 
     def do_evaluate(self, inputs, ckpt_dir):
 
@@ -775,10 +778,10 @@ class SRCNN:
 
 def run_restore_static(directory, ckpt_dir, out_file=None):
     nn = SRCNN()
-    labels, preds = nn.run_restore(directory, ckpt_dir)
+    labels, preds, inputs = nn.run_restore(directory, ckpt_dir)
     if out_file is not None:
         np.save(out_file,
-                [np.squeeze(labels), preds.argmax(axis=3)])
+                [np.squeeze(labels), preds.argmax(axis=3), np.squeeze(inputs)])
 
 
 def run_evaluate_static(in_file, out_file, ckpt_dir):
@@ -799,7 +802,8 @@ def run_evaluate_static(in_file, out_file, ckpt_dir):
     refl_lo, refl_hi, refl_std, refl_avg = get_min_max_std(refl)
     refl_lo = normalize(refl_lo, 'refl_0_65um_nom', mean_std_dct)
     refl_hi = normalize(refl_hi, 'refl_0_65um_nom', mean_std_dct)
-    refl_avg = normalize(refl_avg, 'refl_0_65um_nom', mean_std_dct)
+    #refl_avg = normalize(refl_avg, 'refl_0_65um_nom', mean_std_dct)
+    refl_avg = refl_std
     refl_lo = np.squeeze(refl_lo)
     refl_hi = np.squeeze(refl_hi)
     refl_avg = np.squeeze(refl_avg)
-- 
GitLab