From 15412d46ada94d19cbcb6bf8002a97a0451087bf Mon Sep 17 00:00:00 2001 From: Owen Graham <ohgraham1@madisoncollege.edu> Date: Mon, 27 Jun 2022 15:37:42 -0500 Subject: [PATCH] Refactor plotters into classes --- visualizer/plotting.py | 322 +++++++++++++++++++++-------------------- 1 file changed, 165 insertions(+), 157 deletions(-) diff --git a/visualizer/plotting.py b/visualizer/plotting.py index 28998e6..69d7761 100644 --- a/visualizer/plotting.py +++ b/visualizer/plotting.py @@ -1,6 +1,5 @@ """Plotting functions.""" -from dataclasses import make_dataclass from io import BytesIO from types import SimpleNamespace @@ -17,170 +16,179 @@ plt.style.use('ggplot') plt.rcParams['axes.xmargin'] = 0 matplotlib.use('Agg') -Plotter = make_dataclass('Plotter', ['name', 'nav_title', 'plot']) -plotters = {} +class Plotter: + """Parent of plotter classes.""" -def add_plotter(name, nav_title): - """Decorator to register a plot type.""" - def decorator(f): - plotters[name] = Plotter(name, nav_title, f) - return f - - return decorator - - -@add_plotter('time-series', 'Time Series') -def plot_time_series(): +class TimeSeries(Plotter): """Plot one station/year/measurement.""" - station_id = get_param('station') - year = get_param('year', to=year_type) - meas = get_param('measurement', to=meas_type) - station = get_station(station_id) - data = read_data(station, year) - fig, axes = plt.subplots() - fig.set_figheight(6) - fig.set_figwidth(12) - axes.plot(data[:, 0], data[:, meas.field]) - avg = np.nanmean(data[:, meas.field], dtype='float32') - axes.hlines(y=avg, - xmin=data[:, 0][0], - xmax=data[:, 0][-1], - linestyle='-', - alpha=0.7) - axes.set_ylabel(meas.title) - axes.grid(True) - maximum = max(data, key=lambda row: row[meas.field]) - minimum = min(data, key=lambda row: row[meas.field]) - axes.set_title( - (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). ' - f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'), - fontsize='small', - ) - name = station['name'] - plt.suptitle(f'{meas.title} measurements, {name} Station, ' - f'{data[0][0].year}') - return savefig_response( - fig, - filename=f'time-series.{station_id}.{year}.{meas.url_name}.png', - ) - - -@add_plotter('overlay', 'Overlay') -def plot_overlay(): - """Plot two station/years @ one measurement.""" - num_datasets = 2 - datasets = tuple(SimpleNamespace() for _ in range(num_datasets)) - stations = read_stations() - for n, dset in enumerate(datasets, start=1): - dset.station_id = get_param(f'station{n}') - dset.year = get_param(f'year{n}', to=year_type) - dset.station = get_station(dset.station_id, stations=stations) - dset.name = dset.station['name'] - meas = get_param('measurement', to=meas_type) - - def ignore_feb_29(rows): - return [row for row in rows if (row[0].month, row[0].day) != (2, 29)] - - for dset in datasets: - raw_data = read_data(dset.station, dset.year) - dset.data = np.array(ignore_feb_29(raw_data)) - - fig, axes = plt.subplots() - fig.set_figheight(6) - fig.set_figwidth(12) - for i, dset in enumerate(datasets): - alpha_kw = {'alpha': 0.6} if i else {} - axes.plot(datasets[0].data[:, 0], - dset.data[:, meas.field], - **alpha_kw, - label=f'{dset.name} {dset.data[0][0].year}') - for dset in datasets: - dset.avg = np.nanmean(dset.data[:, meas.field], dtype='float32') - axes.hlines(y=dset.avg, - xmin=datasets[0].data[:, 0][0], - xmax=datasets[0].data[:, 0][-1], - alpha=0.7, - label=f'Avg {dset.name} {dset.data[0][0].year}') - axes.set_ylabel(meas.title) - - for dset in datasets: - dset.max = max(dset.data, key=lambda row: row[meas.field]) - dset.min = min(dset.data, key=lambda row: row[meas.field]) - max_dset = max(datasets, key=lambda dset: dset.max[meas.field]) - min_dset = min(datasets, key=lambda dset: dset.min[meas.field]) - - axes.grid(True) - axes.set_title((f'Max {meas.title}: {max_dset.max[meas.field]}, ' - f'{max_dset.name} Station, Date: ({max_dset.max[0]}). ' - f'Min {meas.title}: {min_dset.min[meas.field]}, ' - f'{min_dset.name} Station, Date: ({min_dset.min[0]}).'), - fontsize='small') - axes.legend() - axes.tick_params(labelbottom=False) - title_dsets = ' / '.join(f'{dset.name} Station, {dset.data[0][0].year}' - for dset in datasets) - plt.suptitle(f'{meas.title} measurements, {title_dsets}') - filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}' - for dset in datasets) - return savefig_response( - fig, - filename=f'overlay.{filename_dsets}.{meas.url_name}.png', - ) - - -@add_plotter('boxplot', 'Boxplot') -def plot_boxplot(): - """Boxplot one station/measurement @ multiple years.""" - station_id = get_param('station') - year1 = get_param('year1', to=year_type) - year2 = get_param('year2', to=year_type) - meas = get_param('measurement', to=meas_type) - station = get_station(station_id) - plot_data = [] + name = 'time-series' + nav_title = 'Time Series' - start_year, end_year = sorted([year1, year2]) - years = range(start_year, end_year + 1) - - maximum = minimum = None + @classmethod + def plot(cls): + station_id = get_param('station') + year = get_param('year', to=year_type) + meas = get_param('measurement', to=meas_type) + station = get_station(station_id) + data = read_data(station, year) + fig, axes = plt.subplots() + fig.set_figheight(6) + fig.set_figwidth(12) + axes.plot(data[:, 0], data[:, meas.field]) + avg = np.nanmean(data[:, meas.field], dtype='float32') + axes.hlines(y=avg, + xmin=data[:, 0][0], + xmax=data[:, 0][-1], + linestyle='-', + alpha=0.7) + axes.set_ylabel(meas.title) + axes.grid(True) + maximum = max(data, key=lambda row: row[meas.field]) + minimum = min(data, key=lambda row: row[meas.field]) + axes.set_title( + (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). ' + f'Min {meas.title}: {minimum[meas.field]}, ' + f'Date: ({minimum[0]}).'), + fontsize='small', + ) + name = station['name'] + plt.suptitle(f'{meas.title} measurements, {name} Station, ' + f'{data[0][0].year}') + return savefig_response(fig, filename=( + f'{cls.name}.{station_id}.{year}.{meas.url_name}.png' + )) + + +class Overlay(Plotter): + """Plot two station/years @ one measurement.""" - def at_index(row): - """Get the selected field of `row`.""" - return row[meas.field] + name = 'overlay' + nav_title = 'Overlay' + + @classmethod + def plot(cls): + num_datasets = 2 + datasets = tuple(SimpleNamespace() for _ in range(num_datasets)) + stations = read_stations() + for n, dset in enumerate(datasets, start=1): + dset.station_id = get_param(f'station{n}') + dset.year = get_param(f'year{n}', to=year_type) + dset.station = get_station(dset.station_id, stations=stations) + dset.name = dset.station['name'] + meas = get_param('measurement', to=meas_type) + + def ignore_feb_29(rows): + return [row for row in rows + if (row[0].month, row[0].day) != (2, 29)] + + for dset in datasets: + raw_data = read_data(dset.station, dset.year) + dset.data = np.array(ignore_feb_29(raw_data)) + + fig, axes = plt.subplots() + fig.set_figheight(6) + fig.set_figwidth(12) + for i, dset in enumerate(datasets): + alpha_kw = {'alpha': 0.6} if i else {} + axes.plot(datasets[0].data[:, 0], + dset.data[:, meas.field], + **alpha_kw, + label=f'{dset.name} {dset.data[0][0].year}') + for dset in datasets: + dset.avg = np.nanmean(dset.data[:, meas.field], dtype='float32') + axes.hlines(y=dset.avg, + xmin=datasets[0].data[:, 0][0], + xmax=datasets[0].data[:, 0][-1], + alpha=0.7, + label=f'Avg {dset.name} {dset.data[0][0].year}') + axes.set_ylabel(meas.title) + + for dset in datasets: + dset.max = max(dset.data, key=lambda row: row[meas.field]) + dset.min = min(dset.data, key=lambda row: row[meas.field]) + max_dset = max(datasets, key=lambda dset: dset.max[meas.field]) + min_dset = min(datasets, key=lambda dset: dset.min[meas.field]) + + axes.grid(True) + axes.set_title( + (f'Max {meas.title}: {max_dset.max[meas.field]}, ' + f'{max_dset.name} Station, Date: ({max_dset.max[0]}). ' + f'Min {meas.title}: {min_dset.min[meas.field]}, ' + f'{min_dset.name} Station, Date: ({min_dset.min[0]}).'), + fontsize='small') + axes.legend() + axes.tick_params(labelbottom=False) + title_dsets = ' / '.join(f'{dset.name} Station, {dset.data[0][0].year}' + for dset in datasets) + plt.suptitle(f'{meas.title} measurements, {title_dsets}') + filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}' + for dset in datasets) + return savefig_response(fig, filename=( + f'{cls.name}.{filename_dsets}.{meas.url_name}.png' + )) + + +class Boxplot(Plotter): + """Boxplot one station/measurement @ multiple years.""" - for year in years: - data = read_data(station, year) - selected_data = data[:, meas.field] - plot_data.append(selected_data) - - year_max = max(data, key=at_index) - year_min = min(data, key=at_index) - if maximum is None: - maximum, minimum = year_max, year_min - else: - maximum = max(maximum, year_max, key=at_index) - minimum = min(minimum, year_min, key=at_index) - - fig, axes = plt.subplots() - fig.set_figheight(6) - fig.set_figwidth(12) - axes.boxplot(plot_data, positions=years) - axes.set_ylabel(meas.title) - axes.grid(True) - axes.set_title( - (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). ' - f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'), - fontsize='small', - ) - name = station['name'] - plt.suptitle(f'{meas.title} measurements, {name} Station, ' - f'{start_year} - {end_year}.') - return savefig_response( - fig, - filename=(f'boxplot.{station_id}.{year1}.{year2}.{meas.url_name}.png'), - ) + name = 'boxplot' + nav_title = 'Boxplot' + + @classmethod + def plot(cls): + station_id = get_param('station') + year1 = get_param('year1', to=year_type) + year2 = get_param('year2', to=year_type) + meas = get_param('measurement', to=meas_type) + station = get_station(station_id) + + plot_data = [] + + start_year, end_year = sorted([year1, year2]) + years = range(start_year, end_year + 1) + + maximum = minimum = None + + def at_index(row): + """Get the selected field of `row`.""" + return row[meas.field] + + for year in years: + data = read_data(station, year) + plot_data.append(data[:, meas.field]) + + year_max = max(data, key=at_index) + year_min = min(data, key=at_index) + if maximum is None: + maximum, minimum = year_max, year_min + else: + maximum = max(maximum, year_max, key=at_index) + minimum = min(minimum, year_min, key=at_index) + + fig, axes = plt.subplots() + fig.set_figheight(6) + fig.set_figwidth(12) + axes.boxplot(plot_data, positions=years) + axes.set_ylabel(meas.title) + axes.grid(True) + axes.set_title( + (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). ' + f'Min {meas.title}: {minimum[meas.field]}, ' + f'Date: ({minimum[0]}).'), + fontsize='small', + ) + name = station['name'] + plt.suptitle(f'{meas.title} measurements, {name} Station, ' + f'{start_year} - {end_year}.') + return savefig_response(fig, filename=( + f'{cls.name}.{station_id}.{year1}.{year2}.{meas.url_name}.png' + )) + + +plotters = {plotter.name: plotter for plotter in Plotter.__subclasses__()} def savefig_response(fig, filename=None): -- GitLab