diff --git a/modules/util/plot_cm.py b/modules/util/plot_cm.py
index 67f09d1ba8727790c5a95cfd0e3878fb73037534..619d00cdf025ac9da2a6e6fb8365ad8984e112f2 100644
--- a/modules/util/plot_cm.py
+++ b/modules/util/plot_cm.py
@@ -12,7 +12,7 @@ def confusion_matrix_values(correct_labels, predict_labels):
     return cm
 
 
-def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False):
+def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'confusion_matrix', normalize=False, axis=1):
     '''
     Parameters:
         correct_labels                  : These are your true classification categories.
@@ -30,7 +30,10 @@ def plot_confusion_matrix(cm, labels, title='Confusion matrix', filename = 'conf
     '''
 
     if normalize:
-        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
+        if axis == 1:
+            cm = cm.astype('float') / cm.sum(axis=axis)[:, np.newaxis]
+        elif axis == 0:
+            cm = cm.astype('float') / cm.sum(axis=axis)[np.newaxis, :]
         cm *= 100
         cm = np.nan_to_num(cm, copy=True)
         cm = cm.astype('int')