Skip to content
Snippets Groups Projects
Commit da6c9346 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent bfec1c99
No related branches found
No related tags found
No related merge requests found
......@@ -757,6 +757,87 @@ class SRCNN:
self.build_evaluation()
return self.do_evaluate(data, ckpt_dir)
def setup_inference(self, ckpt_dir):
self.num_data_samples = 80000
self.build_model()
self.build_training()
self.build_evaluation()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
def do_inference(self, inputs):
self.reset_test_metrics()
pred = self.model([inputs], training=False)
self.test_probs = pred
pred = pred.numpy()
return pred
def run_inference(self, in_file, out_file):
gc.collect()
h5f = h5py.File(in_file, 'r')
bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
y_len, x_len = bt.shape
refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
refl_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
refl_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
refl_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
cp = get_grid_values_all(h5f, label_param)
# lons = get_grid_values_all(h5f, 'longitude')
# lats = get_grid_values_all(h5f, 'latitude')
cld_frac = self.run_inference_(bt, refl, refl_lo, refl_hi, refl_std, cp)
cld_frac_out = np.zeros((y_len, x_len), dtype=np.int8)
border = int((KERNEL_SIZE - 1) / 2)
cld_frac_out[border:y_len - border, border:x_len - border] = cld_frac[0, :, :]
# Use this hack for now.
off_earth = (bt <= 161.0)
night = np.isnan(refl)
cld_frac_out[off_earth] = -1
cld_frac_out[np.invert(off_earth) & night] = -1
# --- Make a DataArray ----------------------------------------------------
# var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um']
# dims = ['num_params', 'y', 'x']
# da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims)
# da.assign_coords({
# 'num_params': var_names,
# 'lat': (['y', 'x'], lats),
# 'lon': (['y', 'x'], lons)
# })
# ---------------------------------------------------------------------------
h5f.close()
if out_file is not None:
np.save(out_file, (cld_frac_out, bt, refl, cp))
else:
# return [cld_frac_out, bt, refl, cp, lons, lats]
return cld_frac_out
def run_inference_(self, bt, refl, refl_lo, refl_hi, refl_std, cp):
bt = normalize(bt, 'temp_11_0um_nom', mean_std_dct)
refl = normalize(refl, 'refl_0_65um_nom', mean_std_dct)
refl_lo = normalize(refl_lo, 'refl_0_65um_nom', mean_std_dct)
refl_hi = normalize(refl_hi, 'refl_0_65um_nom', mean_std_dct)
refl_std = np.where(np.isnan(refl_std), 0, refl_std)
cp = np.where(np.isnan(cp), 0, cp)
data = np.stack([bt, refl, refl_lo, refl_hi, refl_std, cp], axis=2)
data = np.expand_dims(data, axis=0)
probs = self.do_inference(data)
cld_frac = probs.argmax(axis=3)
cld_frac = cld_frac.astype(np.int8)
return cld_frac
def run_restore_static(directory, ckpt_dir, out_file=None):
nn = SRCNN()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment