Skip to content
Snippets Groups Projects
Forked from Owen Graham / visualizer-synthesis
This fork has diverged from the upstream repository.
plotting.py 10.08 KiB
"""Plotting functions."""

from io import BytesIO
from types import SimpleNamespace

from flask import Response

from .data import read_data
from .parameters import get_param, meas_type, year_type
from .records import get_station


class Selector:
    """Information about a `<select>`."""

    def __init__(self, **kwargs):
        if kwargs['type'] not in ('station', 'year', 'measurement'):
            raise ValueError('unknown type')
        self.type = kwargs['type']
        self.name = kwargs['name']
        self.id = kwargs['id']
        self.label = kwargs['label']
        self.onchange = kwargs.get('onchange')

    def get_param(self):
        if self.type == 'year':
            to = year_type
        elif self.type == 'measurement':
            to = meas_type
        else:
            to = str
        return get_param(self.name, to=to, call_abort=False)

    def put_param(self, param):
        """Convert a parameter of this type back to a string."""
        if self.type == 'measurement':
            return param.slug
        return str(param)


class Plotter:
    """Parent of plotter classes."""


class TimeSeries(Plotter):
    """Plot one station/year/measurement."""

    name = 'time-series'
    nav_title = 'Time Series'
    page_title = 'Plot Data'
    onsubmit_fn = 'timeSeriesVisualize'
    selectors = (
        Selector(type='station', name='station', id='station', label='Station',
                 onchange='getYears()'),
        Selector(type='year', name='year', id='year', label='Year'),
        Selector(type='measurement', name='measurement', id='measurement',
                 label='Measurement'),
    )
    source_param_maps = (
        {'station': 'station', 'year': 'year'},
    )

    @classmethod
    def plot(cls):
        import matplotlib
        import matplotlib.pyplot as plt
        import numpy as np

        plt.style.use('ggplot')
        plt.rcParams['axes.xmargin'] = 0
        matplotlib.use('Agg')

        station_id = get_param('station')
        year = get_param('year', to=year_type)
        meas = get_param('measurement', to=meas_type)
        station = get_station(station_id)
        data = read_data(station, year)
        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        axes.plot(data[:, 0], data[:, meas.field])
        avg = np.nanmean(data[:, meas.field], dtype='float32')
        axes.hlines(y=avg,
                    xmin=data[0, 0],
                    xmax=data[-1, 0],
                    linestyle='-',
                    alpha=0.7)
        axes.set_ylabel(f'{meas.title.title()} ({meas.units})')
        axes.grid(True)
        maximum = max(data, key=lambda row: row[meas.field])
        minimum = min(data, key=lambda row: row[meas.field])
        axes.set_title(
            (f'Max {meas.title}: {maximum[meas.field]} {meas.units} '
             f'on {maximum[0]:%F at %R} / '
             f'Min {meas.title}: {minimum[meas.field]} {meas.units} '
             f'on {minimum[0]:%F at %R}'),
            fontsize='small',
        )
        name = station['name']
        plt.suptitle(f'{meas.title.capitalize()} measurements, '
                     f'{name} Station, {data[0, 0].year}')
        return savefig_response(fig, filename=(
            f'{cls.name}.{station_id}.{year}.{meas.slug}.png'
        ))


