diff --git a/modules/deeplearning/cloud_opd_srcnn_viirs.py b/modules/deeplearning/cloud_opd_srcnn_viirs.py index 9fca0151e8f23da391d3df61ce5016ba3538f44c..19653240f7fd13ec324c49f5b34c59a23dda68aa 100644 --- a/modules/deeplearning/cloud_opd_srcnn_viirs.py +++ b/modules/deeplearning/cloud_opd_srcnn_viirs.py @@ -653,8 +653,8 @@ class SRCNN: def run(self, directory, ckpt_dir=None, num_data_samples=50000): train_data_files = glob.glob(directory+'train*mres*.npy') valid_data_files = glob.glob(directory+'valid*mres*.npy') - train_label_files = glob.glob(directory+'train*ires*.npy') - valid_label_files = glob.glob(directory+'valid*ires*.npy') + train_label_files = [f.replace('mres', 'ires') for f in train_data_files] + valid_label_files = [f.replace('mres', 'ires') for f in valid_data_files] self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, num_data_samples) self.build_model() @@ -666,7 +666,7 @@ class SRCNN: self.num_data_samples = 1000 valid_data_files = glob.glob(directory + 'valid*mres*.npy') - valid_label_files = glob.glob(directory + 'valid*ires*.npy') + valid_label_files = [f.replace('mres', 'ires') for f in valid_data_files] self.setup_test_pipeline(valid_data_files, valid_label_files) self.build_model()