Skip to content
Snippets Groups Projects
Commit 000c93d0 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 92e920e0
No related branches found
No related tags found
No related merge requests found
......@@ -410,21 +410,18 @@ class IcingIntensityFCN:
return self.get_in_mem_data_batch(idxs, False)
def get_in_mem_data_batch_eval(self, idxs):
# sort these to use as numpy indexing arrays
nd_idxs = np.array(idxs)
nd_idxs = np.sort(nd_idxs)
data = []
for param in self.train_params:
nda = self.data_dct[param][nd_idxs, ]
nda = self.data_dct[param]
nda = normalize(nda, param, mean_std_dct)
data.append(nda)
data = np.stack(data)
data = data.astype(np.float32)
data = np.transpose(data, axes=(1, 2, 3, 0))
data = np.transpose(data, axes=(1, 2, 0))
data = np.expand_dims(data, axis=0)
# TODO: altitude data will be specified by user at run-time
nda = np.zeros([nd_idxs.size])
nda = np.zeros([data.shape[1]*data.shape[2]])
nda[:] = self.flight_level
nda = tf.one_hot(nda, 5).numpy()
......@@ -470,9 +467,9 @@ class IcingIntensityFCN:
indexes = list(indexes)
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
# dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
dataset = dataset.cache()
# dataset = dataset.cache()
self.eval_dataset = dataset
def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None):
......@@ -548,9 +545,9 @@ class IcingIntensityFCN:
print('num test samples: ', tst_idxs.shape[0])
print('setup_test_pipeline: Done')
def setup_eval_pipeline(self, data_dct, num_tiles):
def setup_eval_pipeline(self, data_dct):
self.data_dct = data_dct
idxs = np.arange(num_tiles)
idxs = np.arange(1)
self.num_data_samples = idxs.shape[0]
self.get_evaluate_dataset(idxs)
......@@ -973,22 +970,21 @@ class IcingIntensityFCN:
self.test_preds = preds
def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5):
# if ckpt_dir is not None: # if is None, this has been done already
# ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
# ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
# ckpt.restore(ckpt_manager.latest_checkpoint)
def do_evaluate(self, prob_thresh=0.5):
self.reset_test_metrics()
pred_s = []
# for data in self.eval_dataset:
# ds = tf.data.Dataset.from_tensor_slices(data)
# ds = ds.batch(BATCH_SIZE)
# for mini_batch in ds:
# pred = self.model([mini_batch], training=False)
# pred_s.append(pred)
for data in self.eval_dataset:
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.batch(BATCH_SIZE)
for mini_batch in ds:
pred = self.model([mini_batch], training=False)
pred_s.append(pred)
print(data[0].shape, data[1].shape)
pred = self.model([data])
print(pred.shape, pred.numpy().min(), pred.numpy().max())
preds = np.concatenate(pred_s)
preds = preds[:,0]
......@@ -1111,8 +1107,8 @@ def run_evaluate_static_avg(data_dct, ll, cc, ckpt_dir_s_path, day_night='DAY',
return ice_lons, ice_lats, preds_2d
def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l1b_or_l2='both', prob_thresh=0.5,
flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False):
def run_evaluate_static_fcn(data_dct, ckpt_dir_s_path, day_night='DAY', l1b_or_l2='both', prob_thresh=0.5,
flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False):
ckpt_dir_s = os.listdir(ckpt_dir_s_path)
ckpt_dir = ckpt_dir_s_path + ckpt_dir_s[0]
......@@ -1124,7 +1120,7 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l
preds_dct = {flvl: None for flvl in flight_levels}
nn = IcingIntensityFCN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude)
nn.num_data_samples = num_tiles
nn.num_data_samples = 1
nn.build_model()
nn.build_training()
nn.build_evaluation()
......@@ -1135,7 +1131,7 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l
for flvl in flight_levels:
nn.flight_level = flvl
nn.setup_eval_pipeline(data_dct, num_tiles)
nn.setup_eval_pipeline(data_dct)
nn.do_evaluate(ckpt_dir)
probs = nn.test_probs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment