diff --git a/modules/deeplearning/cloud_fraction_fcn_abi.py b/modules/deeplearning/cloud_fraction_fcn_abi.py index 32ea7dbb83d4186bb91abc1c890f22d13aea3826..c140172a6d5c0adfe429d38beaa269ab766f0267 100644 --- a/modules/deeplearning/cloud_fraction_fcn_abi.py +++ b/modules/deeplearning/cloud_fraction_fcn_abi.py @@ -757,6 +757,87 @@ class SRCNN: self.build_evaluation() return self.do_evaluate(data, ckpt_dir) + def setup_inference(self, ckpt_dir): + self.num_data_samples = 80000 + self.build_model() + self.build_training() + self.build_evaluation() + + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model) + ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + ckpt.restore(ckpt_manager.latest_checkpoint) + + def do_inference(self, inputs): + self.reset_test_metrics() + + pred = self.model([inputs], training=False) + self.test_probs = pred + pred = pred.numpy() + + return pred + + def run_inference(self, in_file, out_file): + gc.collect() + + h5f = h5py.File(in_file, 'r') + + bt = get_grid_values_all(h5f, 'temp_11_0um_nom') + y_len, x_len = bt.shape + refl = get_grid_values_all(h5f, 'refl_0_65um_nom') + refl_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub') + refl_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub') + refl_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub') + cp = get_grid_values_all(h5f, label_param) + # lons = get_grid_values_all(h5f, 'longitude') + # lats = get_grid_values_all(h5f, 'latitude') + + cld_frac = self.run_inference_(bt, refl, refl_lo, refl_hi, refl_std, cp) + + cld_frac_out = np.zeros((y_len, x_len), dtype=np.int8) + border = int((KERNEL_SIZE - 1) / 2) + cld_frac_out[border:y_len - border, border:x_len - border] = cld_frac[0, :, :] + + # Use this hack for now. + off_earth = (bt <= 161.0) + night = np.isnan(refl) + cld_frac_out[off_earth] = -1 + cld_frac_out[np.invert(off_earth) & night] = -1 + + # --- Make a DataArray ---------------------------------------------------- + # var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um'] + # dims = ['num_params', 'y', 'x'] + # da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims) + # da.assign_coords({ + # 'num_params': var_names, + # 'lat': (['y', 'x'], lats), + # 'lon': (['y', 'x'], lons) + # }) + # --------------------------------------------------------------------------- + + h5f.close() + + if out_file is not None: + np.save(out_file, (cld_frac_out, bt, refl, cp)) + else: + # return [cld_frac_out, bt, refl, cp, lons, lats] + return cld_frac_out + + def run_inference_(self, bt, refl, refl_lo, refl_hi, refl_std, cp): + bt = normalize(bt, 'temp_11_0um_nom', mean_std_dct) + refl = normalize(refl, 'refl_0_65um_nom', mean_std_dct) + refl_lo = normalize(refl_lo, 'refl_0_65um_nom', mean_std_dct) + refl_hi = normalize(refl_hi, 'refl_0_65um_nom', mean_std_dct) + refl_std = np.where(np.isnan(refl_std), 0, refl_std) + cp = np.where(np.isnan(cp), 0, cp) + + data = np.stack([bt, refl, refl_lo, refl_hi, refl_std, cp], axis=2) + data = np.expand_dims(data, axis=0) + probs = self.do_inference(data) + cld_frac = probs.argmax(axis=3) + cld_frac = cld_frac.astype(np.int8) + + return cld_frac + def run_restore_static(directory, ckpt_dir, out_file=None): nn = SRCNN()