diff --git a/modules/util/plot.py b/modules/util/plot.py
index 2d63910203f53f4df13c5b3b9636402f8f35246f..b25b7f41b87cea5a1fd6931418b05640fdd1fb73 100644
--- a/modules/util/plot.py
+++ b/modules/util/plot.py
@@ -292,11 +292,12 @@ def make_time_domain_hist(values, edges):
     plt.show()
 
 
-def scatter_density(x, y, color, fname=None, fig=None):
+def scatter_density(x, y, color=None, fname=None, fig=None):
     if fig is None:
         fig = plt.figure()
     ax = fig.add_subplot(1, 1, 1, projection='scatter_density')
-    ax.scatter_density(x, y, color=color)
+    if color is not None:
+        ax.scatter_density(x, y, color=color)
 
     if fname is not None:
         fig.savefig(fname)