Skip to content
Snippets Groups Projects
Select Git revision
  • 8488287aa57410e3af5ba26ebec90a2bcd16fe3e
  • main default protected
  • dev
3 results

__init__.py

Blame
  • Forked from Owen Graham / visualizer-synthesis
    4 commits behind the upstream repository.
    __init__.py 10.12 KiB
    from dataclasses import make_dataclass
    import datetime
    from io import BytesIO
    import json
    from types import SimpleNamespace
    from urllib.request import urlopen
    
    import asccol
    from flask import abort, Flask, g, jsonify, render_template, request, Response
    import matplotlib
    import matplotlib.pyplot as plt
    import numpy as np
    from pyld import jsonld
    
    from . import data_spec
    
    Measurement = make_dataclass('Measurement', ['url_name', 'field', 'title'])
    measurements = {m.url_name: m for m in (
        Measurement('temperature', 1, 'Temperature (C)'),
        Measurement('pressure', 2, 'Pressure (hPa)'),
        Measurement('wind-speed', 3, 'Wind Speed (m/s)'),
    )}
    # Special JSON property names/values
    ACCESS_URL = 'http://www.w3.org/ns/dcat#accessURL'
    DISTRIBUTION = 'http://www.w3.org/ns/dcat#Distribution'
    
    app = Flask(__name__)
    
    app.jinja_env.trim_blocks = True
    app.jinja_env.lstrip_blocks = True
    app.jinja_env.keep_trailing_newline = True
    
    plt.style.use('ggplot')
    plt.rcParams['axes.xmargin'] = 0
    matplotlib.use('Agg')
    
    
    @app.route('/', defaults={'plot_name': 'index'})
    @app.route('/<plot_name>')
    def page(plot_name):
        """Render a visualizer page."""
        if plot_name not in plotters:
            abort(404)
    
        stations = read_stations()
        g.stations = stations
    
        def id_years(station):
            years = [record['year'] for record in station['records']]
            return station['id'], years
    
        g.record_years = dict(map(id_years, stations))
    
        return render_template(f'{plot_name}.html')
    
    
    @app.route('/record/link')
    def record_link():
        """Return the source link for the selected dataset."""
        station = get_station(get_param('station'))
        year = get_param('year', to=year_type)
        return jsonify(get_link(station, year))
    
    
    @app.route('/plot', defaults={'plot_name': 'index'})
    @app.route('/plot/<plot_name>')
    def plot(plot_name):
        """Create a data plot image."""
        plotter = plotters.get(plot_name)
        if plotter is None:
            abort(404)
        return plotter.plot()
    
    
    Plotter = make_dataclass('Plotter', ['name', 'plot'])
    plotters = {}
    
    
    def add_plotter(name):
        """Decorator to register a plot type."""
    
        def decorator(f):
            plotters[name] = Plotter(name, f)
            return f
    
        return decorator
    
    
    @add_plotter('index')
    def plot_index():
        """Plot one station/year/measurement."""
        station_id = get_param('station')
        year = get_param('year', to=year_type)
        meas = get_param('measurement', to=meas_type)
        station = get_station(station_id)
        data = read_data(station, year)
        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        axes.plot(data[:, 0], data[:, meas.field])
        avg = np.nanmean(data[:, meas.field], dtype='float32')
        axes.hlines(y=avg,
                    xmin=data[:, 0][0],
                    xmax=data[:, 0][-1],
                    linestyle='-',
                    alpha=0.7)
        axes.set_ylabel(meas.title)
        axes.grid(True)
        maximum = max(data, key=lambda row: row[meas.field])
        minimum = min(data, key=lambda row: row[meas.field])
        axes.set_title(
            (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
             f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'),
            fontsize='small',
        )
        name = station['name']
        plt.suptitle(f'{meas.title} measurements, {name} Station, '
                     f'{data[0][0].year}')
        return savefig_response(fig)
    
    
    @add_plotter('overlay')
    def plot_overlay():
        """Plot two station/years @ one measurement."""
        num_datasets = 2
        datasets = tuple(SimpleNamespace() for _ in range(num_datasets))
        stations = read_stations()
        for n, dset in enumerate(datasets, start=1):
            dset.station_id = get_param(f'station{n}')
            dset.year = get_param(f'year{n}', to=year_type)
            dset.station = get_station(dset.station_id, stations=stations)
            dset.name = dset.station['name']
        meas = get_param('measurement', to=meas_type)
    
        def ignore_feb_29(rows):
            return [row for row in rows if (row[0].month, row[0].day) != (2, 29)]
    
        for dset in datasets:
            raw_data = read_data(dset.station, dset.year)
            dset.data = np.array(ignore_feb_29(raw_data))
    
        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        for i, dset in enumerate(datasets):
            alpha_kw = {'alpha': 0.6} if i else {}
            axes.plot(datasets[0].data[:, 0],
                      dset.data[:, meas.field],
                      **alpha_kw,
                      label=f'{dset.name} {dset.data[0][0].year}')
        for dset in datasets:
            dset.avg = np.nanmean(dset.data[:, meas.field], dtype='float32')
            axes.hlines(y=dset.avg,
                        xmin=datasets[0].data[:, 0][0],
                        xmax=datasets[0].data[:, 0][-1],
                        alpha=0.7,
                        label=f'Avg {dset.name} {dset.data[0][0].year}')
        axes.set_ylabel(meas.title)
    
        for dset in datasets:
            dset.max = max(dset.data, key=lambda row: row[meas.field])
            dset.min = min(dset.data, key=lambda row: row[meas.field])
        max_dset = max(datasets, key=lambda dset: dset.max[meas.field])
        min_dset = min(datasets, key=lambda dset: dset.min[meas.field])
    
        axes.grid(True)
        axes.set_title((f'Max {meas.title}: {max_dset.max[meas.field]}, '
                        f'{max_dset.name} Station, Date: ({max_dset.max[0]}). '
                        f'Min {meas.title}: {min_dset.min[meas.field]}, '
                        f'{min_dset.name} Station, Date: ({min_dset.min[0]}).'),
                       fontsize='small')
        axes.legend()
        axes.tick_params(labelbottom=False)
        title_dsets = ' / '.join(f'{dset.name} Station, {dset.data[0][0].year}'
                                 for dset in datasets)
        plt.suptitle(f'{meas.title} measurements, {title_dsets}')
        return savefig_response(fig)
    
    
    @add_plotter('boxplot')
    def plot_boxplot():
        """Boxplot one station/measurement @ multiple years."""
        station_id = get_param('station')
        year1 = get_param('year1', to=year_type)
        year2 = get_param('year2', to=year_type)
        meas = get_param('measurement', to=meas_type)
        station = get_station(station_id)
    
        plot_data = []
    
        start_year, end_year = sorted([year1, year2])
        years = range(start_year, end_year + 1)
    
        maximum = minimum = None
    
        def at_index(row):
            """Get the selected field of `row`."""
            return row[meas.field]
    
        for year in years:
            data = read_data(station, year)
            selected_data = data[:, meas.field]
            plot_data.append(selected_data)
    
            year_max = max(data, key=at_index)
            year_min = min(data, key=at_index)
            if maximum is None:
                maximum, minimum = year_max, year_min
            else:
                maximum = max(maximum, year_max, key=at_index)
                minimum = min(minimum, year_min, key=at_index)
    
        fig, axes = plt.subplots()
        fig.set_figheight(6)
        fig.set_figwidth(12)
        axes.boxplot(plot_data, positions=years)
        axes.set_ylabel(meas.title)
        axes.grid(True)
        axes.set_title(
            (f'Max {meas.title}: {maximum[meas.field]}, Date: ({maximum[0]}). '
             f'Min {meas.title}: {minimum[meas.field]}, Date: ({minimum[0]}).'),
            fontsize='small',
        )
        name = station['name']
        plt.suptitle(f'{meas.title} measurements, {name} Station, '
                     f'{start_year} - {end_year}.')
        return savefig_response(fig)
    
    
    def get_param(key, to=str):
        """Get a parameter from the query string.
    
        Calls `abort(400)` if the value is the wrong type or missing.
        """
        val = request.args.get(key)
        if val is None:
            abort(400)
        try:
            return to(val)
        except ValueError:
            abort(400)
    
    
    def meas_type(s):
        """Type converter for valid measurements."""
        val = measurements.get(s)
        if val is not None:
            return val
        raise ValueError(f'bad measurement arg: {s!r}')
    
    
    def year_type(s):
        """Type converter for valid years."""
        if s.isdigit():
            val = int(s)
            if val in range(10000):
                return val
        raise ValueError(f'bad year arg: {s!r}')
    
    
    def savefig_response(fig):
        """Make an image response with `fig.savefig()`."""
        buf = BytesIO()
        fig.savefig(buf, format='png')
        return Response(buf.getvalue(), mimetype='image/png')
    
    
    def read_stations():
        """Read amrdcrecords.json."""
        with open('amrdcrecords.json') as f:
            return json.load(f)
    
    
    def get_station(station_id, stations=None):
        """Get a station record by ID.
    
        Calls `abort(404)` if none is found."""
        if stations is None:
            stations = read_stations()
        for station in stations:
            if station['id'] == station_id:
                return station
        abort(404)
    
    
    def get_link(station, year):
        """Get the link to a dataset.
    
        Calls `abort(404)` if none is found.
        """
        for record in station['records']:
            if record['year'] == year:
                return record['url']
        abort(404)
    
    
    def get_resources(link):
        """Fetch the download links for a dataset."""
        doc = jsonld.flatten(link + '.jsonld')
        for i in doc:
            types = i.get('@type', [])
            if DISTRIBUTION not in types:
                continue
            urls = i.get(ACCESS_URL, [])
            for j in urls:
                url = j.get('@id')
                if url is None:
                    continue
                # If 10-minute, 1-hour, and 3-hour data are available, only
                # use the 1-hour.
                if '10min' in url or '3hr' in url:
                    continue
                yield url
    
    
    def read_data(station, year):
        """Fetch data and convert it to a NumPy array."""
        spec = (data_spec.SOUTH_POLE if station['id'] == 'south-pole'
                else data_spec.ONE_HOUR)
        data = []
        resource_list = get_resources(get_link(station, year))
        for url in resource_list:
            with urlopen(url) as f:
                lines = map(bytes.decode, f)
                for row in asccol.parse_data(lines, spec):
                    date = datetime.datetime(row.year, row.month, row.day,
                                             *row.time)
                    data.append([
                        date,
                        row.temp,
                        row.pressure,
                        row.wind_speed,
                    ])
        data = np.array(data)
        # Sort by date, since monthly URLs will be out of order.
        data = data[data[:, 0].argsort()]
        return data