-
David Hoese authoredDavid Hoese authored
test_data_api.py 17.45 KiB
from __future__ import annotations
import contextlib
import json
import random
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, ContextManager
from unittest import mock
import numpy as np
import pandas as pd
import pytest
try:
import influxdb_client
except ImportError:
influxdb_client = None
try:
from influxdb.resultset import ResultSet
except ImportError:
ResultSet = None
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator
from metobsapi.util.query_influx import QueryResult
FAKE_DATA_START = datetime(2017, 3, 5, 19, 0, 0)
@pytest.fixture(params=["1", "2"], ids=["influxdbv1", "influxdbv2"])
def influxdb_version(request):
return request.param
@pytest.fixture()
def mock_influxdb_query(influxdb_version) -> Callable[[list[dict]], ContextManager[mock.Mock]]:
@contextlib.contextmanager
def _mock_influxdb_query_with_fake_data(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
with _mock_influxdb_library_availability(influxdb_version == "1", influxdb_version == "2"):
if influxdb_version == "1":
with _patch_v1_query(fake_result_data) as query_func:
yield query_func
elif influxdb_version == "2":
with _patch_v2_query(fake_result_data) as query_func:
yield query_func
return _mock_influxdb_query_with_fake_data
@contextlib.contextmanager
def _mock_influxdb_library_availability(v1_available: bool, v2_available: bool) -> Iterator[None]:
with _mock_if(not v1_available, "metobsapi.util.query_influx.InfluxDBClientv1"), _mock_if(
not v2_available,
"metobsapi.util.query_influx.InfluxDBClientv2",
):
yield None
@contextlib.contextmanager
def _mock_if(condition: bool, obj_path: str) -> Iterator[None]:
if not condition:
yield None
return
with mock.patch(obj_path, None):
yield None
@contextlib.contextmanager
def _patch_v1_query(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
if ResultSet is None:
pytest.skip("Version 1 'influxdb' dependency missing. Skipping test.")
with mock.patch("metobsapi.util.query_influx._query_influxdbv1") as query_func:
fake_data_for_v1 = _fake_data_to_v1_resultset(fake_result_data)
query_func.return_value = fake_data_for_v1
yield query_func
@contextlib.contextmanager
def _patch_v2_query(fake_result_data: list[dict]) -> Iterator[mock.Mock]:
if influxdb_client is None:
pytest.skip("Version 2 'influxdb-client' dependency missing. Skipping test.")
site_data_frames = []
for data_dict in fake_result_data:
data_row_dicts = []
for row_values in data_dict["values"]:
data_row = {}
for col_name, col_value in zip(data_dict["columns"], row_values, strict=True):
data_row[col_name] = np.nan if col_value is None else col_value
data_row.update(data_dict["tags"])
data_row_dicts.append(data_row)
single_df = pd.DataFrame(data_row_dicts, columns=["site", "inst"] + data_dict["columns"])
single_df = single_df.set_index("time")
single_df = single_df.dropna(axis=1, how="all")
site_data_frames.append(single_df)
with mock.patch("metobsapi.util.query_influx._query_influxdbv2") as query_func:
query_func.return_value = site_data_frames
yield query_func
def _fake_data_to_v1_resultset(fake_data: list[dict]) -> QueryResult:
t_format = "%Y-%m-%dT%H:%M:%SZ"
result_sets = []
for single_result in fake_data:
for value_row in single_result["values"]:
# convert datetime object to isoformat for default epoch format
value_row[0] = value_row[0].strftime(t_format)
result_sets.append(single_result)
return [ResultSet({"series": [single_result], "statement_id": 0}) for single_result in result_sets]
@pytest.fixture()
def influxdb_air_temp_9_values(mock_influxdb_query) -> Iterable[mock.Mock]:
fake_result = _fake_data("1m", {("aoss", "tower"): ["time", "air_temp"]}, 9)
with mock_influxdb_query(fake_result) as query_func:
yield query_func
@pytest.fixture()
def influxdb_3_symbols_9_values(mock_influxdb_query) -> Iterable[mock.Mock]:
fake_result = _fake_data("1m", {("aoss", "tower"): ["time", "air_temp", "rel_hum", "wind_speed"]}, 9)
with mock_influxdb_query(fake_result) as query_func:
yield query_func
@pytest.fixture()
def influxdb_wind_fields_9_values(mock_influxdb_query) -> Iterable[mock.Mock]:
fake_result = _fake_data(
"1m",
{("aoss", "tower"): ["time", "wind_speed", "wind_direction", "wind_east", "wind_north"]},
9,
)
with mock_influxdb_query(fake_result) as query_func:
yield query_func
def _fake_data(interval: str, symbols: dict[tuple[str, str], list[str]], num_vals: int) -> list[dict[str, Any]]:
now = FAKE_DATA_START
measurement_name = "metobs_" + interval
series = []
for (site, inst), columns in symbols.items():
tags = {"site": site, "inst": inst}
vals: list[list[datetime | float | None]] = []
for i in range(num_vals):
vals.append([now + timedelta(minutes=i), *_generate_fake_column_values(columns[1:])])
# make up some Nones/nulls (but not all the time)
r_idx = int(random.random() * len(columns) * 3)
# don't include time (index 0)
if 0 < r_idx < len(columns):
vals[-1][r_idx] = None
s = {
"name": measurement_name,
"columns": columns,
"tags": tags,
"values": vals,
}
series.append(s)
return series
def _generate_fake_column_values(columns: list[str]) -> list[float | None]:
rand_data: list[float | None] = [random.random()] * len(columns)
if "wind_direction" in columns:
# wind direction shouldn't exist in the database
rand_data[columns.index("wind_direction")] = None
return rand_data
@pytest.fixture()
def app(influxdb_version):
from metobsapi import app as real_app
real_app.config.update({"TESTING": True, "DEBUG": True})
if influxdb_version == "2":
real_app.config.update({"INFLUXDB_TOKEN": "1234"})
return real_app
@pytest.fixture()
def client(app):
return app.test_client()
def test_doc(client):
res = client.get("/api/data")
assert b"Data Request Application" in res.data
def test_bad_format(client):
res = client.get("/api/data.fake")
assert b"No data file format" in res.data
def test_bad_begin_json(client):
res = client.get("/api/data.json?symbols=air_temp&begin=blah")
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "timestamp" in res["message"]
def test_bad_order(client):
res = client.get("/api/data.json?order=blah&symbols=air_temp")
res = json.loads(res.data.decode())
assert "column" in res["message"]
assert "row" in res["message"]
def test_bad_epoch(client):
res = client.get("/api/data.json?epoch=blah&symbols=air_temp")
res = json.loads(res.data.decode())
assert "'h'" in res["message"]
assert "'m'" in res["message"]
assert "'s'" in res["message"]
assert "'u'" in res["message"]
def test_bad_interval(client):
res = client.get("/api/data.json?interval=blah&symbols=air_temp")
res = json.loads(res.data.decode())
assert "'1m'" in res["message"]
assert "'5m'" in res["message"]
assert "'1h'" in res["message"]
def test_missing_inst(client):
res = client.get("/api/data.json?site=X&symbols=air_temp&begin=-05:00:00")
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "'site'" in res["message"]
assert "'inst'" in res["message"]
def test_missing_site(client):
res = client.get("/api/data.json?inst=X&symbols=air_temp&begin=-05:00:00")
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "'site'" in res["message"]
assert "'inst'" in res["message"]
def test_missing_symbols(client):
res = client.get("/api/data.json?begin=-05:00:00")
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "'symbols'" in res["message"]
def test_nonexistent_symbol(client):
res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss.tower.air_temp:aoss.tower.noexist")
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "Unknown symbol" in res["message"]
def test_bad_symbol_site_inst(client):
res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss.tower.something.air_temp")
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "3 period-separated parts" in res["message"]
def test_symbol_unknown_site_inst(client):
res = client.get("/api/data.json?begin=-05:00:00&symbols=aoss2.tower.air_temp")
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "Unknown site/instrument" in res["message"]
def test_multiple_symbol(client):
res = client.get(
"/api/data.json?begin=-05:00:00&symbols=aoss.tower.air_temp:aoss.tower.wind_speed:aoss.tower.air_temp",
)
res = json.loads(res.data.decode())
assert res["code"] == 400
assert res["status"] == "error"
assert "multiple times" in res["message"]
def test_too_many_points(client):
res = client.get("/api/data.json?symbols=aoss.tower.air_temp&begin=1970-01-01T00:00:00")
assert res.status_code == 413
res = json.loads(res.data.decode())
assert "too many values" in res["message"]
assert res["code"] == 413
assert res["status"] == "fail"
@pytest.mark.parametrize(
("begin", "end", "exp_query_time"),
[
("-00:10:00", None, "-600s"),
(
FAKE_DATA_START.strftime("%Y-%m-%dT%H:%M:%S"),
(FAKE_DATA_START + timedelta(minutes=10)).strftime("%Y-%m-%dT%H:%M:%S"),
"range(start: 2017-03-05T19:00:00Z, stop: 2017-03-05T19:10:00Z)",
),
],
)
def test_shorthand_one_symbol_json_row(client, influxdb_air_temp_9_values, begin, end, exp_query_time):
end = "" if not end else f"&end={end}"
# row should be the default
res = client.get(f"/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin={begin}{end}")
res = json.loads(res.data.decode())
influxdb_air_temp_9_values.assert_called_once()
query_str_arg = influxdb_air_temp_9_values.call_args[0][0]
if not query_str_arg.startswith("SELECT"):
assert exp_query_time in query_str_arg
assert res["code"] == 200
assert res["num_results"] == 9
assert res["results"]["symbols"] == ["air_temp"]
assert len(res["results"]["timestamps"]) == 9
assert len(res["results"]["data"]) == 9
assert len(res["results"]["data"][0]) == 1
@pytest.mark.usefixtures("influxdb_air_temp_9_values")
def test_shorthand_one_symbol_json_column(client):
res = client.get("/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin=-00:10:00&order=column")
res = json.loads(res.data.decode())
assert res["code"] == 200
assert res["num_results"] == 9
assert "air_temp" in res["results"]["data"]
assert len(res["results"]["data"]["air_temp"]) == 9
assert len(res["results"]["timestamps"]) == 9
@pytest.mark.usefixtures("influxdb_3_symbols_9_values")
def test_shorthand_two_symbols_json_column_order_check(client):
res1 = _query_with_symbols(client, "air_temp:rel_hum:wind_speed")
res2 = _query_with_symbols(client, "rel_hum:air_temp:wind_speed")
assert res1["results"]["data"]["air_temp"] == res2["results"]["data"]["air_temp"]
assert res1["results"]["data"]["rel_hum"] == res2["results"]["data"]["rel_hum"]
@pytest.mark.usefixtures("influxdb_wind_fields_9_values")
def test_shorthand_wind_symbols_json_column_order_check(client):
res1 = _query_with_symbols(client, "wind_speed:wind_direction")
res2 = _query_with_symbols(client, "wind_direction:wind_speed")
assert res1["results"]["data"]["wind_speed"] == res2["results"]["data"]["wind_speed"]
assert res1["results"]["data"]["wind_direction"] == res2["results"]["data"]["wind_direction"]
def _query_with_symbols(client, symbols: str) -> dict:
res = client.get(f"/api/data.json?site=aoss&inst=tower&symbols={symbols}&begin=-00:10:00&order=column")
res = json.loads(res.data.decode())
assert res["code"] == 200
assert res["num_results"] == 9
for symbol in symbols.split(":"):
assert symbol in res["results"]["data"]
assert len(res["results"]["data"][symbol]) == 9
assert len(res["results"]["timestamps"]) == 9
return res
@pytest.mark.parametrize(
"symbols",
[
["aoss.tower.wind_speed", "aoss.tower.wind_direction"],
["wind_speed", "wind_direction"],
],
)
@pytest.mark.usefixtures("influxdb_wind_fields_9_values")
def test_wind_speed_direction_json(symbols, client):
symbol_param = ":".join(symbols)
site_inst_params = "&site=aoss&inst=tower" if "." not in symbols[0] else ""
res = client.get(f"/api/data.json?symbols={symbol_param}{site_inst_params}&begin=-00:10:00&order=column")
res = json.loads(res.data.decode())
assert res["code"] == 200
assert res["num_results"] == 9
for symbol_name in symbols:
assert symbol_name in res["results"]["data"]
assert not np.isnan(res["results"]["data"][symbol_name]).all()
assert len(list(res["results"]["data"].keys())) == len(symbols)
def test_one_symbol_two_insts_json_row(client, mock_influxdb_query):
fake_result = _fake_data(
"1m",
{
("aoss", "tower"): ["time", "air_temp"],
("mendota", "buoy"): ["time", "air_temp"],
},
9,
)
with mock_influxdb_query(fake_result):
# row should be the default
res = client.get("/api/data.json?symbols=aoss.tower.air_temp:mendota.buoy.air_temp&begin=-00:10:00")
res = json.loads(res.data.decode())
assert res["code"] == 200
assert res["num_results"] == 9
assert res["results"]["symbols"] == ["aoss.tower.air_temp", "mendota.buoy.air_temp"]
assert len(res["results"]["timestamps"]) == 9
assert len(res["results"]["data"]) == 9
assert len(res["results"]["data"][0]) == 2
def test_one_symbol_three_insts_json_row(client, mock_influxdb_query):
fake_result = _fake_data(
"1m",
{
("site1", "inst1"): ["time", "air_temp"],
("site2", "inst2"): ["time", "air_temp"],
("site3", "inst3"): ["time", "air_temp"],
},
9,
)
# row should be the default
from metobsapi.util.data_responses import SYMBOL_TRANSLATIONS
st = SYMBOL_TRANSLATIONS.copy()
st[("site1", "inst1")] = st[("aoss", "tower")]
st[("site2", "inst2")] = st[("aoss", "tower")]
st[("site3", "inst3")] = st[("aoss", "tower")]
with mock.patch("metobsapi.util.data_responses.SYMBOL_TRANSLATIONS", st), mock_influxdb_query(fake_result):
res = client.get(
"/api/data.json?symbols=site1.inst1.air_temp:site2.inst2.air_temp:site3.inst3.air_temp&begin=-00:10:00",
)
res = json.loads(res.data.decode())
assert res["code"] == 200
assert res["num_results"] == 9
assert res["results"]["symbols"] == ["site1.inst1.air_temp", "site2.inst2.air_temp", "site3.inst3.air_temp"]
assert len(res["results"]["timestamps"]) == 9
assert len(res["results"]["data"]) == 9
assert len(res["results"]["data"][0]) == 3
@pytest.mark.usefixtures("influxdb_air_temp_9_values")
def test_one_symbol_csv(client):
# row should be the default
res = client.get("/api/data.csv?symbols=aoss.tower.air_temp&begin=-00:10:00")
res = res.data.decode()
# header, data, newline at end
lines = res.split("\n")
assert len(lines) == 5 + 9 + 1
# time + 1 channel
assert len(lines[5].split(",")) == 2
assert "# code: 200" in res
@pytest.mark.usefixtures("influxdb_air_temp_9_values")
def test_one_symbol_xml(client):
from xml.dom.minidom import parseString
# row should be the default
res = client.get("/api/data.xml?symbols=aoss.tower.air_temp&begin=-00:10:00")
res = parseString(res.data.decode())
assert len(res.childNodes[0].childNodes[0].childNodes) == 2
# data rows
assert len(res.childNodes[0].childNodes[1].childNodes) == 9
@pytest.mark.usefixtures("influxdb_3_symbols_9_values")
def test_three_symbol_csv(client):
"""Test that multiple channels in a CSV file are structured properly."""
# row should be the default
res = client.get(
"/api/data.csv?symbols=aoss.tower.air_temp:aoss.tower.rel_hum:aoss.tower.wind_speed&begin=-00:10:00",
)
res = res.data.decode()
# header, data, newline at end
lines = res.split("\n")
assert len(lines) == 5 + 9 + 1
# time + 3 channels
assert len(lines[5].split(",")) == 4
assert "# code: 200" in res
@pytest.mark.usefixtures("influxdb_3_symbols_9_values")
def test_three_symbol_csv_repeat(client):
"""Test that multiple channels in a CSV file are structured properly."""
# row should be the default
res = client.get(
"/api/data.csv?symbols=aoss.tower.air_temp:aoss.tower.air_temp:aoss.tower.air_temp&begin=-00:10:00",
)
res = res.data.decode()
# header, data, newline at end
lines = res.split("\n")
# header, data (one empty line), newline at end
assert len(lines) == 5 + 1 + 1
# time + 1 channel
assert len(lines[5].split(",")) == 1
assert "# code: 400" in res