diff --git a/modules/deeplearning/icing_fcn.py b/modules/deeplearning/icing_fcn.py index 5f3aa4e6c6c90e0ec05d7427a6d3727450ddd944..b5706202072d2d75057917f9e85936d6be20f345 100644 --- a/modules/deeplearning/icing_fcn.py +++ b/modules/deeplearning/icing_fcn.py @@ -239,10 +239,11 @@ class IcingIntensityFCN: self.test_probs = None self.learningRateSchedule = None - self.num_data_samples = None + self.num_data_samples = 1 self.initial_learning_rate = None self.data_dct = None + self.cth_max = None n_chans = len(self.train_params) if TRIPLET: @@ -446,7 +447,11 @@ class IcingIntensityFCN: # TODO: altitude data will be specified by user at run-time nda = np.zeros([nd_idxs.size]) - nda[:] = self.flight_level + if self.cth_max is not None: + nda[:] = self.cth_max[nd_idxs] + else: + nda[:] = self.flight_level + nda = tf.one_hot(nda, 5).numpy() nda = np.expand_dims(nda, axis=1) nda = np.expand_dims(nda, axis=1) @@ -587,6 +592,7 @@ class IcingIntensityFCN: def setup_eval_pipeline(self, data_dct, num_tiles=1): self.data_dct = data_dct + self.cth_max = data_dct.get('cth_high_avg', None) idxs = np.arange(num_tiles) self.num_data_samples = idxs.shape[0] @@ -1058,7 +1064,7 @@ class IcingIntensityFCN: pred_s.append(pred) preds = np.concatenate(pred_s) - preds = preds[:,0] + preds = np.squeeze(preds) self.test_probs = preds if NumClasses == 2: @@ -1213,8 +1219,7 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l return preds_dct, probs_dct -def run_evaluate_static_2(model, data_dct, num_tiles, day_night='DAY', l1b_or_l2='both', satellite='GOES16', - prob_thresh=0.5, flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False): +def run_evaluate_static_2(model, data_dct, num_tiles, prob_thresh=0.5, flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False): if not use_flight_altitude: flight_levels = [0] @@ -1233,6 +1238,22 @@ def run_evaluate_static_2(model, data_dct, num_tiles, day_night='DAY', l1b_or_l2 return preds_dct, probs_dct +def load_model(model_path, day_night='NIGHT', l1b_andor_l2='BOTH', satellite='GOES16', use_flight_altitude=False): + ckpt_dir_s = os.listdir(model_path) + ckpt_dir = model_path + ckpt_dir_s[0] + + model = IcingIntensityFCN(day_night=day_night, l1b_or_l2=l1b_andor_l2, satellite=satellite, use_flight_altitude=use_flight_altitude) + model.build_model() + model.build_training() + model.build_evaluation() + + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=model.model) + ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + ckpt.restore(ckpt_manager.latest_checkpoint) + + return model + + if __name__ == "__main__": nn = IcingIntensityFCN() nn.run('matchup_filename')