Skip to content
Snippets Groups Projects
Commit 000c93d0 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 92e920e0
Branches
No related tags found
No related merge requests found
......@@ -410,21 +410,18 @@ class IcingIntensityFCN:
return self.get_in_mem_data_batch(idxs, False)
def get_in_mem_data_batch_eval(self, idxs):
# sort these to use as numpy indexing arrays
nd_idxs = np.array(idxs)
nd_idxs = np.sort(nd_idxs)
data = []
for param in self.train_params:
nda = self.data_dct[param][nd_idxs, ]
nda = self.data_dct[param]
nda = normalize(nda, param, mean_std_dct)
data.append(nda)
data = np.stack(data)
data = data.astype(np.float32)
data = np.transpose(data, axes=(1, 2, 3, 0))
data = np.transpose(data, axes=(1, 2, 0))
data = np.expand_dims(data, axis=0)
# TODO: altitude data will be specified by user at run-time
nda = np.zeros([nd_idxs.size])
nda = np.zeros([data.shape[1]*data.shape[2]])
nda[:] = self.flight_level
nda = tf.one_hot(nda, 5).numpy()
......@@ -470,9 +467,9 @@ class IcingIntensityFCN:
indexes = list(indexes)
dataset = tf.data.Dataset.from_tensor_slices(indexes)
dataset = dataset.batch(PROC_BATCH_SIZE)
# dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
dataset = dataset.cache()
# dataset = dataset.cache()
self.eval_dataset = dataset
def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None):
......@@ -548,9 +545,9 @@ class IcingIntensityFCN:
print('num test samples: ', tst_idxs.shape[0])
print('setup_test_pipeline: Done')
def setup_eval_pipeline(self, data_dct, num_tiles):
def setup_eval_pipeline(self, data_dct):
self.data_dct = data_dct
idxs = np.arange(num_tiles)
idxs = np.arange(1)
self.num_data_samples = idxs.shape[0]
self.get_evaluate_dataset(idxs)
......@@ -973,22 +970,21 @@ class IcingIntensityFCN:
self.test_preds = preds
def do_evaluate(self, ckpt_dir=None, prob_thresh=0.5):
# if ckpt_dir is not None: # if is None, this has been done already
# ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
# ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
# ckpt.restore(ckpt_manager.latest_checkpoint)
def do_evaluate(self, prob_thresh=0.5):
self.reset_test_metrics()
pred_s = []
# for data in self.eval_dataset:
# ds = tf.data.Dataset.from_tensor_slices(data)
# ds = ds.batch(BATCH_SIZE)
# for mini_batch in ds:
# pred = self.model([mini_batch], training=False)
# pred_s.append(pred)
for data in self.eval_dataset:
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.batch(BATCH_SIZE)
for mini_batch in ds:
pred = self.model([mini_batch], training=False)
pred_s.append(pred)
print(data[0].shape, data[1].shape)
pred = self.model([data])
print(pred.shape, pred.numpy().min(), pred.numpy().max())
preds = np.concatenate(pred_s)
preds = preds[:,0]
......@@ -1111,8 +1107,8 @@ def run_evaluate_static_avg(data_dct, ll, cc, ckpt_dir_s_path, day_night='DAY',
return ice_lons, ice_lats, preds_2d
def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l1b_or_l2='both', prob_thresh=0.5,
flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False):
def run_evaluate_static_fcn(data_dct, ckpt_dir_s_path, day_night='DAY', l1b_or_l2='both', prob_thresh=0.5,
flight_levels=[0, 1, 2, 3, 4], use_flight_altitude=False):
ckpt_dir_s = os.listdir(ckpt_dir_s_path)
ckpt_dir = ckpt_dir_s_path + ckpt_dir_s[0]
......@@ -1124,7 +1120,7 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l
preds_dct = {flvl: None for flvl in flight_levels}
nn = IcingIntensityFCN(day_night=day_night, l1b_or_l2=l1b_or_l2, use_flight_altitude=use_flight_altitude)
nn.num_data_samples = num_tiles
nn.num_data_samples = 1
nn.build_model()
nn.build_training()
nn.build_evaluation()
......@@ -1135,7 +1131,7 @@ def run_evaluate_static(data_dct, num_tiles, ckpt_dir_s_path, day_night='DAY', l
for flvl in flight_levels:
nn.flight_level = flvl
nn.setup_eval_pipeline(data_dct, num_tiles)
nn.setup_eval_pipeline(data_dct)
nn.do_evaluate(ckpt_dir)
probs = nn.test_probs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment