Skip to content
Snippets Groups Projects
test_data_api.py 17.45 KiB
from __future__ import annotations

import contextlib
import json
import random
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, ContextManager
from unittest import mock

import numpy as np
import pandas as pd
import pytest

try:
    import influxdb_client
except ImportError:
    influxdb_client = None

try:
    from influxdb.resultset import ResultSet
except ImportError:
    ResultSet = None

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable, Iterator

    from metobsapi.util.query_influx import QueryResult


FAKE_DATA_START = datetime(2017, 3, 5, 19, 0, 0)


@pytest.fixture(params=["1", "2"], ids=["influxdbv1", "influxdbv2"])
def influxdb_version(request):
    return request.param


@pytest.fixture()
def mock_influxdb_query(influxdb_version) -> Callable[[list[dict]], ContextManager[mock.Mock]]:
    @contextlib.contextmanager
    def _mock_influxdb_query_with_fake_data(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
        with _mock_influxdb_library_availability(influxdb_version == "1", influxdb_version == "2"):
            if influxdb_version == "1":
                with _patch_v1_query(fake_result_data) as query_func:
                    yield query_func
            elif influxdb_version == "2":
                with _patch_v2_query(fake_result_data) as query_func:
                    yield query_func

    return _mock_influxdb_query_with_fake_data


@contextlib.contextmanager
def _mock_influxdb_library_availability(v1_available: bool, v2_available: bool) -> Iterator[None]:
    with _mock_if(not v1_available, "metobsapi.util.query_influx.InfluxDBClientv1"), _mock_if(
        not v2_available,
        "metobsapi.util.query_influx.InfluxDBClientv2",
    ):
        yield None


@contextlib.contextmanager
def _mock_if(condition: bool, obj_path: str) -> Iterator[None]:
    if not condition:
        yield None
        return
    with mock.patch(obj_path, None):
        yield None


@contextlib.contextmanager
def _patch_v1_query(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
    if ResultSet is None:
        pytest.skip("Version 1 'influxdb' dependency missing. Skipping test.")
    with mock.patch("metobsapi.util.query_influx._query_influxdbv1") as query_func:
        fake_data_for_v1 = _fake_data_to_v1_resultset(fake_result_data)
        query_func.return_value = fake_data_for_v1
        yield query_func


@contextlib.contextmanager
def _patch_v2_query(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
    if influxdb_client is None:
        pytest.skip("Version 2 'influxdb-client' dependency missing. Skipping test.")

    site_data_frames = []
    for data_dict in fake_result_data:
        data_row_dicts = []
        for row_values in data_dict["values"]:
            data_row = {}
            for col_name, col_value in zip(data_dict["columns"], row_values, strict=True):
                data_row[col_name] = np.nan if col_value is None else col_value
            data_row.update(data_dict["tags"])
            data_row_dicts.append(data_row)

        single_df = pd.DataFrame(data_row_dicts, columns=["site", "inst"] + data_dict["columns"])
        single_df = single_df.set_index("time")
        single_df = single_df.dropna(axis=1, how="all")
        site_data_frames.append(single_df)

    with mock.patch("metobsapi.util.query_influx._query_influxdbv2") as query_func:
        query_func.return_value = site_data_frames
        yield query_func


def _fake_data_to_v1_resultset(fake_data: list[dict]) -> QueryResult:
    t_format = "%Y-%m-%dT%H:%M:%SZ"
    result_sets = []
    for single_result in fake_data:
        for value_row in single_result["values"]:
            # convert datetime object to isoformat for default epoch format
            value_row[0] = value_row[0].strftime(t_format)
        result_sets.append(single_result)
    return [ResultSet({"series": [single_result], "statement_id": 0}) for single_result in result_sets]


@pytest.fixture()
def influxdb_air_temp_9_values(mock_influxdb_query) -> Iterable[mock.Mock]:
    fake_result = _fake_data("1m", {("aoss", "tower"): ["time", "air_temp"]}, 9)
    with mock_influxdb_query(fake_result) as query_func:
        yield query_func


@pytest.fixture()
def influxdb_3_symbols_9_values(mock_influxdb_query) -> Iterable[mock.Mock]:
    fake_result = _fake_data("1m", {("aoss", "tower"): ["time", "air_temp", "rel_hum", "wind_speed"]}, 9)
    with mock_influxdb_query(fake_result) as query_func:
        yield query_func


@pytest.fixture()
def influxdb_wind_fields_9_values(mock_influxdb_query) -> Iterable[mock.Mock]:
    fake_result = _fake_data(
        "1m",
        {("aoss", "tower"): ["time", "wind_speed", "wind_direction", "wind_east", "wind_north"]},
        9,
    )
    with mock_influxdb_query(fake_result) as query_func:
        yield query_func


def _fake_data(interval: str, symbols: dict[tuple[str, str], list[str]], num_vals: int) -> list[dict[str, Any]]:
    now = FAKE_DATA_START
    measurement_name = "metobs_" + interval
    series = []
    for (site, inst), columns in symbols.items():
        tags = {"site": site, "inst": inst}
        vals: list[list[datetime | float | None]] = []
        for i in range(num_vals):
            vals.append([now + timedelta(minutes=i), *_generate_fake_column_values(columns[1:])])
            # make up some Nones/nulls (but not all the time)
            r_idx = int(random.random() * len(columns) * 3)
            # don't include time (index 0)
            if 0 < r_idx < len(columns):
                vals[-1][r_idx] = None
        s = {
            "name": measurement_name,
            "columns": columns,
            "tags": tags,
            "values": vals,
        }
        series.append(s)

    return series


def _generate_fake_column_values(columns: list[str]) -> list[float | None]:
    rand_data: list[float | None] = [random.random()] * len(columns)
    if "wind_direction" in columns:
        # wind direction shouldn't exist in the database
        rand_data[columns.index("wind_direction")] = None
    return rand_data


@pytest.fixture()
def app(influxdb_version):
    from metobsapi import app as real_app

    real_app.config.update({"TESTING": True, "DEBUG": True})
    if influxdb_version == "2":
        real_app.config.update({"INFLUXDB_TOKEN": "1234"})
    return real_app


@pytest.fixture()
def client(app):
    return app.test_client()


def test_doc(client):
    res = client.get("/api/data")
    assert b"Data Request Application" in res.data


def test_bad_format(client):
    res = client.get("/api/data.fake")
    assert b"No data file format" in res.data


def test_bad_begin_json(client):
    res = client.get("/api/data.json?symbols=air_temp&begin=blah")
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "timestamp" in res["message"]


def test_bad_order(client):
    res = client.get("/api/data.json?order=blah&symbols=air_temp")
    res = json.loads(res.data.decode())
    assert "column" in res["message"]
    assert "row" in res["message"]


def test_bad_epoch(client):
    res = client.get("/api/data.json?epoch=blah&symbols=air_temp")
    res = json.loads(res.data.decode())
    assert "'h'" in res["message"]
    assert "'m'" in res["message"]
    assert "'s'" in res["message"]
    assert "'u'" in res["message"]


def test_bad_interval(client):
    res = client.get("/api/data.json?interval=blah&symbols=air_temp")
    res = json.loads(res.data.decode())
    assert "'1m'" in res["message"]
    assert "'5m'" in res["message"]
    assert "'1h'" in res["message"]


def test_missing_inst(client):
    res = client.get("/api/data.json?site=X&symbols=air_temp&begin=-05:00:00")
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "'site'" in res["message"]
    assert "'inst'" in res["message"]


def test_missing_site(client):
    res = client.get("/api/data.json?inst=X&symbols=air_temp&begin=-05:00:00")
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "'site'" in res["message"]
    assert "'inst'" in res["message"]


def test_missing_symbols(client):
    res = client.get("/api/data.json?begin=-05:00:00")
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "'symbols'" in res["message"]


def test_nonexistent_symbol(client):
    res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss.tower.air_temp:aoss.tower.noexist")
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "Unknown symbol" in res["message"]


def test_bad_symbol_site_inst(client):
    res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss.tower.something.air_temp")
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "3 period-separated parts" in res["message"]


def test_symbol_unknown_site_inst(client):
    res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss2.tower.air_temp")
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "Unknown site/instrument" in res["message"]


def test_multiple_symbol(client):
    res = client.get(
        "/api/data.json?begin=-05:00:00&symbols=aoss.tower.air_temp:aoss.tower.wind_speed:aoss.tower.air_temp",
    )
    res = json.loads(res.data.decode())
    assert res["code"] == 400
    assert res["status"] == "error"
    assert "multiple times" in res["message"]


def test_too_many_points(client):
    res = client.get("/api/data.json?symbols=aoss.tower.air_temp&begin=1970-01-01T00:00:00")
    assert res.status_code == 413
    res = json.loads(res.data.decode())
    assert "too many values" in res["message"]
    assert res["code"] == 413
    assert res["status"] == "fail"


@pytest.mark.parametrize(
    ("begin", "end", "exp_query_time"),
    [
        ("-00:10:00", None, "-600s"),
        (
            FAKE_DATA_START.strftime("%Y-%m-%dT%H:%M:%S"),
            (FAKE_DATA_START + timedelta(minutes=10)).strftime("%Y-%m-%dT%H:%M:%S"),
            "range(start: 2017-03-05T19:00:00Z, stop: 2017-03-05T19:10:00Z)",
        ),
    ],
)
def test_shorthand_one_symbol_json_row(client, influxdb_air_temp_9_values, begin, end, exp_query_time):
    end = "" if not end else f"&end={end}"
    # row should be the default
    res = client.get(f"/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin={begin}{end}")
    res = json.loads(res.data.decode())
    influxdb_air_temp_9_values.assert_called_once()
    query_str_arg = influxdb_air_temp_9_values.call_args[0][0]
    if not query_str_arg.startswith("SELECT"):
        assert exp_query_time in query_str_arg
    assert res["code"] == 200
    assert res["num_results"] == 9
    assert res["results"]["symbols"] == ["air_temp"]
    assert len(res["results"]["timestamps"]) == 9
    assert len(res["results"]["data"]) == 9
    assert len(res["results"]["data"][0]) == 1


@pytest.mark.usefixtures("influxdb_air_temp_9_values")
def test_shorthand_one_symbol_json_column(client):
    res = client.get("/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin=-00:10:00&order=column")
    res = json.loads(res.data.decode())
    assert res["code"] == 200
    assert res["num_results"] == 9
    assert "air_temp" in res["results"]["data"]
    assert len(res["results"]["data"]["air_temp"]) == 9
    assert len(res["results"]["timestamps"]) == 9


@pytest.mark.usefixtures("influxdb_3_symbols_9_values")
def test_shorthand_two_symbols_json_column_order_check(client):
    res1 = _query_with_symbols(client, "air_temp:rel_hum:wind_speed")
    res2 = _query_with_symbols(client, "rel_hum:air_temp:wind_speed")
    assert res1["results"]["data"]["air_temp"] == res2["results"]["data"]["air_temp"]
    assert res1["results"]["data"]["rel_hum"] == res2["results"]["data"]["rel_hum"]


@pytest.mark.usefixtures("influxdb_wind_fields_9_values")
def test_shorthand_wind_symbols_json_column_order_check(client):
    res1 = _query_with_symbols(client, "wind_speed:wind_direction")
    res2 = _query_with_symbols(client, "wind_direction:wind_speed")
    assert res1["results"]["data"]["wind_speed"] == res2["results"]["data"]["wind_speed"]
    assert res1["results"]["data"]["wind_direction"] == res2["results"]["data"]["wind_direction"]


def _query_with_symbols(client, symbols: str) -> dict:
    res = client.get(f"/api/data.json?site=aoss&inst=tower&symbols={symbols}&begin=-00:10:00&order=column")
    res = json.loads(res.data.decode())
    assert res["code"] == 200
    assert res["num_results"] == 9
    for symbol in symbols.split(":"):
        assert symbol in res["results"]["data"]
        assert len(res["results"]["data"][symbol]) == 9
        assert len(res["results"]["timestamps"]) == 9
    return res


@pytest.mark.parametrize(
    "symbols",
    [
        ["aoss.tower.wind_speed", "aoss.tower.wind_direction"],
        ["wind_speed", "wind_direction"],
    ],
)
@pytest.mark.usefixtures("influxdb_wind_fields_9_values")
def test_wind_speed_direction_json(symbols, client):
    symbol_param = ":".join(symbols)
    site_inst_params = "&site=aoss&inst=tower" if "." not in symbols[0] else ""

    res = client.get(f"/api/data.json?symbols={symbol_param}{site_inst_params}&begin=-00:10:00&order=column")
    res = json.loads(res.data.decode())
    assert res["code"] == 200
    assert res["num_results"] == 9
    for symbol_name in symbols:
        assert symbol_name in res["results"]["data"]
        assert not np.isnan(res["results"]["data"][symbol_name]).all()
    assert len(list(res["results"]["data"].keys())) == len(symbols)


def test_one_symbol_two_insts_json_row(client, mock_influxdb_query):
    fake_result = _fake_data(
        "1m",
        {
            ("aoss", "tower"): ["time", "air_temp"],
            ("mendota", "buoy"): ["time", "air_temp"],
        },
        9,
    )
    with mock_influxdb_query(fake_result):
        # row should be the default
        res = client.get("/api/data.json?symbols=aoss.tower.air_temp:mendota.buoy.air_temp&begin=-00:10:00")
    res = json.loads(res.data.decode())
    assert res["code"] == 200
    assert res["num_results"] == 9
    assert res["results"]["symbols"] == ["aoss.tower.air_temp", "mendota.buoy.air_temp"]
    assert len(res["results"]["timestamps"]) == 9
    assert len(res["results"]["data"]) == 9
    assert len(res["results"]["data"][0]) == 2


def test_one_symbol_three_insts_json_row(client, mock_influxdb_query):
    fake_result = _fake_data(
        "1m",
        {
            ("site1", "inst1"): ["time", "air_temp"],
            ("site2", "inst2"): ["time", "air_temp"],
            ("site3", "inst3"): ["time", "air_temp"],
        },
        9,
    )
    # row should be the default
    from metobsapi.util.data_responses import SYMBOL_TRANSLATIONS

    st = SYMBOL_TRANSLATIONS.copy()
    st[("site1", "inst1")] = st[("aoss", "tower")]
    st[("site2", "inst2")] = st[("aoss", "tower")]
    st[("site3", "inst3")] = st[("aoss", "tower")]
    with mock.patch("metobsapi.util.data_responses.SYMBOL_TRANSLATIONS", st), mock_influxdb_query(fake_result):
        res = client.get(
            "/api/data.json?symbols=site1.inst1.air_temp:site2.inst2.air_temp:site3.inst3.air_temp&begin=-00:10:00",
        )
        res = json.loads(res.data.decode())
        assert res["code"] == 200
        assert res["num_results"] == 9
        assert res["results"]["symbols"] == ["site1.inst1.air_temp", "site2.inst2.air_temp", "site3.inst3.air_temp"]
        assert len(res["results"]["timestamps"]) == 9
        assert len(res["results"]["data"]) == 9
        assert len(res["results"]["data"][0]) == 3


@pytest.mark.usefixtures("influxdb_air_temp_9_values")
def test_one_symbol_csv(client):
    # row should be the default
    res = client.get("/api/data.csv?symbols=aoss.tower.air_temp&begin=-00:10:00")
    res = res.data.decode()
    # header, data, newline at end
    lines = res.split("\n")
    assert len(lines) == 5 + 9 + 1
    # time + 1 channel
    assert len(lines[5].split(",")) == 2
    assert "# code: 200" in res


@pytest.mark.usefixtures("influxdb_air_temp_9_values")
def test_one_symbol_xml(client):
    from xml.dom.minidom import parseString

    # row should be the default
    res = client.get("/api/data.xml?symbols=aoss.tower.air_temp&begin=-00:10:00")
    res = parseString(res.data.decode())
    assert len(res.childNodes[0].childNodes[0].childNodes) == 2
    # data rows
    assert len(res.childNodes[0].childNodes[1].childNodes) == 9


@pytest.mark.usefixtures("influxdb_3_symbols_9_values")
def test_three_symbol_csv(client):
    """Test that multiple channels in a CSV file are structured properly."""
    # row should be the default
    res = client.get(
        "/api/data.csv?symbols=aoss.tower.air_temp:aoss.tower.rel_hum:aoss.tower.wind_speed&begin=-00:10:00",
    )
    res = res.data.decode()
    # header, data, newline at end
    lines = res.split("\n")
    assert len(lines) == 5 + 9 + 1
    # time + 3 channels
    assert len(lines[5].split(",")) == 4
    assert "# code: 200" in res


@pytest.mark.usefixtures("influxdb_3_symbols_9_values")
def test_three_symbol_csv_repeat(client):
    """Test that multiple channels in a CSV file are structured properly."""
    # row should be the default
    res = client.get(
        "/api/data.csv?symbols=aoss.tower.air_temp:aoss.tower.air_temp:aoss.tower.air_temp&begin=-00:10:00",
    )
    res = res.data.decode()
    # header, data, newline at end
    lines = res.split("\n")
    # header, data (one empty line), newline at end
    assert len(lines) == 5 + 1 + 1
    # time + 1 channel
    assert len(lines[5].split(",")) == 1
    assert "# code: 400" in res