Skip to content
Snippets Groups Projects
test_data_api.py 16.63 KiB
import contextlib
import json
import random
from datetime import datetime, timedelta
from typing import Any, Callable, ContextManager, Iterable, Iterator
from unittest import mock

import numpy as np
import pandas as pd
import pytest

try:
    import influxdb_client
except ImportError:
    influxdb_client = None

try:
    from influxdb.resultset import ResultSet
except ImportError:
    ResultSet = None

from metobsapi.util.query_influx import QueryResult


@pytest.fixture(params=["1", "2"], ids=["influxdbv1", "influxdbv2"])
def influxdb_version(request):
    return request.param


@pytest.fixture
def mock_influxdb_query(influxdb_version) -> Callable[[list[dict]], ContextManager[mock.Mock]]:
    @contextlib.contextmanager
    def _mock_influxdb_query_with_fake_data(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
        with _mock_influxdb_library_availability(influxdb_version == "1", influxdb_version == "2"):
            if influxdb_version == "1":
                with _patch_v1_query(fake_result_data) as query_func:
                    yield query_func
            elif influxdb_version == "2":
                with _patch_v2_query(fake_result_data) as query_func:
                    yield query_func

    return _mock_influxdb_query_with_fake_data


@contextlib.contextmanager
def _mock_influxdb_library_availability(v1_available: bool, v2_available: bool) -> Iterator[None]:
    with _mock_if(not v1_available, "metobsapi.util.query_influx.InfluxDBClientv1"), _mock_if(
        not v2_available, "metobsapi.util.query_influx.InfluxDBClientv2"
    ):
        yield None


@contextlib.contextmanager
def _mock_if(condition: bool, obj_path: str) -> Iterator[None]:
    if not condition:
        yield None
        return
    with mock.patch(obj_path, None):
        yield None


@contextlib.contextmanager
def _patch_v1_query(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
    if ResultSet is None:
        pytest.skip("Version 1 'influxdb' dependency missing. Skipping test.")
    with mock.patch("metobsapi.util.query_influx._query_influxdbv1") as query_func:
        fake_data_for_v1 = _fake_data_to_v1_resultset(fake_result_data)
        query_func.return_value = fake_data_for_v1
        yield query_func


@contextlib.contextmanager
def _patch_v2_query(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
    if influxdb_client is None:
        pytest.skip("Version 2 'influxdb-client' dependency missing. Skipping test.")

    site_data_frames = []
    for data_dict in fake_result_data:
        data_row_dicts = []
        for row_values in data_dict["values"]:
            data_row = {}
            for col_name, col_value in zip(data_dict["columns"], row_values):
                data_row[col_name] = np.nan if col_value is None else col_value
            data_row.update(data_dict["tags"])
            data_row_dicts.append(data_row)

        single_df = pd.DataFrame(data_row_dicts, columns=["site", "inst"] + data_dict["columns"])
        single_df.set_index("time", inplace=True)
        single_df = single_df.dropna(axis=1, how="all")
        site_data_frames.append(single_df)

    # merged_df = pd.concat(site_data_frames, axis=1, copy=False)
    with mock.patch("metobsapi.util.query_influx._query_influxdbv2") as query_func:
        query_func.return_value = site_data_frames
        yield query_func


def _fake_data_to_v1_resultset(fake_data: list[dict]) -> QueryResult:
    t_format = "%Y-%m-%dT%H:%M:%SZ"
    result_sets = []
    for single_result in fake_data:
        for value_row in single_result["values"]:
            # convert datetime object to isoformat for default epoch format
            value_row[0] = value_row[0].strftime(t_format)
        result_sets.append(single_result)
    return [ResultSet({"series": [single_result], "statement_id": 0}) for single_result in result_sets]


@pytest.fixture
def influxdb_air_temp_9_values(mock_influxdb_query) -> Iterable[None]:
    fake_result = _fake_data("1m", {("aoss", "tower"): ["time", "air_temp"]}, 9)
    with mock_influxdb_query(fake_result):
        yield


@pytest.fixture
def influxdb_3_symbols_9_values(mock_influxdb_query) -> Iterable[None]:
    fake_result = _fake_data("1m", {("aoss", "tower"): ["time", "air_temp", "rel_hum", "wind_speed"]}, 9)
    with mock_influxdb_query(fake_result):
        yield


@pytest.fixture
def influxdb_wind_fields_9_values(mock_influxdb_query) -> Iterable[None]:
    fake_result = _fake_data(
        "1m", {("aoss", "tower"): ["time", "wind_speed", "wind_direction", "wind_east", "wind_north"]}, 9
    )
    with mock_influxdb_query(fake_result):
        yield


def _fake_data(interval: str, symbols: dict[tuple[str, str], list[str]], num_vals: int) -> list[dict[str, Any]]:
    now = datetime(2017, 3, 5, 19, 0, 0)
    measurement_name = "metobs_" + interval
    series = []
    for (site, inst), columns in symbols.items():
        tags = {"site": site, "inst": inst}
        vals: list[list[datetime | float | None]] = []
        for i in range(num_vals):
            vals.append([(now + timedelta(minutes=i))] + _generate_fake_column_values(columns[1:]))
            # make up some Nones/nulls (but not all the time)
            r_idx = int(random.random() * len(columns) * 3)
            # don't include time (index 0)
            if 0 < r_idx < len(columns):
                vals[-1][r_idx] = None
        s = {
            "name": measurement_name,
            "columns": columns,
            "tags": tags,
            "values": vals,
        }
        series.append(s)

    return series


def _generate_fake_column_values(columns: list[str]) -> list[float | None]:
    rand_data: list[float | None] = [random.random()] * len(columns)
    if "wind_direction" in columns:
        # wind direction shouldn't exist in the database
        rand_data[columns.index("wind_direction")] = None
    return rand_data


@pytest.fixture
def app(influxdb_version):
    from metobsapi import app as real_app

    real_app.config.update({"TESTING": True, "DEBUG": True})
    if influxdb_version == "2":
        real_app.config.update({"INFLUXDB_TOKEN": "1234"})
    yield real_app


@pytest.fixture
def client(app):
    return app.test_client()


class TestDataAPI:
    def test_doc(self, client):
        res = client.get("/api/data")
        assert b"Data Request Application" in res.data

    def test_bad_format(self, client):
        res = client.get("/api/data.fake")
        assert b"No data file format" in res.data

    def test_bad_begin_json(self, client):
        res = client.get("/api/data.json?symbols=air_temp&begin=blah")
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "timestamp" in res["message"]

    def test_bad_order(self, client):
        res = client.get("/api/data.json?order=blah&symbols=air_temp")
        res = json.loads(res.data.decode())
        assert "column" in res["message"]
        assert "row" in res["message"]

    def test_bad_epoch(self, client):
        res = client.get("/api/data.json?epoch=blah&symbols=air_temp")
        res = json.loads(res.data.decode())
        assert "'h'" in res["message"]
        assert "'m'" in res["message"]
        assert "'s'" in res["message"]
        assert "'u'" in res["message"]

    def test_bad_interval(self, client):
        res = client.get("/api/data.json?interval=blah&symbols=air_temp")
        res = json.loads(res.data.decode())
        assert "'1m'" in res["message"]
        assert "'5m'" in res["message"]
        assert "'1h'" in res["message"]

    def test_missing_inst(self, client):
        res = client.get("/api/data.json?site=X&symbols=air_temp&begin=-05:00:00")
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "'site'" in res["message"]
        assert "'inst'" in res["message"]

    def test_missing_site(self, client):
        res = client.get("/api/data.json?inst=X&symbols=air_temp&begin=-05:00:00")
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "'site'" in res["message"]
        assert "'inst'" in res["message"]

    def test_missing_symbols(self, client):
        res = client.get("/api/data.json?begin=-05:00:00")
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "'symbols'" in res["message"]

    def test_nonexistent_symbol(self, client):
        res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss.tower.air_temp:aoss.tower.noexist")
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "Unknown symbol" in res["message"]

    def test_bad_symbol_site_inst(self, client):
        res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss.tower.something.air_temp")
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "3 period-separated parts" in res["message"]

    def test_symbol_unknown_site_inst(self, client):
        res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss2.tower.air_temp")
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "Unknown site/instrument" in res["message"]

    def test_multiple_symbol(self, client):
        res = client.get(
            "/api/data.json?begin=-05:00:00&symbols=aoss.tower.air_temp:aoss.tower.wind_speed:aoss.tower.air_temp"
        )
        res = json.loads(res.data.decode())
        assert res["code"] == 400
        assert res["status"] == "error"
        assert "multiple times" in res["message"]

    def test_too_many_points(self, client):
        res = client.get("/api/data.json?symbols=aoss.tower.air_temp&begin=1970-01-01T00:00:00")
        assert res.status_code == 413
        res = json.loads(res.data.decode())
        assert "too many values" in res["message"]
        assert res["code"] == 413
        assert res["status"] == "fail"

    def test_shorthand_one_symbol_json_row(self, client, influxdb_air_temp_9_values):
        # row should be the default
        res = client.get("/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin=-00:10:00")
        res = json.loads(res.data.decode())
        assert res["code"] == 200
        assert res["num_results"] == 9
        assert res["results"]["symbols"] == ["air_temp"]
        assert len(res["results"]["timestamps"]) == 9
        assert len(res["results"]["data"]) == 9
        assert len(res["results"]["data"][0]) == 1

    def test_shorthand_one_symbol_json_column(self, client, influxdb_air_temp_9_values):
        res = client.get("/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin=-00:10:00&order=column")
        res = json.loads(res.data.decode())
        assert res["code"] == 200
        assert res["num_results"] == 9
        assert "air_temp" in res["results"]["data"]
        assert len(res["results"]["data"]["air_temp"]) == 9
        assert len(res["results"]["timestamps"]) == 9

    @pytest.mark.parametrize(
        "symbols",
        [
            ["aoss.tower.wind_speed", "aoss.tower.wind_direction"],
            ["wind_speed", "wind_direction"],
        ],
    )
    def test_wind_speed_direction_json(self, symbols, client, influxdb_wind_fields_9_values):
        symbol_param = ":".join(symbols)
        site_inst_params = "&site=aoss&inst=tower" if "." not in symbols[0] else ""

        res = client.get(f"/api/data.json?symbols={symbol_param}{site_inst_params}&begin=-00:10:00&order=column")
        res = json.loads(res.data.decode())
        assert res["code"] == 200
        assert res["num_results"] == 9
        for symbol_name in symbols:
            assert symbol_name in res["results"]["data"]
            assert not np.isnan(res["results"]["data"][symbol_name]).all()
        assert len(list(res["results"]["data"].keys())) == len(symbols)

    def test_one_symbol_two_insts_json_row(self, client, mock_influxdb_query):
        fake_result = _fake_data(
            "1m",
            {
                ("aoss", "tower"): ["time", "air_temp"],
                ("mendota", "buoy"): ["time", "air_temp"],
            },
            9,
        )
        with mock_influxdb_query(fake_result):
            # row should be the default
            res = client.get("/api/data.json?symbols=aoss.tower.air_temp:mendota.buoy.air_temp&begin=-00:10:00")
        res = json.loads(res.data.decode())
        assert res["code"] == 200
        assert res["num_results"] == 9
        assert res["results"]["symbols"] == ["aoss.tower.air_temp", "mendota.buoy.air_temp"]
        assert len(res["results"]["timestamps"]) == 9
        assert len(res["results"]["data"]) == 9
        assert len(res["results"]["data"][0]) == 2

    def test_one_symbol_three_insts_json_row(self, client, mock_influxdb_query):
        fake_result = _fake_data(
            "1m",
            {
                ("site1", "inst1"): ["time", "air_temp"],
                ("site2", "inst2"): ["time", "air_temp"],
                ("site3", "inst3"): ["time", "air_temp"],
            },
            9,
        )
        # row should be the default
        from metobsapi.util.data_responses import SYMBOL_TRANSLATIONS as st

        st = st.copy()
        st[("site1", "inst1")] = st[("aoss", "tower")]
        st[("site2", "inst2")] = st[("aoss", "tower")]
        st[("site3", "inst3")] = st[("aoss", "tower")]
        with mock.patch("metobsapi.util.data_responses.SYMBOL_TRANSLATIONS", st), mock_influxdb_query(fake_result):
            res = client.get(
                "/api/data.json?symbols=site1.inst1.air_temp:site2.inst2.air_temp:site3.inst3.air_temp&begin=-00:10:00"
            )
            res = json.loads(res.data.decode())
            assert res["code"] == 200
            assert res["num_results"] == 9
            assert res["results"]["symbols"] == ["site1.inst1.air_temp", "site2.inst2.air_temp", "site3.inst3.air_temp"]
            assert len(res["results"]["timestamps"]) == 9
            assert len(res["results"]["data"]) == 9
            assert len(res["results"]["data"][0]) == 3

    def test_one_symbol_csv(self, client, influxdb_air_temp_9_values):
        # row should be the default
        res = client.get("/api/data.csv?symbols=aoss.tower.air_temp&begin=-00:10:00")
        res = res.data.decode()
        # header, data, newline at end
        lines = res.split("\n")
        assert len(lines) == 5 + 9 + 1
        # time + 1 channel
        assert len(lines[5].split(",")) == 2
        assert "# code: 200" in res

    def test_one_symbol_xml(self, client, influxdb_air_temp_9_values):
        from xml.dom.minidom import parseString

        # row should be the default
        res = client.get("/api/data.xml?symbols=aoss.tower.air_temp&begin=-00:10:00")
        res = parseString(res.data.decode())
        # symbols: time and air_temp
        assert len(res.childNodes[0].childNodes[0].childNodes) == 2
        # data rows
        assert len(res.childNodes[0].childNodes[1].childNodes) == 9

    def test_three_symbol_csv(self, client, influxdb_3_symbols_9_values):
        """Test that multiple channels in a CSV file are structured properly."""
        # row should be the default
        res = client.get(
            "/api/data.csv?symbols=aoss.tower.air_temp:" "aoss.tower.rel_hum:aoss.tower.wind_speed&begin=-00:10:00"
        )
        res = res.data.decode()
        # header, data, newline at end
        lines = res.split("\n")
        assert len(lines) == 5 + 9 + 1
        # time + 3 channels
        assert len(lines[5].split(",")) == 4
        assert "# code: 200" in res

    def test_three_symbol_csv_repeat(self, client, influxdb_3_symbols_9_values):
        """Test that multiple channels in a CSV file are structured properly."""
        # row should be the default
        res = client.get(
            "/api/data.csv?symbols=aoss.tower.air_temp:" "aoss.tower.air_temp:aoss.tower.air_temp&begin=-00:10:00"
        )
        res = res.data.decode()
        # header, data, newline at end
        lines = res.split("\n")
        # header, data (one empty line), newline at end
        assert len(lines) == 5 + 1 + 1
        # time + 1 channel
        assert len(lines[5].split(",")) == 1
        assert "# code: 400" in res

    # @mock.patch('metobsapi.data_api.query')
    # def test_jsonp_bad_symbol_400(self, query_func, client):
    #     XXX: Not currently possible with flask-json
    #     r = _fake_data('1m', {('aoss', 'tower'): ['time', 'air_temp']}, 9)
    #     query_func.return_value = r
    #     # row should be the default
    #     res = client.get('/api/data.json?site=aoss&inst=tower&symbols=bad&begin=-00:10:00&callback=test')
    #     assert res.status_code == 400
    #     res = res.data.decode()
    #     assert res['code'] == 400