diff --git a/modules/deeplearning/unet.py b/modules/deeplearning/unet.py index efe15ca3c67901b6a3c966e391b65c815a09d10f..16e44f6f630fc896a5aab2812dc7237981141034 100644 --- a/modules/deeplearning/unet.py +++ b/modules/deeplearning/unet.py @@ -286,7 +286,7 @@ class UNET: def get_in_mem_data_batch(self, idxs, is_training): if is_training: train_data = [] - label_data = [] + train_label = [] for k in idxs: f = self.train_data_files[k] nda = np.load(f) @@ -294,35 +294,35 @@ class UNET: f = self.train_label_files[k] nda = np.load(f) - label_data.append(nda) + train_label.append(nda) data = np.concatenate(train_data) data = np.expand_dims(data, axis=3) - label = np.concatenate(label_data) + label = np.concatenate(train_label) label = np.expand_dims(label, axis=3) else: - train_data = [] - label_data = [] + test_data = [] + test_label = [] for k in idxs: f = self.test_data_files[k] nda = np.load(f) - train_data.append(nda) + test_data.append(nda) f = self.test_label_files[k] nda = np.load(f) - label_data.append(nda) + test_label.append(nda) - data = np.concatenate(train_data) + data = np.concatenate(test_data) data = np.expand_dims(data, axis=3) - label = np.concatenate(label_data) + label = np.concatenate(test_label) label = np.expand_dims(label, axis=3) data = data.astype(np.float32) label = label.astype(np.float32) - normalize(data, 'M15', mean_std_dct) - normalize(label, 'M15', mean_std_dct) + data = normalize(data, 'M15', mean_std_dct) + label = normalize(label, 'M15', mean_std_dct) if is_training and DO_AUGMENT: data_ud = np.flip(data, axis=1)