Skip to content
Snippets Groups Projects
Verified Commit 15412d46 authored by Owen Graham's avatar Owen Graham
Browse files

Refactor plotters into classes

parent 401f0f6c
No related branches found
No related tags found
No related merge requests found
"""Plotting functions.""" """Plotting functions."""
from dataclasses import make_dataclass
from io import BytesIO from io import BytesIO
from types import SimpleNamespace from types import SimpleNamespace
...@@ -17,23 +16,19 @@ plt.style.use('ggplot') ...@@ -17,23 +16,19 @@ plt.style.use('ggplot')
plt.rcParams['axes.xmargin'] = 0 plt.rcParams['axes.xmargin'] = 0
matplotlib.use('Agg') matplotlib.use('Agg')
Plotter = make_dataclass('Plotter', ['name', 'nav_title', 'plot'])
plotters = {}
class Plotter:
"""Parent of plotter classes."""
def add_plotter(name, nav_title):
"""Decorator to register a plot type."""
def decorator(f): class TimeSeries(Plotter):
plotters[name] = Plotter(name, nav_title, f) """Plot one station/year/measurement."""
return f
return decorator
name = 'time-series'
nav_title = 'Time Series'
@add_plotter('time-series', 'Time Series') @classmethod
def plot_time_series(): def plot(cls):
"""Plot one station/year/measurement."""
station_id = get_param('station') station_id = get_param('station')
year = get_param('year', to=year_type) year = get_param('year', to=year_type)
meas = get_param('measurement', to=meas_type) meas = get_param('measurement', to=meas_type)
...@@ -55,21 +50,26 @@ def plot_time_series(): ...@@ -55,21 +50,26 @@ def plot_time_series():
minimum = min(data, key=lambda row: row[meas.field]) minimum = min(data, key=lambda row: row[meas.field])
axes.set_title( axes.set_title(
(f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). ' (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'), f'Min {meas.title}: {minimum[meas.field]}, '
f'Date: ({minimum[0]}).'),
fontsize='small', fontsize='small',
) )
name = station['name'] name = station['name']
plt.suptitle(f'{meas.title} measurements, {name} Station, ' plt.suptitle(f'{meas.title} measurements, {name} Station, '
f'{data[0][0].year}') f'{data[0][0].year}')
return savefig_response( return savefig_response(fig, filename=(
fig, f'{cls.name}.{station_id}.{year}.{meas.url_name}.png'
filename=f'time-series.{station_id}.{year}.{meas.url_name}.png', ))
)
@add_plotter('overlay', 'Overlay') class Overlay(Plotter):
def plot_overlay():
"""Plot two station/years @ one measurement.""" """Plot two station/years @ one measurement."""
name = 'overlay'
nav_title = 'Overlay'
@classmethod
def plot(cls):
num_datasets = 2 num_datasets = 2
datasets = tuple(SimpleNamespace() for _ in range(num_datasets)) datasets = tuple(SimpleNamespace() for _ in range(num_datasets))
stations = read_stations() stations = read_stations()
...@@ -81,7 +81,8 @@ def plot_overlay(): ...@@ -81,7 +81,8 @@ def plot_overlay():
meas = get_param('measurement', to=meas_type) meas = get_param('measurement', to=meas_type)
def ignore_feb_29(rows): def ignore_feb_29(rows):
return [row for row in rows if (row[0].month, row[0].day) != (2, 29)] return [row for row in rows
if (row[0].month, row[0].day) != (2, 29)]
for dset in datasets: for dset in datasets:
raw_data = read_data(dset.station, dset.year) raw_data = read_data(dset.station, dset.year)
...@@ -112,7 +113,8 @@ def plot_overlay(): ...@@ -112,7 +113,8 @@ def plot_overlay():
min_dset = min(datasets, key=lambda dset: dset.min[meas.field]) min_dset = min(datasets, key=lambda dset: dset.min[meas.field])
axes.grid(True) axes.grid(True)
axes.set_title((f'Max {meas.title}: {max_dset.max[meas.field]}, ' axes.set_title(
(f'Max {meas.title}: {max_dset.max[meas.field]}, '
f'{max_dset.name} Station, Date: ({max_dset.max[0]}). ' f'{max_dset.name} Station, Date: ({max_dset.max[0]}). '
f'Min {meas.title}: {min_dset.min[meas.field]}, ' f'Min {meas.title}: {min_dset.min[meas.field]}, '
f'{min_dset.name} Station, Date: ({min_dset.min[0]}).'), f'{min_dset.name} Station, Date: ({min_dset.min[0]}).'),
...@@ -124,15 +126,19 @@ def plot_overlay(): ...@@ -124,15 +126,19 @@ def plot_overlay():
plt.suptitle(f'{meas.title} measurements, {title_dsets}') plt.suptitle(f'{meas.title} measurements, {title_dsets}')
filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}' filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}'
for dset in datasets) for dset in datasets)
return savefig_response( return savefig_response(fig, filename=(
fig, f'{cls.name}.{filename_dsets}.{meas.url_name}.png'
filename=f'overlay.{filename_dsets}.{meas.url_name}.png', ))
)
@add_plotter('boxplot', 'Boxplot') class Boxplot(Plotter):
def plot_boxplot():
"""Boxplot one station/measurement @ multiple years.""" """Boxplot one station/measurement @ multiple years."""
name = 'boxplot'
nav_title = 'Boxplot'
@classmethod
def plot(cls):
station_id = get_param('station') station_id = get_param('station')
year1 = get_param('year1', to=year_type) year1 = get_param('year1', to=year_type)
year2 = get_param('year2', to=year_type) year2 = get_param('year2', to=year_type)
...@@ -152,8 +158,7 @@ def plot_boxplot(): ...@@ -152,8 +158,7 @@ def plot_boxplot():
for year in years: for year in years:
data = read_data(station, year) data = read_data(station, year)
selected_data = data[:, meas.field] plot_data.append(data[:, meas.field])
plot_data.append(selected_data)
year_max = max(data, key=at_index) year_max = max(data, key=at_index)
year_min = min(data, key=at_index) year_min = min(data, key=at_index)
...@@ -171,16 +176,19 @@ def plot_boxplot(): ...@@ -171,16 +176,19 @@ def plot_boxplot():
axes.grid(True) axes.grid(True)
axes.set_title( axes.set_title(
(f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). ' (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'), f'Min {meas.title}: {minimum[meas.field]}, '
f'Date: ({minimum[0]}).'),
fontsize='small', fontsize='small',
) )
name = station['name'] name = station['name']
plt.suptitle(f'{meas.title} measurements, {name} Station, ' plt.suptitle(f'{meas.title} measurements, {name} Station, '
f'{start_year} - {end_year}.') f'{start_year} - {end_year}.')
return savefig_response( return savefig_response(fig, filename=(
fig, f'{cls.name}.{station_id}.{year1}.{year2}.{meas.url_name}.png'
filename=(f'boxplot.{station_id}.{year1}.{year2}.{meas.url_name}.png'), ))
)
plotters = {plotter.name: plotter for plotter in Plotter.__subclasses__()}
def savefig_response(fig, filename=None): def savefig_response(fig, filename=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment