diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 6b32892a988720d5e52ca81503ff90af392b11ef..0bf753001fe11ebab107c7d83377d8bc635c59be 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -209,6 +209,8 @@ class IcingIntensityNN: self.inputs.append(self.X_img) self.inputs.append(tf.keras.Input(5)) + self.flight_level = 0 + self.DISK_CACHE = False if datapath is not None: @@ -367,9 +369,13 @@ class IcingIntensityNN: data = np.stack(data) data = data.astype(np.float32) data = np.transpose(data, axes=(1, 2, 3, 0)) + # TODO: altitude data will be specified by user at run-time + nda = np.zeros([nd_idxs.size]) + nda = self.flight_level + nda = tf.one_hot(nda, 5).numpy() - return data + return data, nda @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) def data_function(self, indexes): @@ -384,7 +390,7 @@ class IcingIntensityNN: @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) def data_function_evaluate(self, indexes): # TODO: modify for user specified altitude - out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], tf.float32) + out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], [tf.float32, tf.float32]) return out def get_train_dataset(self, indexes): @@ -1011,7 +1017,7 @@ def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path): return labels, prob_avg, cm_avg -def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16', domain='FD'): +def run_evaluate_static(h5f, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5, satellite='GOES16', domain='FD'): data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, domain=domain) num_elems = len(cc) num_lines = len(ll) @@ -1026,6 +1032,7 @@ def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16 if not os.path.isdir(ckpt_dir): continue nn = IcingIntensityNN() + nn.flight_level = flight_level nn.setup_eval_pipeline(data_dct, num_lines * num_elems) nn.build_model() nn.build_training()