class Overlay(Plotter):
    """Plot two station/years @ one measurement."""

    name = 'overlay'
    nav_title = 'Overlay'
    page_title = 'Overlay Stations'
    onsubmit_fn = 'overlayVisualize'
    selectors = (
        Selector(type='station', name='station1', id='station-1',
                 label='Station #1', onchange='getYears(1)'),
        Selector(type='year', name='year1', id='year-1', label='Year #1'),
        Selector(type='station', name='station2', id='station-2',
                 label='Station #2', onchange='getYears(2)'),
        Selector(type='year', name='year2', id='year-2', label='Year #2'),
        Selector(type='measurement', name='measurement', id='measurement',
                 label='Measurement'),
    )
    source_param_maps = (
        {'station': 'station1', 'year': 'year1'},
        {'station': 'station2', 'year': 'year2'},
    )

    @classmethod
    def plot(cls):
        import matplotlib
        import matplotlib.pyplot as plt
        import numpy as np

        plt.style.use('ggplot')
        plt.rcParams['axes.xmargin'] = 0
        matplotlib.use('Agg')

        num_datasets = 2
        datasets = tuple(SimpleNamespace() for _ in range(num_datasets))
        for n, dset in enumerate(datasets, start=1):
            dset.station_id = get_param(f'station{n}')
            dset.year = get_param(f'year{n}', to=year_type)
            dset.station = get_station(dset.station_id)
            dset.name = dset.station['name']
        meas = get_param('measurement', to=meas_type)

        def ignore_feb_29(rows):
            return [row for row in rows
                    if (row[0].month, row[0].day) != (2, 29)]

        for dset in datasets:
            raw_data = read_data(dset.station, dset.year)
            dset.data = np.array(ignore_feb_29(raw_data))

        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        for i, dset in enumerate(datasets):
            alpha_kw = {'alpha': 0.6} if i else {}
            axes.plot(datasets[0].data[:, 0],
                      dset.data[:, meas.field],
                      **alpha_kw,
                      label=f'{dset.name} {dset.data[0, 0].year}')
        for dset in datasets:
            dset.avg = np.nanmean(dset.data[:, meas.field], dtype='float32')
            axes.hlines(y=dset.avg,
                        xmin=datasets[0].data[0, 0],
                        xmax=datasets[0].data[-1, 0],
                        alpha=0.7,
                        label=f'Avg {dset.name} {dset.data[0, 0].year}')
        axes.set_ylabel(f'{meas.title.title()} ({meas.units})')

        for dset in datasets:
            dset.max = max(dset.data, key=lambda row: row[meas.field])
            dset.min = min(dset.data, key=lambda row: row[meas.field])
        max_dset = max(datasets, key=lambda dset: dset.max[meas.field])
        min_dset = min(datasets, key=lambda dset: dset.min[meas.field])

        axes.grid(True)
        axes.set_title(
            (f'Max {meas.title}: {max_dset.name} Station, '
             f'{max_dset.max[meas.field]} {meas.units} '
             f'on {max_dset.max[0]:%F at %R} / '
             f'Min {meas.title}: {min_dset.name} Station, '
             f'{min_dset.min[meas.field]} {meas.units} '
             f'on {min_dset.min[0]:%F at %R}'),
            fontsize='small')
        axes.legend()
        axes.tick_params(labelbottom=False)
        title_dsets = ' / '.join(f'{dset.name} Station, {dset.data[0, 0].year}'
                                 for dset in datasets)
        plt.suptitle(f'{meas.title.capitalize()} measurements, '
                     f'{title_dsets}')
        filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}'
                                  for dset in datasets)
        return savefig_response(fig, filename=(
            f'{cls.name}.{filename_dsets}.{meas.slug}.png'
        ))


class Boxplot(Plotter):
    """Boxplot one station/measurement @ multiple years."""

    name = 'boxplot'
    nav_title = 'Boxplot'
    page_title = 'Boxplot Years'
    onsubmit_fn = 'boxplotVisualize'
    selectors = (
        Selector(type='station', name='station', id='station', label='Station',
                 onchange='boxplotGetYears()'),
        Selector(type='year', name='year1', id='year-1', label='Year #1'),
        Selector(type='year', name='year2', id='year-2', label='Year #2'),
        Selector(type='measurement', name='measurement', id='measurement',
                 label='Measurement'),
    )
    source_param_maps = (
        {'station': 'station', 'year': 'year1'},
    )

    @classmethod
    def plot(cls):
        import matplotlib
        import matplotlib.pyplot as plt

        plt.style.use('ggplot')
        plt.rcParams['axes.xmargin'] = 0
        matplotlib.use('Agg')

        station_id = get_param('station')
        year1 = get_param('year1', to=year_type)
        year2 = get_param('year2', to=year_type)
        meas = get_param('measurement', to=meas_type)
        station = get_station(station_id)

        plot_data = []

        start_year, end_year = sorted([year1, year2])
        years = range(start_year, end_year + 1)

        maximum = minimum = None

        def at_index(row):
            """Get the selected field of `row`."""
            return row[meas.field]

        for year in years:
            data = read_data(station, year)
            plot_data.append(data[:, meas.field])

            year_max = max(data, key=at_index)
            year_min = min(data, key=at_index)
            if maximum is None:
                maximum, minimum = year_max, year_min
            else:
                maximum = max(maximum, year_max, key=at_index)
                minimum = min(minimum, year_min, key=at_index)

        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        axes.boxplot(plot_data, positions=years)
        axes.set_ylabel(f'{meas.title.title()} ({meas.units})')
        axes.grid(True)
        axes.set_title(
            (f'Max {meas.title}: {maximum[meas.field]} {meas.units} '
             f'on {maximum[0]:%F at %R} / '
             f'Min {meas.title}: {minimum[meas.field]} {meas.units} '
             f'on {minimum[0]:%F at %R}'),
            fontsize='small',
        )
        name = station['name']
        plt.suptitle(f'{meas.title.capitalize()} measurements, '
                     f'{name} Station, {start_year} - {end_year}.')
        return savefig_response(fig, filename=(
            f'{cls.name}.{station_id}.{year1}.{year2}.{meas.slug}.png'
        ))


plotters = {plotter.name: plotter for plotter in Plotter.__subclasses__()}


def savefig_response(fig, filename=None):
    """Make an image response with `fig.savefig()`."""
    buf = BytesIO()
    fig.savefig(buf, format='png')
    res = Response(buf.getvalue(), mimetype='image/png')
    if filename is not None:
        res.headers['Content-Disposition'] = f'inline; filename="{filename}"'
    return res