From f4f693c1131e6f9f5803faddd20f95bf7e152f92 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 28 Aug 2023 13:42:08 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_opd_fcn_abi.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/cloud_opd_fcn_abi.py b/modules/deeplearning/cloud_opd_fcn_abi.py
index 4bff2d3d..df4c7d4e 100644
--- a/modules/deeplearning/cloud_opd_fcn_abi.py
+++ b/modules/deeplearning/cloud_opd_fcn_abi.py
@@ -686,14 +686,15 @@ class SRCNN:
         labels = np.concatenate(self.test_labels)
         preds = np.concatenate(self.test_preds)
         inputs = np.concatenate(self.test_input)
-        print(labels.shape, preds.shape)
+        cat_cld_frac = np.concatenate(self.test_cat_cf)
+        print(labels.shape, cat_cld_frac.shape, preds.shape)
 
         # labels = denormalize(labels, label_param, mean_std_dct)
         # preds = denormalize(preds, label_param, mean_std_dct)
         labels = descale(labels, label_param, mean_std_dct)
         preds= descale(preds, label_param, mean_std_dct)
 
-        return labels, preds, inputs
+        return labels, cat_cld_frac, preds, inputs
 
     def do_evaluate(self, inputs, ckpt_dir):
 
@@ -892,16 +893,17 @@ class SRCNN:
 
 def run_restore_static(directory, ckpt_dir, out_file=None):
     nn = SRCNN()
-    labels, preds, inputs = nn.run_restore(directory, ckpt_dir)
+    labels, cat_cld_frac, preds, inputs = nn.run_restore(directory, ckpt_dir)
     if out_file is not None:
         y_hi, x_hi = (Y_LEN // 4) + 1, (X_LEN // 4) + 1
         np.save(out_file,
-                [np.squeeze(labels), preds.argmax(axis=3),
+                [labels, preds,
                  inputs[:, 1:y_hi, 1:x_hi, 0],
                  descale(inputs[:, 1:y_hi, 1:x_hi, 1], 'refl_0_65um_nom', mean_std_dct),
                  descale(inputs[:, 1:y_hi, 1:x_hi, 2], 'refl_0_65um_nom', mean_std_dct),
                  inputs[:, 1:y_hi, 1:x_hi, 3],
-                 descale(inputs[:, 1:y_hi, 1:x_hi, 4], label_param, mean_std_dct)])
+                 descale(inputs[:, 1:y_hi, 1:x_hi, 4], label_param, mean_std_dct),
+                 cat_cld_frac[:, 1:y_hi, 1:x_hi]])
 
 
 def run_evaluate_static(in_file, out_file, ckpt_dir):
-- 
GitLab