diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py index 33a18220c0f86180a448d996d3edcc512e3fffdb..94aa3b3675bc0b4ce1aac20213b7c1d0cfb60362 100644 --- a/modules/deeplearning/srcnn_l1b_l2.py +++ b/modules/deeplearning/srcnn_l1b_l2.py @@ -49,8 +49,8 @@ f.close() mean_std_dct.update(mean_std_dct_l1b) mean_std_dct.update(mean_std_dct_l2) -label_param = 'cloud_fraction' -# label_param = 'cld_opd_dcomp' +# label_param = 'cloud_fraction' +label_param = 'cld_opd_dcomp' # label_param = 'cloud_probability' params = ['temp_11_0um_nom', 'temp_12_0um_nom', 'refl_0_65um_nom', label_param] @@ -251,7 +251,8 @@ class SRCNN: # label = input_data[:, label_idx, 3:131:2, 3:131:2] label = input_data[:, label_idx, 3:131, 3:131] if label_param != 'cloud_fraction': - label = normalize(label, label_param, mean_std_dct) + #label = normalize(label, label_param, mean_std_dct) + label = np.where(np.isnan(label), 0, label) else: label = np.where(np.isnan(label), 0, label) label = np.expand_dims(label, axis=3) @@ -399,9 +400,9 @@ class SRCNN: conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_3', scale=scale) - conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_4', scale=scale) + # conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_4', scale=scale) - conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_5', scale=scale) + # conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_5', scale=scale) conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, kernel_initializer='he_uniform', padding=padding)(conv_b) @@ -646,6 +647,8 @@ class SRCNN: def run(self, directory, ckpt_dir=None, num_data_samples=50000): train_data_files = glob.glob(directory+'data_train_*.npy') valid_data_files = glob.glob(directory+'data_valid_*.npy') + train_data_files = train_data_files[::2] + valid_data_files = valid_data_files[::2] self.setup_pipeline(train_data_files, valid_data_files, num_data_samples) self.build_model()