diff --git a/modules/deeplearning/unet_l1b_l2.py b/modules/deeplearning/unet_l1b_l2.py index 06d7a55d39c8876b3320d6a0fbccc859305ff598..c91458838b698534baa050b838ab19ac6fb3bcd4 100644 --- a/modules/deeplearning/unet_l1b_l2.py +++ b/modules/deeplearning/unet_l1b_l2.py @@ -60,7 +60,8 @@ l2_params = ['cloud_fraction', 'cld_temp_acha', 'cld_press_acha'] zero_out_params = ['cld_reff_dcomp', 'cld_opd_dcomp', 'iwc_dcomp', 'lwc_dcomp'] DO_ZERO_OUT = False -label_param = l2_params[1] +label_idx = 1 +label_param = l2_params[label_idx] def build_conv2d_block(conv, num_filters, activation, block_name, padding='SAME'): @@ -231,7 +232,7 @@ class UNET: f = self.test_label_files nda = np.load(f) - label = nda[idxs, 0, :, :] + label = nda[idxs, label_idx, :, :] label = np.expand_dims(label, axis=3) data = data.astype(np.float32)