diff --git a/modules/util/plot.py b/modules/util/plot.py index b027d29efd1558c3592ae6355113b8ab59483b20..f2163ccbe14d7d81b82aa17be26cb271a3fe78d6 100644 --- a/modules/util/plot.py +++ b/modules/util/plot.py @@ -330,10 +330,10 @@ def make_time_domain_hist(values, edges): plt.show() -def scatter_density(x, y, color=None, fname=None, fig=None, x_rng=None, y_rng=None, dpi=72): +def scatter_density(x, y, color=None, fname=None, fig=None, ax=None, x_rng=None, y_rng=None, dpi=72): if fig is None: fig = plt.figure() - ax = fig.add_subplot(1, 1, 1, projection='scatter_density') + ax = fig.add_subplot(1, 1, 1, projection='scatter_density') if color is not None: ax.scatter_density(x, y, color=color, dpi=dpi) else: @@ -345,4 +345,4 @@ def scatter_density(x, y, color=None, fname=None, fig=None, x_rng=None, y_rng=No if fname is not None: fig.savefig(fname) - return fig \ No newline at end of file + return fig, ax \ No newline at end of file