Skip to content
Snippets Groups Projects
Forked from Owen Graham / visualizer-synthesis
This fork has diverged from the upstream repository.
plotting.py 10.35 KiB
"""Plotting functions."""

import math
from io import BytesIO
import operator
from types import SimpleNamespace

from flask import Response

from .data import read_data
from .parameters import get_param, meas_type, year_type
from .records import get_station


class Selector:
    """Information about a `<select>`."""

    def __init__(self, **kwargs):
        if kwargs['type'] not in ('measurement', 'station', 'year'):
            raise ValueError('unknown type')
        self.type = kwargs['type']
        self.name = kwargs['name']
        self.id = kwargs['id']
        self.label = kwargs['label']
        self.onchange = kwargs.get('onchange')

    def get_param(self):
        if self.type == 'measurement':
            to = meas_type
        elif self.type == 'year':
            to = year_type
        else:
            to = str
        return get_param(self.name, to=to, call_abort=False)

    def put_param(self, param):
        """Convert a parameter of this type back to a string."""
        if self.type == 'measurement':
            return param.slug
        return str(param)


class Plotter:
    """Parent of plotter classes."""


class TimeSeries(Plotter):
    """Plot one measurement/station/year."""

    name = 'time-series'
    nav_title = 'Time Series'
    page_title = 'Plot Data'
    onsubmit_fn = 'timeSeriesVisualize'
    selectors = (
        Selector(type='measurement', name='measurement', id='measurement',
                 label='Measurement'),
        Selector(type='station', name='station', id='station', label='Station',
                 onchange='getYears()'),
        Selector(type='year', name='year', id='year', label='Year'),
    )
    source_param_maps = (
        {'station': 'station', 'year': 'year'},
    )

    @classmethod
    def plot(cls):
        import matplotlib
        import matplotlib.pyplot as plt
        import numpy as np

        plt.style.use('ggplot')
        plt.rcParams['axes.xmargin'] = 0
        matplotlib.use('Agg')

        meas = get_param('measurement', to=meas_type)
        station_id = get_param('station')
        year = get_param('year', to=year_type)
        station = get_station(station_id)
        data = read_data(station, year)
        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        axes.plot(data[:, 0], data[:, meas.field])
        avg = np.nanmean(data[:, meas.field], dtype='float32')
        axes.hlines(y=avg, xmin=data[0, 0], xmax=data[-1, 0])
        axes.set_ylabel(f'{meas.title.title()} ({meas.units})')
        axes.grid(True)
        maximum, minimum = data_max_min(data, field=meas.field)
        axes.set_title(
            (f'Max {meas.title}: {maximum[meas.field]} {meas.units} '
             f'on {maximum[0]:%F at %R} / '
             f'Min {meas.title}: {minimum[meas.field]} {meas.units} '
             f'on {minimum[0]:%F at %R}'),
            fontsize='small',
        )
        name = station['name']
        plt.suptitle(f'{meas.title.capitalize()} measurements, '
                     f'{name}, {data[0, 0].year}')
        return savefig_response(fig, filename=(
            f'{cls.name}.{meas.slug}.{station_id}.{year}.png'
        ))


class Overlay(Plotter):
    """Plot one measurement @ two station/years."""

    name = 'overlay'
    nav_title = 'Overlay'
    page_title = 'Overlay Stations'
    onsubmit_fn = 'overlayVisualize'
    selectors = (
        Selector(type='measurement', name='measurement', id='measurement',
                 label='Measurement'),
        Selector(type='station', name='station1', id='station-1',
                 label='Station #1', onchange='getYears(1)'),
        Selector(type='year', name='year1', id='year-1', label='Year #1'),
        Selector(type='station', name='station2', id='station-2',
                 label='Station #2', onchange='getYears(2)'),
        Selector(type='year', name='year2', id='year-2', label='Year #2'),
    )
    source_param_maps = (
        {'station': 'station1', 'year': 'year1'},
        {'station': 'station2', 'year': 'year2'},
    )

    @classmethod
    def plot(cls):
        import matplotlib
        import matplotlib.pyplot as plt
        import numpy as np

        plt.style.use('ggplot')
        plt.rcParams['axes.xmargin'] = 0
        matplotlib.use('Agg')

        meas = get_param('measurement', to=meas_type)
        num_datasets = 2
        datasets = tuple(SimpleNamespace() for _ in range(num_datasets))
        for n, dset in enumerate(datasets, start=1):
            dset.station_id = get_param(f'station{n}')
            dset.year = get_param(f'year{n}', to=year_type)
            dset.station = get_station(dset.station_id)
            dset.name = dset.station['name']

        def ignore_feb_29(rows):
            return [row for row in rows
                    if (row[0].month, row[0].day) != (2, 29)]

        for dset in datasets:
            raw_data = read_data(dset.station, dset.year)
            dset.data = np.array(ignore_feb_29(raw_data))
            dset.avg = np.nanmean(dset.data[:, meas.field], dtype='float32')

        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        for i, dset in enumerate(datasets):
            # Plot all dates as though they were in 1999.
            x_data = [dt.replace(year=1999) for dt in dset.data[:, 0]]
            axes.plot(x_data,
                      dset.data[:, meas.field],
                      alpha=0.875 if i else 1.0,
                      label=f'{dset.name}, {dset.data[0, 0].year}')
            axes.hlines(y=dset.avg,
                        xmin=x_data[0],
                        xmax=x_data[-1],
                        colors=f'C{i}',
                        alpha=0.875 if i else 1.0)
        axes.set_ylabel(f'{meas.title.title()} ({meas.units})')

        row_meas = operator.itemgetter(meas.field)
        for dset in datasets:
            dset.max, dset.min = data_max_min(dset.data, key=row_meas)
        max_dset = max(datasets, key=lambda dset: row_meas(dset.max))
        min_dset = min(datasets, key=lambda dset: row_meas(dset.min))

        axes.grid(True)
        axes.set_title(
            (f'Max {meas.title}: {max_dset.name}, '
             f'{max_dset.max[meas.field]} {meas.units} '
             f'on {max_dset.max[0]:%F at %R} / '
             f'Min {meas.title}: {min_dset.name}, '
             f'{min_dset.min[meas.field]} {meas.units} '
             f'on {min_dset.min[0]:%F at %R}'),
            fontsize='small')
        axes.legend()
        axes.tick_params(labelbottom=False)
        title_dsets = ' / '.join(f'{dset.name}, {dset.data[0, 0].year}'
                                 for dset in datasets)
        plt.suptitle(f'{meas.title.capitalize()} measurements, '
                     f'{title_dsets}')
        filename_dsets = '.'.join(f'{dset.station_id}.{dset.year}'
                                  for dset in datasets)
        return savefig_response(fig, filename=(
            f'{cls.name}.{meas.slug}.{filename_dsets}.png'
        ))


