From 996cc96827edfeac4ef0d046ec24c92344c212df Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 13 Nov 2023 12:24:46 -0600
Subject: [PATCH] snapshot...

---
 modules/deeplearning/icing_fcn.py | 36 +++++++++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/modules/deeplearning/icing_fcn.py b/modules/deeplearning/icing_fcn.py
index dc61cbc8..44bc846c 100644
--- a/modules/deeplearning/icing_fcn.py
+++ b/modules/deeplearning/icing_fcn.py
@@ -986,6 +986,8 @@ class IcingIntensityFCN:
         if self.h5f_l2_tst is not None:
             self.h5f_l2_tst.close()
 
+        print('best stats: ', best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc)
+
         f = open(home_dir+'/best_stats_'+now+'.pkl', 'wb')
         pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
         f.close()
@@ -1170,6 +1172,40 @@ def run_evaluate_static_2(model, data_dct, num_tiles, prob_thresh=0.5, flight_le
     return preds_dct, probs_dct
 
 
+def run_evaluate_static_avg(ckpt_dir_s_path, day_night='NIGHT', l1b_andor_l2='BOTH', use_flight_altitude=False):
+
+    ckpt_dir_s = os.listdir(ckpt_dir_s_path)
+    weight_s = []
+    for ckpt in ckpt_dir_s:
+        ckpt_dir = ckpt_dir_s_path + ckpt
+        if not os.path.isdir(ckpt_dir):
+            continue
+        model = load_model(ckpt_dir, day_night=day_night, l1b_andor_l2=l1b_andor_l2, use_flight_altitude=use_flight_altitude)
+        k_model = model.model
+        weight_s.append(k_model.get_weights())
+    sum = 0.0
+    for w in weight_s:
+        sum += w
+    avg_weights = sum / len(weight_s)
+
+    # ---------------------------------------------
+
+    new_model = IcingIntensityFCN(day_night=day_night, l1b_or_l2=l1b_andor_l2, use_flight_altitude=use_flight_altitude)
+    new_model.build_model()
+    new_model.build_training()
+    new_model.build_evaluation()
+
+    if ckpt_dir is None:
+        if not os.path.exists(modeldir):
+            os.mkdir(modeldir)
+        ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=new_model.model)
+        ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
+    new_model.model.set_weights(avg_weights)
+    ckpt_manager.save()
+
+    return
+
+
 def load_model(model_path, day_night='NIGHT', l1b_andor_l2='BOTH', use_flight_altitude=False):
     ckpt_dir_s = os.listdir(model_path)
     ckpt_dir = model_path + ckpt_dir_s[0]
-- 
GitLab