diff --git a/modules/deeplearning/icing_cnn.py b/modules/deeplearning/icing_cnn.py index 97ee728d01a0a875a617fab2848ed67366e2e7ba..8a5c92c5463d9bd3b411de4d1999fea618434f43 100644 --- a/modules/deeplearning/icing_cnn.py +++ b/modules/deeplearning/icing_cnn.py @@ -217,6 +217,7 @@ class IcingIntensityNN: elif NumClasses == 3: label = np.where(np.logical_or(label == 1, label == 2), 1, label) label = np.where(np.invert(np.logical_or(label == 0, label == 1)), 2, label) + label = label.reshape((label.shape[0], 1)) if CACHE_DATA_IN_MEM: self.in_mem_data_cache[key] = (data, label)