class Boxplot(Plotter):
    """Boxplot one measurement/station @ multiple years."""

    name = 'boxplot'
    nav_title = 'Boxplot'
    page_title = 'Boxplot Years'
    onsubmit_fn = 'boxplotVisualize'
    selectors = (
        Selector(type='measurement', name='measurement', id='measurement',
                 label='Measurement'),
        Selector(type='station', name='station', id='station', label='Station',
                 onchange='boxplotGetYears()'),
        Selector(type='year', name='year1', id='year-1', label='Year #1'),
        Selector(type='year', name='year2', id='year-2', label='Year #2'),
    )
    source_param_maps = (
        {'station': 'station', 'year': 'year1'},
    )

    @classmethod
    def plot(cls):
        import matplotlib
        import matplotlib.pyplot as plt

        plt.style.use('ggplot')
        plt.rcParams['axes.xmargin'] = 0
        matplotlib.use('Agg')

        meas = get_param('measurement', to=meas_type)
        station_id = get_param('station')
        year1 = get_param('year1', to=year_type)
        year2 = get_param('year2', to=year_type)
        station = get_station(station_id)

        plot_data = []

        start_year, end_year = sorted([year1, year2])
        years = range(start_year, end_year + 1)

        row_meas = operator.itemgetter(meas.field)
        years_maxima = []
        years_minima = []
        for year in years:
            data = read_data(station, year)
            plot_data.append(data[:, meas.field])
            year_max, year_min = data_max_min(data, key=row_meas)
            years_maxima.append(year_max)
            years_minima.append(year_min)
        maximum = max(years_maxima, key=row_meas)
        minimum = min(years_minima, key=row_meas)

        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        axes.boxplot(plot_data, positions=years)
        axes.set_ylabel(f'{meas.title.title()} ({meas.units})')
        axes.grid(True)
        axes.set_title(
            (f'Max {meas.title}: {maximum[meas.field]} {meas.units} '
             f'on {maximum[0]:%F at %R} / '
             f'Min {meas.title}: {minimum[meas.field]} {meas.units} '
             f'on {minimum[0]:%F at %R}'),
            fontsize='small',
        )
        name = station['name']
        plt.suptitle(f'{meas.title.capitalize()} measurements, '
                     f'{name}, {start_year} - {end_year}.')
        return savefig_response(fig, filename=(
            f'{cls.name}.{meas.slug}.{station_id}.{year1}.{year2}.png'
        ))


plotters = {plotter.name: plotter for plotter in Plotter.__subclasses__()}


def savefig_response(fig, filename=None):
    """Make an image response with `fig.savefig()`."""
    buf = BytesIO()
    fig.savefig(buf, format='png')
    res = Response(buf.getvalue(), mimetype='image/png')
    if filename is not None:
        res.headers['Content-Disposition'] = f'inline; filename="{filename}"'
    return res


def data_max_min(data, field=None, key=None):
    """Smarter, NaN-aware `max_min()`, for plot functions."""
    if key is None:
        key = operator.itemgetter(field)
    comparable = non_nan(data, key=key)
    return max_min(comparable, key=key, default=data[0])


def max_min(*args, **kwargs):
    """Call `max()` and `min()` with the same arguments."""
    return max(*args, **kwargs), min(*args, **kwargs)


def non_nan(items, key=None):
    """Filter out items that are NaN."""
    if key is None:
        key = lambda x: x
    return [i for i in items if not math.isnan(key(i))]