Skip to content
Snippets Groups Projects
Commit da6c9346 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent bfec1c99
Branches
No related tags found
No related merge requests found
......@@ -757,6 +757,87 @@ class SRCNN:
self.build_evaluation()
return self.do_evaluate(data, ckpt_dir)
def setup_inference(self, ckpt_dir):
self.num_data_samples = 80000
self.build_model()
self.build_training()
self.build_evaluation()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
def do_inference(self, inputs):
self.reset_test_metrics()
pred = self.model([inputs], training=False)
self.test_probs = pred
pred = pred.numpy()
return pred
def run_inference(self, in_file, out_file):
gc.collect()
h5f = h5py.File(in_file, 'r')
bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
y_len, x_len = bt.shape
refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
refl_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
refl_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
refl_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
cp = get_grid_values_all(h5f, label_param)
# lons = get_grid_values_all(h5f, 'longitude')
# lats = get_grid_values_all(h5f, 'latitude')
cld_frac = self.run_inference_(bt, refl, refl_lo, refl_hi, refl_std, cp)
cld_frac_out = np.zeros((y_len, x_len), dtype=np.int8)
border = int((KERNEL_SIZE - 1) / 2)
cld_frac_out[border:y_len - border, border:x_len - border] = cld_frac[0, :, :]
# Use this hack for now.
off_earth = (bt <= 161.0)
night = np.isnan(refl)
cld_frac_out[off_earth] = -1
cld_frac_out[np.invert(off_earth) & night] = -1
# --- Make a DataArray ----------------------------------------------------
# var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um']
# dims = ['num_params', 'y', 'x']
# da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims)
# da.assign_coords({
# 'num_params': var_names,
# 'lat': (['y', 'x'], lats),
# 'lon': (['y', 'x'], lons)
# })
# ---------------------------------------------------------------------------
h5f.close()
if out_file is not None:
np.save(out_file, (cld_frac_out, bt, refl, cp))
else:
# return [cld_frac_out, bt, refl, cp, lons, lats]
return cld_frac_out
def run_inference_(self, bt, refl, refl_lo, refl_hi, refl_std, cp):
bt = normalize(bt, 'temp_11_0um_nom', mean_std_dct)
refl = normalize(refl, 'refl_0_65um_nom', mean_std_dct)
refl_lo = normalize(refl_lo, 'refl_0_65um_nom', mean_std_dct)
refl_hi = normalize(refl_hi, 'refl_0_65um_nom', mean_std_dct)
refl_std = np.where(np.isnan(refl_std), 0, refl_std)
cp = np.where(np.isnan(cp), 0, cp)
data = np.stack([bt, refl, refl_lo, refl_hi, refl_std, cp], axis=2)
data = np.expand_dims(data, axis=0)
probs = self.do_inference(data)
cld_frac = probs.argmax(axis=3)
cld_frac = cld_frac.astype(np.int8)
return cld_frac
def run_restore_static(directory, ckpt_dir, out_file=None):
nn = SRCNN()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment