From 15412d46ada94d19cbcb6bf8002a97a0451087bf Mon Sep 17 00:00:00 2001
From: Owen Graham <ohgraham1@madisoncollege.edu>
Date: Mon, 27 Jun 2022 15:37:42 -0500
Subject: [PATCH] Refactor plotters into classes

---
 visualizer/plotting.py | 322 +++++++++++++++++++++--------------------
 1 file changed, 165 insertions(+), 157 deletions(-)

diff --git a/visualizer/plotting.py b/visualizer/plotting.py
index 28998e6..69d7761 100644
--- a/visualizer/plotting.py
+++ b/visualizer/plotting.py
@@ -1,6 +1,5 @@
 """Plotting functions."""
 
-from dataclasses import make_dataclass
 from io import BytesIO
 from types import SimpleNamespace
 
@@ -17,170 +16,179 @@ plt.style.use('ggplot')
 plt.rcParams['axes.xmargin'] = 0
 matplotlib.use('Agg')
 
-Plotter = make_dataclass('Plotter', ['name', 'nav_title', 'plot'])
-plotters = {}
 
+class Plotter:
+    """Parent of plotter classes."""
 
-def add_plotter(name, nav_title):
-    """Decorator to register a plot type."""
 
-    def decorator(f):
-        plotters[name] = Plotter(name, nav_title, f)
-        return f
-
-    return decorator
-
-
-@add_plotter('time-series', 'Time Series')
-def plot_time_series():
+class TimeSeries(Plotter):
     """Plot one station/year/measurement."""
-    station_id = get_param('station')
-    year = get_param('year', to=year_type)
-    meas = get_param('measurement', to=meas_type)
-    station = get_station(station_id)
-    data = read_data(station, year)
-    fig, axes = plt.subplots()
-    fig.set_figheight(6)
-    fig.set_figwidth(12)
-    axes.plot(data[:, 0], data[:, meas.field])
-    avg = np.nanmean(data[:, meas.field], dtype='float32')
-    axes.hlines(y=avg,
-                xmin=data[:, 0][0],
-                xmax=data[:, 0][-1],
-                linestyle='-',
-                alpha=0.7)
-    axes.set_ylabel(meas.title)
-    axes.grid(True)
-    maximum = max(data, key=lambda row: row[meas.field])
-    minimum = min(data, key=lambda row: row[meas.field])
-    axes.set_title(
-        (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
-         f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'),
-        fontsize='small',
-    )
-    name = station['name']
-    plt.suptitle(f'{meas.title} measurements, {name} Station, '
-                 f'{data[0][0].year}')
-    return savefig_response(
-        fig,
-        filename=f'time-series.{station_id}.{year}.{meas.url_name}.png',
-    )
-
-
-@add_plotter('overlay', 'Overlay')
-def plot_overlay():
-    """Plot two station/years @ one measurement."""
-    num_datasets = 2
-    datasets = tuple(SimpleNamespace() for _ in range(num_datasets))
-    stations = read_stations()
-    for n, dset in enumerate(datasets, start=1):
-        dset.station_id = get_param(f'station{n}')
-        dset.year = get_param(f'year{n}', to=year_type)
-        dset.station = get_station(dset.station_id, stations=stations)
-        dset.name = dset.station['name']
-    meas = get_param('measurement', to=meas_type)
-
-    def ignore_feb_29(rows):
-        return [row for row in rows if (row[0].month, row[0].day) != (2, 29)]
-
-    for dset in datasets:
-        raw_data = read_data(dset.station, dset.year)
-        dset.data = np.array(ignore_feb_29(raw_data))
-
-    fig, axes = plt.subplots()
-    fig.set_figheight(6)
-    fig.set_figwidth(12)
-    for i, dset in enumerate(datasets):
-        alpha_kw = {'alpha': 0.6} if i else {}
-        axes.plot(datasets[0].data[:, 0],
-                  dset.data[:, meas.field],
-                  **alpha_kw,
-                  label=f'{dset.name} {dset.data[0][0].year}')
-    for dset in datasets:
-        dset.avg = np.nanmean(dset.data[:, meas.field], dtype='float32')
-        axes.hlines(y=dset.avg,
-                    xmin=datasets[0].data[:, 0][0],
-                    xmax=datasets[0].data[:, 0][-1],
-                    alpha=0.7,
-                    label=f'Avg {dset.name} {dset.data[0][0].year}')
-    axes.set_ylabel(meas.title)
-
-    for dset in datasets:
-        dset.max = max(dset.data, key=lambda row: row[meas.field])
-        dset.min = min(dset.data, key=lambda row: row[meas.field])
-    max_dset = max(datasets, key=lambda dset: dset.max[meas.field])
-    min_dset = min(datasets, key=lambda dset: dset.min[meas.field])
-
-    axes.grid(True)
-    axes.set_title((f'Max {meas.title}: {max_dset.max[meas.field]}, '
-                    f'{max_dset.name} Station, Date: ({max_dset.max[0]}). '
-                    f'Min {meas.title}: {min_dset.min[meas.field]}, '
-                    f'{min_dset.name} Station, Date: ({min_dset.min[0]}).'),
-                   fontsize='small')
-    axes.legend()
-    axes.tick_params(labelbottom=False)
-    title_dsets = ' / '.join(f'{dset.name} Station, {dset.data[0][0].year}'
-                             for dset in datasets)
-    plt.suptitle(f'{meas.title} measurements, {title_dsets}')
-    filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}'
-                              for dset in datasets)
-    return savefig_response(
-        fig,
-        filename=f'overlay.{filename_dsets}.{meas.url_name}.png',
-    )
-
-
-@add_plotter('boxplot', 'Boxplot')
-def plot_boxplot():
-    """Boxplot one station/measurement @ multiple years."""
-    station_id = get_param('station')
-    year1 = get_param('year1', to=year_type)
-    year2 = get_param('year2', to=year_type)
-    meas = get_param('measurement', to=meas_type)
-    station = get_station(station_id)
 
-    plot_data = []
+    name = 'time-series'
+    nav_title = 'Time Series'
 
-    start_year, end_year = sorted([year1, year2])
-    years = range(start_year, end_year + 1)
-
-    maximum = minimum = None
+    @classmethod
+    def plot(cls):
+        station_id = get_param('station')
+        year = get_param('year', to=year_type)
+        meas = get_param('measurement', to=meas_type)
+        station = get_station(station_id)
+        data = read_data(station, year)
+        fig, axes = plt.subplots()
+        fig.set_figheight(6)
+        fig.set_figwidth(12)
+        axes.plot(data[:, 0], data[:, meas.field])
+        avg = np.nanmean(data[:, meas.field], dtype='float32')
+        axes.hlines(y=avg,
+                    xmin=data[:, 0][0],
+                    xmax=data[:, 0][-1],
+                    linestyle='-',
+                    alpha=0.7)
+        axes.set_ylabel(meas.title)
+        axes.grid(True)
+        maximum = max(data, key=lambda row: row[meas.field])
+        minimum = min(data, key=lambda row: row[meas.field])
+        axes.set_title(
+            (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
+             f'Min {meas.title}: {minimum[meas.field]}, '
+             f'Date: ({minimum[0]}).'),
+            fontsize='small',
+        )
+        name = station['name']
+        plt.suptitle(f'{meas.title} measurements, {name} Station, '
+                     f'{data[0][0].year}')
+        return savefig_response(fig, filename=(
+            f'{cls.name}.{station_id}.{year}.{meas.url_name}.png'
+        ))
+
+
+class Overlay(Plotter):
+    """Plot two station/years @ one measurement."""
 
-    def at_index(row):
-        """Get the selected field of `row`."""
-        return row[meas.field]
+    name = 'overlay'
+    nav_title = 'Overlay'
+
+    @classmethod
+    def plot(cls):
+        num_datasets = 2
+        datasets = tuple(SimpleNamespace() for _ in range(num_datasets))
+        stations = read_stations()
+        for n, dset in enumerate(datasets, start=1):
+            dset.station_id = get_param(f'station{n}')
+            dset.year = get_param(f'year{n}', to=year_type)
+            dset.station = get_station(dset.station_id, stations=stations)
+            dset.name = dset.station['name']
+        meas = get_param('measurement', to=meas_type)
+
+        def ignore_feb_29(rows):
+            return [row for row in rows
+                    if (row[0].month, row[0].day) != (2, 29)]
+
+        for dset in datasets:
+            raw_data = read_data(dset.station, dset.year)
+            dset.data = np.array(ignore_feb_29(raw_data))
+
+        fig, axes = plt.subplots()
+        fig.set_figheight(6)
+        fig.set_figwidth(12)
+        for i, dset in enumerate(datasets):
+            alpha_kw = {'alpha': 0.6} if i else {}
+            axes.plot(datasets[0].data[:, 0],
+                      dset.data[:, meas.field],
+                      **alpha_kw,
+                      label=f'{dset.name} {dset.data[0][0].year}')
+        for dset in datasets:
+            dset.avg = np.nanmean(dset.data[:, meas.field], dtype='float32')
+            axes.hlines(y=dset.avg,
+                        xmin=datasets[0].data[:, 0][0],
+                        xmax=datasets[0].data[:, 0][-1],
+                        alpha=0.7,
+                        label=f'Avg {dset.name} {dset.data[0][0].year}')
+        axes.set_ylabel(meas.title)
+
+        for dset in datasets:
+            dset.max = max(dset.data, key=lambda row: row[meas.field])
+            dset.min = min(dset.data, key=lambda row: row[meas.field])
+        max_dset = max(datasets, key=lambda dset: dset.max[meas.field])
+        min_dset = min(datasets, key=lambda dset: dset.min[meas.field])
+
+        axes.grid(True)
+        axes.set_title(
+            (f'Max {meas.title}: {max_dset.max[meas.field]}, '
+             f'{max_dset.name} Station, Date: ({max_dset.max[0]}). '
+             f'Min {meas.title}: {min_dset.min[meas.field]}, '
+             f'{min_dset.name} Station, Date: ({min_dset.min[0]}).'),
+            fontsize='small')
+        axes.legend()
+        axes.tick_params(labelbottom=False)
+        title_dsets = ' / '.join(f'{dset.name} Station, {dset.data[0][0].year}'
+                                 for dset in datasets)
+        plt.suptitle(f'{meas.title} measurements, {title_dsets}')
+        filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}'
+                                  for dset in datasets)
+        return savefig_response(fig, filename=(
+            f'{cls.name}.{filename_dsets}.{meas.url_name}.png'
+        ))
+
+
+class Boxplot(Plotter):
+    """Boxplot one station/measurement @ multiple years."""
 
-    for year in years:
-        data = read_data(station, year)
-        selected_data = data[:, meas.field]
-        plot_data.append(selected_data)
-
-        year_max = max(data, key=at_index)
-        year_min = min(data, key=at_index)
-        if maximum is None:
-            maximum, minimum = year_max, year_min
-        else:
-            maximum = max(maximum, year_max, key=at_index)
-            minimum = min(minimum, year_min, key=at_index)
-
-    fig, axes = plt.subplots()
-    fig.set_figheight(6)
-    fig.set_figwidth(12)
-    axes.boxplot(plot_data, positions=years)
-    axes.set_ylabel(meas.title)
-    axes.grid(True)
-    axes.set_title(
-        (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
-         f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'),
-        fontsize='small',
-    )
-    name = station['name']
-    plt.suptitle(f'{meas.title} measurements, {name} Station, '
-                 f'{start_year} - {end_year}.')
-    return savefig_response(
-        fig,
-        filename=(f'boxplot.{station_id}.{year1}.{year2}.{meas.url_name}.png'),
-    )
+    name = 'boxplot'
+    nav_title = 'Boxplot'
+
+    @classmethod
+    def plot(cls):
+        station_id = get_param('station')
+        year1 = get_param('year1', to=year_type)
+        year2 = get_param('year2', to=year_type)
+        meas = get_param('measurement', to=meas_type)
+        station = get_station(station_id)
+
+        plot_data = []
+
+        start_year, end_year = sorted([year1, year2])
+        years = range(start_year, end_year + 1)
+
+        maximum = minimum = None
+
+        def at_index(row):
+            """Get the selected field of `row`."""
+            return row[meas.field]
+
+        for year in years:
+            data = read_data(station, year)
+            plot_data.append(data[:, meas.field])
+
+            year_max = max(data, key=at_index)
+            year_min = min(data, key=at_index)
+            if maximum is None:
+                maximum, minimum = year_max, year_min
+            else:
+                maximum = max(maximum, year_max, key=at_index)
+                minimum = min(minimum, year_min, key=at_index)
+
+        fig, axes = plt.subplots()
+        fig.set_figheight(6)
+        fig.set_figwidth(12)
+        axes.boxplot(plot_data, positions=years)
+        axes.set_ylabel(meas.title)
+        axes.grid(True)
+        axes.set_title(
+            (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
+             f'Min {meas.title}: {minimum[meas.field]}, '
+             f'Date: ({minimum[0]}).'),
+            fontsize='small',
+        )
+        name = station['name']
+        plt.suptitle(f'{meas.title} measurements, {name} Station, '
+                     f'{start_year} - {end_year}.')
+        return savefig_response(fig, filename=(
+            f'{cls.name}.{station_id}.{year1}.{year2}.{meas.url_name}.png'
+        ))
+
+
+plotters = {plotter.name: plotter for plotter in Plotter.__subclasses__()}
 
 
 def savefig_response(fig, filename=None):
-- 
GitLab