From 38fc2d56e549041d3fdf6442507f96fda01b8bf2 Mon Sep 17 00:00:00 2001 From: tomrink <rink@ssec.wisc.edu> Date: Tue, 14 Mar 2023 11:22:22 -0500 Subject: [PATCH] snapshot... --- modules/deeplearning/cnn_cld_frac_mod_res.py | 63 +++++++++++++++----- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/modules/deeplearning/cnn_cld_frac_mod_res.py b/modules/deeplearning/cnn_cld_frac_mod_res.py index 43ba6cc7..1e6aa4c5 100644 --- a/modules/deeplearning/cnn_cld_frac_mod_res.py +++ b/modules/deeplearning/cnn_cld_frac_mod_res.py @@ -67,6 +67,7 @@ data_params_half = ['temp_11_0um_nom'] data_params_full = ['refl_0_65um_nom'] label_idx = params.index(label_param) +# label_idx = 0 print('data_params_half: ', data_params_half) print('data_params_full: ', data_params_full) @@ -274,14 +275,14 @@ class SRCNN: self.OUT_OF_RANGE = False - self.abi = None - self.temp = None - self.wv = None - self.lbfp = None - self.sfc = None + # self.abi = None + # self.temp = None + # self.wv = None + # self.lbfp = None + # self.sfc = None - self.in_mem_data_cache = {} - self.in_mem_data_cache_test = {} + # self.in_mem_data_cache = {} + # self.in_mem_data_cache_test = {} self.model = None self.optimizer = None @@ -313,10 +314,10 @@ class SRCNN: self.test_data_files = None self.test_label_files = None - self.train_data_nda = None - self.train_label_nda = None - self.test_data_nda = None - self.test_label_nda = None + # self.train_data_nda = None + # self.train_label_nda = None + # self.test_data_nda = None + # self.test_label_nda = None # self.n_chans = len(data_params_half) + len(data_params_full) + 1 self.n_chans = 5 @@ -343,6 +344,27 @@ class SRCNN: continue data_s.append(nda) input_data = np.concatenate(data_s) + input_label = input_data[:, label_idx, :, :] + + # if is_training: + # data_files = self.train_data_files + # label_files = self.train_label_files + # else: + # data_files = self.test_data_files + # label_files = self.test_label_files + # + # data_s = [] + # label_s = [] + # for k in idxs: + # f = data_files[k] + # nda = np.load(f) + # data_s.append(nda) + # + # f = label_files[k] + # nda = np.load(f) + # label_s.append(nda) + # input_data = np.concatenate(data_s) + # input_label = np.concatenate(label_s) DO_ADD_NOISE = False if is_training and NOISE_TRAINING: @@ -379,7 +401,7 @@ class SRCNN: data_norm.append(avg[:, 0:66, 0:66]) # data_norm.append(std[:, 0:66, 0:66]) # --------------------------------------------------- - tmp = input_data[:, label_idx, :, :] + tmp = input_label tmp = tmp.copy() tmp = np.where(np.isnan(tmp), 0, tmp) if DO_ESPCN: @@ -403,7 +425,7 @@ class SRCNN: data = data.astype(np.float32) # ----------------------------------------------------- # ----------------------------------------------------- - label = input_data[:, label_idx, :, :] + label = input_label label = label.copy() label = label[:, y_128, x_128] label = get_label_data(label) @@ -466,13 +488,18 @@ class SRCNN: dataset = dataset.cache() self.test_dataset = dataset - def setup_pipeline(self, train_data_files, test_data_files, num_train_samples): + def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples): + # self.train_data_files = train_data_files + # self.train_label_files = train_label_files + # self.test_data_files = test_data_files + # self.test_label_files = test_label_files self.train_data_files = train_data_files self.test_data_files = test_data_files trn_idxs = np.arange(len(train_data_files)) np.random.shuffle(trn_idxs) + tst_idxs = np.arange(len(test_data_files)) self.get_train_dataset(trn_idxs) @@ -795,10 +822,16 @@ class SRCNN: return pred def run(self, directory, ckpt_dir=None, num_data_samples=50000): + # train_data_files = glob.glob(directory+'data_train*.npy') + # valid_data_files = glob.glob(directory+'data_valid*.npy') + # train_label_files = glob.glob(directory+'label_train*.npy') + # valid_label_files = glob.glob(directory+'label_valid*.npy') + # self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, num_data_samples) + train_data_files = glob.glob(directory+'data_train_*.npy') valid_data_files = glob.glob(directory+'data_valid_*.npy') + self.setup_pipeline(train_data_files, None, valid_data_files, None, num_data_samples) - self.setup_pipeline(train_data_files, valid_data_files, num_data_samples) self.build_model() self.build_training() self.build_evaluation() -- GitLab