Skip to content
Snippets Groups Projects
Commit 249b3ea9 authored by tomrink's avatar tomrink
Browse files

user specified flight level...

parent 4e3be64d
No related branches found
No related tags found
No related merge requests found
...@@ -209,6 +209,8 @@ class IcingIntensityNN: ...@@ -209,6 +209,8 @@ class IcingIntensityNN:
self.inputs.append(self.X_img) self.inputs.append(self.X_img)
self.inputs.append(tf.keras.Input(5)) self.inputs.append(tf.keras.Input(5))
self.flight_level = 0
self.DISK_CACHE = False self.DISK_CACHE = False
if datapath is not None: if datapath is not None:
...@@ -367,9 +369,13 @@ class IcingIntensityNN: ...@@ -367,9 +369,13 @@ class IcingIntensityNN:
data = np.stack(data) data = np.stack(data)
data = data.astype(np.float32) data = data.astype(np.float32)
data = np.transpose(data, axes=(1, 2, 3, 0)) data = np.transpose(data, axes=(1, 2, 3, 0))
# TODO: altitude data will be specified by user at run-time # TODO: altitude data will be specified by user at run-time
nda = np.zeros([nd_idxs.size])
nda = self.flight_level
nda = tf.one_hot(nda, 5).numpy()
return data return data, nda
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes): def data_function(self, indexes):
...@@ -384,7 +390,7 @@ class IcingIntensityNN: ...@@ -384,7 +390,7 @@ class IcingIntensityNN:
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)]) @tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_evaluate(self, indexes): def data_function_evaluate(self, indexes):
# TODO: modify for user specified altitude # TODO: modify for user specified altitude
out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], tf.float32) out = tf.numpy_function(self.get_in_mem_data_batch_eval, [indexes], [tf.float32, tf.float32])
return out return out
def get_train_dataset(self, indexes): def get_train_dataset(self, indexes):
...@@ -1011,7 +1017,7 @@ def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path): ...@@ -1011,7 +1017,7 @@ def run_restore_static(filename_l1b, filename_l2, ckpt_dir_s_path):
return labels, prob_avg, cm_avg return labels, prob_avg, cm_avg
def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16', domain='FD'): def run_evaluate_static(h5f, ckpt_dir_s_path, flight_level=4, prob_thresh=0.5, satellite='GOES16', domain='FD'):
data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, domain=domain) data_dct, ll, cc = make_for_full_domain_predict(h5f, name_list=train_params, domain=domain)
num_elems = len(cc) num_elems = len(cc)
num_lines = len(ll) num_lines = len(ll)
...@@ -1026,6 +1032,7 @@ def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16 ...@@ -1026,6 +1032,7 @@ def run_evaluate_static(h5f, ckpt_dir_s_path, prob_thresh=0.5, satellite='GOES16
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
continue continue
nn = IcingIntensityNN() nn = IcingIntensityNN()
nn.flight_level = flight_level
nn.setup_eval_pipeline(data_dct, num_lines * num_elems) nn.setup_eval_pipeline(data_dct, num_lines * num_elems)
nn.build_model() nn.build_model()
nn.build_training() nn.build_training()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment