diff --git a/modules/util/plot.py b/modules/util/plot.py
index b027d29efd1558c3592ae6355113b8ab59483b20..f2163ccbe14d7d81b82aa17be26cb271a3fe78d6 100644
--- a/modules/util/plot.py
+++ b/modules/util/plot.py
@@ -330,10 +330,10 @@ def make_time_domain_hist(values, edges):
     plt.show()
 
 
-def scatter_density(x, y, color=None, fname=None, fig=None, x_rng=None, y_rng=None, dpi=72):
+def scatter_density(x, y, color=None, fname=None, fig=None, ax=None, x_rng=None, y_rng=None, dpi=72):
     if fig is None:
         fig = plt.figure()
-    ax = fig.add_subplot(1, 1, 1, projection='scatter_density')
+        ax = fig.add_subplot(1, 1, 1, projection='scatter_density')
     if color is not None:
         ax.scatter_density(x, y, color=color, dpi=dpi)
     else:
@@ -345,4 +345,4 @@ def scatter_density(x, y, color=None, fname=None, fig=None, x_rng=None, y_rng=No
     if fname is not None:
         fig.savefig(fname)
 
-    return fig
\ No newline at end of file
+    return fig, ax
\ No newline at end of file