Skip to content
Snippets Groups Projects
Commit 996cc968 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent de1c7b24
No related branches found
No related tags found
No related merge requests found
...@@ -986,6 +986,8 @@ class IcingIntensityFCN: ...@@ -986,6 +986,8 @@ class IcingIntensityFCN:
if self.h5f_l2_tst is not None: if self.h5f_l2_tst is not None:
self.h5f_l2_tst.close() self.h5f_l2_tst.close()
print('best stats: ', best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc)
f = open(home_dir+'/best_stats_'+now+'.pkl', 'wb') f = open(home_dir+'/best_stats_'+now+'.pkl', 'wb')
pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f) pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
f.close() f.close()
...@@ -1170,6 +1172,40 @@ def run_evaluate_static_2(model, data_dct, num_tiles, prob_thresh=0.5, flight_le ...@@ -1170,6 +1172,40 @@ def run_evaluate_static_2(model, data_dct, num_tiles, prob_thresh=0.5, flight_le
return preds_dct, probs_dct return preds_dct, probs_dct
def run_evaluate_static_avg(ckpt_dir_s_path, day_night='NIGHT', l1b_andor_l2='BOTH', use_flight_altitude=False):
ckpt_dir_s = os.listdir(ckpt_dir_s_path)
weight_s = []
for ckpt in ckpt_dir_s:
ckpt_dir = ckpt_dir_s_path + ckpt
if not os.path.isdir(ckpt_dir):
continue
model = load_model(ckpt_dir, day_night=day_night, l1b_andor_l2=l1b_andor_l2, use_flight_altitude=use_flight_altitude)
k_model = model.model
weight_s.append(k_model.get_weights())
sum = 0.0
for w in weight_s:
sum += w
avg_weights = sum / len(weight_s)
# ---------------------------------------------
new_model = IcingIntensityFCN(day_night=day_night, l1b_or_l2=l1b_andor_l2, use_flight_altitude=use_flight_altitude)
new_model.build_model()
new_model.build_training()
new_model.build_evaluation()
if ckpt_dir is None:
if not os.path.exists(modeldir):
os.mkdir(modeldir)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=new_model.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
new_model.model.set_weights(avg_weights)
ckpt_manager.save()
return
def load_model(model_path, day_night='NIGHT', l1b_andor_l2='BOTH', use_flight_altitude=False): def load_model(model_path, day_night='NIGHT', l1b_andor_l2='BOTH', use_flight_altitude=False):
ckpt_dir_s = os.listdir(model_path) ckpt_dir_s = os.listdir(model_path)
ckpt_dir = model_path + ckpt_dir_s[0] ckpt_dir = model_path + ckpt_dir_s[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment