Skip to content
Snippets Groups Projects
Commit 77f09930 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent e5c55b61
No related branches found
No related tags found
No related merge requests found
......@@ -651,7 +651,6 @@ class IcingIntensityNN:
best_test_f1 = 0
best_test_mcc = 0
if EARLY_STOP:
es = EarlyStop()
......@@ -801,9 +800,7 @@ class IcingIntensityNN:
self.test_preds = preds
self.h5f_tst.close()
def do_evaluate(self, ckpt_dir, ll, cc, prob_thresh=0.5):
def do_evaluate(self, ckpt_dir, prob_thresh=0.5):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
......@@ -819,32 +816,14 @@ class IcingIntensityNN:
pred_s.append(pred)
preds = np.concatenate(pred_s)
preds = preds[:,0]
self.test_probs = preds
if NumClasses == 2:
preds = np.where(preds > prob_thresh, 1, 0)
else:
preds = np.argmax(preds, axis=1)
print(preds.shape[0], np.sum(preds == 1))
preds = preds[:,0]
cc = np.array(cc)
ll = np.array(ll)
ice_mask = preds == 1
print(cc.shape, ll.shape, ice_mask.shape)
ice_cc = cc[ice_mask]
ice_ll = ll[ice_mask]
nav = GEOSNavigation(sub_lon=-75.0, CFAC=5.6E-05, COFF=-0.101332, LFAC=-5.6E-05, LOFF=0.128212, num_elems=2500,
num_lines=1500)
ice_lons = []
ice_lats = []
for k in range(ice_cc.shape[0]):
lon, lat = nav.lc_to_earth(ice_cc[k], ice_ll[k])
ice_lons.append(lon)
ice_lats.append(lat)
return ice_lons, ice_lats
self.test_preds = preds
def run(self, filename_trn, filename_tst):
with tf.device('/device:GPU:'+str(self.gpu_device)):
......@@ -860,6 +839,7 @@ class IcingIntensityNN:
self.build_training()
self.build_evaluation()
self.restore(ckpt_dir)
self.h5f_tst.close()
def run_evaluate(self, filename, ckpt_dir):
data_dct, ll, cc = make_for_full_domain_predict(filename, name_list=train_params)
......@@ -867,8 +847,7 @@ class IcingIntensityNN:
self.build_model()
self.build_training()
self.build_evaluation()
ice_lons, ice_lats = self.do_evaluate(ckpt_dir, ll, cc)
return filename, ice_lons, ice_lats
self.do_evaluate(ckpt_dir)
def run_restore_static(filename_tst, ckpt_dir_s_path):
......@@ -890,8 +869,51 @@ def run_restore_static(filename_tst, ckpt_dir_s_path):
return cm_avg
def run_evaluate_static(filename, ckpt_dir_s):
nn = IcingIntensityNN()
def run_evaluate_static(filename, ckpt_dir_s_path, prob_thresh=0.5):
data_dct, ll, cc = make_for_full_domain_predict(filename, name_list=train_params)
ckpt_dir_s = os.listdir(ckpt_dir_s_path)
prob_s = []
for ckpt in ckpt_dir_s:
ckpt_dir = ckpt_dir_s_path + ckpt
if not os.path.isdir(ckpt_dir):
continue
nn = IcingIntensityNN()
nn.setup_eval_pipeline(data_dct, len(ll))
nn.build_model()
nn.build_training()
nn.build_evaluation()
nn.do_evaluate(ckpt_dir, ll, cc)
prob_s.append(nn.test_probs)
num = len(prob_s)
prob_avg = prob_s[0]
for k in range(num-1):
prob_avg += prob_s[k+1]
prob_avg /= num
probs = prob_avg
if NumClasses == 2:
preds = np.where(probs > prob_thresh, 1, 0)
else:
preds = np.argmax(probs, axis=1)
cc = np.array(cc)
ll = np.array(ll)
ice_mask = preds == 1
print(cc.shape, ll.shape, ice_mask.shape)
ice_cc = cc[ice_mask]
ice_ll = ll[ice_mask]
nav = GEOSNavigation(sub_lon=-75.0, CFAC=5.6E-05, COFF=-0.101332, LFAC=-5.6E-05, LOFF=0.128212, num_elems=2500,
num_lines=1500)
ice_lons = []
ice_lats = []
for k in range(ice_cc.shape[0]):
lon, lat = nav.lc_to_earth(ice_cc[k], ice_ll[k])
ice_lons.append(lon)
ice_lats.append(lat)
return filename, ice_lons, ice_lats
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment