From 0b2349b8cb5812401e45db5c480e9e5fede0bd32 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Wed, 27 Sep 2023 14:35:19 -0500 Subject: [PATCH] snapshot... --- modules/util/infer_cloud_fraction.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/modules/util/infer_cloud_fraction.py b/modules/util/infer_cloud_fraction.py index 69e063e8..38c9c805 100644 --- a/modules/util/infer_cloud_fraction.py +++ b/modules/util/infer_cloud_fraction.py @@ -1,8 +1,9 @@ from util.geos_nav import get_navigation -from util.setup_cloud_fraction import model_path +from util.setup_cloud_products import model_path_cld_frac, model_path_cld_opd from aeolus.datasource import CLAVRx import os -from deeplearning.cloud_fraction_fcn_abi import SRCNN +from deeplearning.cloud_fraction_fcn_abi import SRCNN as SRCNN_CLD_FRAC +from deeplearning.cloud_opd_fcn_abi import SRCNN as SRCNN_CLD_OPD from util.util import get_cartopy_crs, write_cld_frac_file_nc4 import numpy as np import time @@ -10,8 +11,11 @@ import time def infer_cloud_fraction(clvrx_path, output_dir, full_disk=True, satellite='GOES16', domain='FD', pattern=None): # -- location of the trained model - ckpt_dir_s = os.listdir(model_path) - ckpt_dir = model_path + ckpt_dir_s[0] + ckpt_dir_s = os.listdir(model_path_cld_frac) + ckpt_dir_cld_frac = model_path_cld_frac + ckpt_dir_s[0] + + ckpt_dir_s = os.listdir(model_path_cld_opd) + ckpt_dir_cld_opd = model_path_cld_opd + ckpt_dir_s[0] # -- Navigation parameters geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain) @@ -22,8 +26,11 @@ def infer_cloud_fraction(clvrx_path, output_dir, full_disk=True, satellite='GOES y_rad = ll * nav.LFAC + nav.LOFF # -- Create a model instance and initialize with trained model above - nn = SRCNN() - nn.setup_inference(ckpt_dir) + nn_cld_frac = SRCNN_CLD_FRAC() + nn_cld_frac.setup_inference(ckpt_dir_cld_frac) + + nn_cld_opd = SRCNN_CLD_OPD() + nn_cld_opd.setup_inference(ckpt_dir_cld_opd) if pattern is not None: clvrx_ds = CLAVRx(clvrx_path, pattern=pattern) @@ -36,9 +43,11 @@ def infer_cloud_fraction(clvrx_path, output_dir, full_disk=True, satellite='GOES t0 = time.time() if full_disk: - cld_frac = nn.run_inference_full_disk(pname, None) + cld_frac = nn_cld_frac.run_inference_full_disk(pname, None) + cld_opd = nn_cld_opd.run_inference_full_disk(pname, None) else: - cld_frac = nn.run_inference(pname, None) + cld_frac = nn_cld_frac.run_inference(pname, None) + cld_opd = nn_cld_opd.run_inference(pname, None) write_cld_frac_file_nc4(clvrx_str_time, out_file, cld_frac, x_rad, y_rad, None, None, satellite=satellite, domain=domain, has_time=True) -- GitLab