From f9360f7c9a5322e80fcfea3e9c2a849662ada6f4 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Tue, 19 Sep 2023 11:47:20 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/cloud_fraction_fcn_abi.py | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/cloud_fraction_fcn_abi.py b/modules/deeplearning/cloud_fraction_fcn_abi.py
index ca89dd75..9938f546 100644
--- a/modules/deeplearning/cloud_fraction_fcn_abi.py
+++ b/modules/deeplearning/cloud_fraction_fcn_abi.py
@@ -31,6 +31,7 @@ NUM_EPOCHS = 80
 
 TRACK_MOVING_AVERAGE = False
 EARLY_STOP = True
+PATIENCE = 7
 
 NOISE_TRAINING = False
 NOISE_STDDEV = 0.01
@@ -123,19 +124,16 @@ def upsample_mean(grd):
 
 
 def get_grid_cell_mean(grd_k):
-    grd_k = np.where(np.isnan(grd_k), 0, grd_k)
-
     mean = np.nanmean([grd_k[:, 0::4, 0::4], grd_k[:, 1::4, 0::4], grd_k[:, 2::4, 0::4], grd_k[:, 3::4, 0::4],
                        grd_k[:, 0::4, 1::4], grd_k[:, 1::4, 1::4], grd_k[:, 2::4, 1::4], grd_k[:, 3::4, 1::4],
                        grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
                        grd_k[:, 0::4, 3::4], grd_k[:, 1::4, 3::4], grd_k[:, 2::4, 3::4], grd_k[:, 3::4, 3::4]], axis=0)
+    np.where(np.isnan(mean), 0, mean)
 
     return mean
 
 
 def get_min_max_std(grd_k):
-    grd_k = np.where(np.isnan(grd_k), 0, grd_k)
-
     lo = np.nanmin([grd_k[:, 0::4, 0::4], grd_k[:, 1::4, 0::4], grd_k[:, 2::4, 0::4], grd_k[:, 3::4, 0::4],
                     grd_k[:, 0::4, 1::4], grd_k[:, 1::4, 1::4], grd_k[:, 2::4, 1::4], grd_k[:, 3::4, 1::4],
                     grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
@@ -156,6 +154,11 @@ def get_min_max_std(grd_k):
                       grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
                       grd_k[:, 0::4, 3::4], grd_k[:, 1::4, 3::4], grd_k[:, 2::4, 3::4], grd_k[:, 3::4, 3::4]], axis=0)
 
+    np.where(np.isnan(lo), 0, lo)
+    np.where(np.isnan(hi), 0, hi)
+    np.where(np.isnan(std), 0, std)
+    np.where(np.isnan(avg), 0, avg)
+
     return lo, hi, std, avg
 
 
@@ -618,7 +621,7 @@ class SRCNN:
         best_test_loss = np.finfo(dtype=np.float64).max
 
         if EARLY_STOP:
-            es = EarlyStop()
+            es = EarlyStop(patience=PATIENCE)
 
         for epoch in range(NUM_EPOCHS):
             self.train_loss.reset_states()
-- 
GitLab