diff --git a/metobsapi/data_api.py b/metobsapi/data_api.py index 8e5b09f1ab37ce75b85af1aea7ce3812fe2aae84..2e817f8268e7f448cc4f34b74649321456ec9388 100644 --- a/metobsapi/data_api.py +++ b/metobsapi/data_api.py @@ -1,14 +1,16 @@ import logging -from collections.abc import Iterable +from collections.abc import Callable, Iterable from datetime import datetime, timedelta +from functools import partial # Security: Document is only used for creating an XML document, not parsing one from xml.dom.minidom import Document import numpy as np import pandas as pd -from flask import Response, render_template +from flask import Response from flask_json import as_json_p +from werkzeug.exceptions import BadRequest from metobsapi.util import data_responses from metobsapi.util.query_influx import query @@ -162,8 +164,9 @@ def calc_num_records(begin, end, interval): return diff / data_responses.INTERVALS[interval] -def handle_csv(frame, epoch, sep=",", message="", code=200, status="success", **kwargs): - del kwargs # other kwargs used by other handlers +def handle_csv(frame, epoch, response_info, formatting_kwargs): + code = response_info.get("code", 200) + sep = formatting_kwargs.get("sep", ",") output = """# status: {status} # code: {code:d} # message: {message} @@ -187,9 +190,9 @@ def handle_csv(frame, epoch, sep=",", message="", code=200, status="success", ** output = output.format( sep=sep, - status=status, + status=response_info.get("status_info", "success"), code=code, - message=message, + message=response_info.get("message", ""), num_results=frame.shape[0] if frame is not None else 0, epoch_str=epoch_str, symbol_list=sep.join(frame.columns) if frame is not None else "", @@ -202,9 +205,15 @@ def handle_csv(frame, epoch, sep=",", message="", code=200, status="success", ** @as_json_p(optional=True) -def handle_json(frame, epoch, order="columns", message="", code=200, status="success", **kwargs): - del kwargs # other kwargs used by other handlers +def handle_json(frame, epoch, response_info, formatting_kwargs): package = {} + order = formatting_kwargs.get("order", "columns") + if order not in ("column", "row"): + # bad order + frame = None + response_info["code"], response_info["message"] = data_responses.ERROR_MESSAGES["bad_order"] + response_info["status"] = "error" + code = response_info.get("code", 200) if frame is not None and not frame.empty: # force conversion to float types so they can be json'd @@ -233,8 +242,8 @@ def handle_json(frame, epoch, order="columns", message="", code=200, status="suc package["symbols"] = [] output = { - "status": status, - "message": message, + "status": response_info.get("status", "success"), + "message": response_info.get("message", ""), "code": code, "num_results": frame.shape[0] if frame is not None else 0, "results": package, @@ -242,15 +251,16 @@ def handle_json(frame, epoch, order="columns", message="", code=200, status="suc return output, code -def handle_xml(frame, epoch, sep=",", message="", code=200, status="success", **kwargs): - del kwargs # other kwargs used by other handlers +def handle_xml(frame, epoch, response_info, formatting_kwargs): doc = Document() header = "metobs" + code = response_info.get("code", 200) + sep = formatting_kwargs.get("sep", ",") head = doc.createElement(header) - head.setAttribute("status", status) + head.setAttribute("status", response_info.get("status", "success")) head.setAttribute("code", str(code)) - head.setAttribute("message", message) + head.setAttribute("message", response_info.get("message", "")) head.setAttribute("num_results", str(frame.shape[0]) if frame is not None else str(0)) head.setAttribute("seperator", sep) data_elem = doc.createElement("data") @@ -313,10 +323,10 @@ def _time_column_to_epoch_str(time_val: datetime | np.datetime64 | pd.Timestamp return f"{time_in_units:d}" -def handle_error(fmt, error_str): - handler = RESPONSE_HANDLERS[fmt] +def create_formatted_error(handler_func, error_str): err_code, err_msg = data_responses.ERROR_MESSAGES.get(error_str, (400, error_str)) - return handler(None, None, message=err_msg, code=err_code, status="error") + response_info = {"message": err_msg, "code": err_code, "status": "error"} + raise FormattedBadRequest(handler_func=handler_func, response_info=response_info) RESPONSE_HANDLERS = { @@ -326,25 +336,50 @@ RESPONSE_HANDLERS = { } -def modify_data(fmt, begin, end, site, inst, symbols, interval, sep=",", order="columns", epoch=None): +def get_format_handler(fmt: str, formatting_kwargs: dict) -> Callable | tuple: if fmt not in RESPONSE_HANDLERS: - return render_template("400.html", format=fmt), 400 - if not interval: - interval = "1m" + raise BadHandlerFormat(bad_format=fmt) + handler_func = RESPONSE_HANDLERS[fmt] + return partial(handler_func, formatting_kwargs=formatting_kwargs) + + +class BadHandlerFormat(BadRequest): + """The provided file format is unknown.""" + + def __init__(self, *, bad_format: str, **kwargs): + self.bad_format = bad_format + super().__init__(**kwargs) + +class FormattedBadRequest(BadRequest): + """A request error that we want in the requested format.""" + + def __init__(self, *, handler_func: Callable, response_info: dict, **kwargs): + super().__init__(**kwargs) + self.handler_func = handler_func + self.response_info = response_info + + +def parse_time_parameters(handler_func, begin, end, interval, epoch): try: begin, end = _convert_begin_and_end(begin, end) - _check_query_parameters(order, epoch, symbols, interval) + _check_query_parameters(epoch, interval) + except ValueError as e: + raise create_formatted_error(handler_func, str(e)) from e + return begin, end, interval, epoch + + +def modify_data(handler_func, time_parameters, site, inst, symbols): + try: short_symbols, fully_qualified_symbols = _parse_symbol_names(site, inst, symbols) influx_to_requested_symbols, influx_symbols = handle_symbols(fully_qualified_symbols) except ValueError as e: - return handle_error(fmt, str(e)) + raise create_formatted_error(handler_func, str(e)) from e - data_frame, response_info = _query_time_series_db(begin, end, interval, influx_symbols, epoch) + data_frame, response_info = _query_time_series_db(time_parameters, influx_symbols) data_frame = handle_influxdb_result(data_frame, influx_to_requested_symbols) data_frame = _reorder_and_rename_result_dataframe(data_frame, fully_qualified_symbols, short_symbols) - handler = RESPONSE_HANDLERS[fmt] - return handler(data_frame, epoch, sep=sep, order=order, **response_info) + return handler_func(data_frame, time_parameters[3], response_info) def _convert_begin_and_end(begin, end) -> tuple[datetime | timedelta, datetime | timedelta]: @@ -357,18 +392,16 @@ def _convert_begin_and_end(begin, end) -> tuple[datetime | timedelta, datetime | return begin, end -def _check_query_parameters(order, epoch, symbols, interval): - if order not in ("column", "row"): - raise ValueError("bad_order") +def _check_query_parameters(epoch, interval): if epoch and epoch not in data_responses.epoch_translation: raise ValueError("bad_epoch") - if not symbols: - raise ValueError("missing_symbols") if interval not in data_responses.INTERVALS: raise ValueError("bad_interval") def _parse_symbol_names(site: str, inst: str, symbols: str) -> tuple[list[str] | None, list[str]]: + if not symbols: + raise ValueError("missing_symbols") if site and inst: # shorthand for symbols that all use the same site and inst short_symbols = symbols.split(":") @@ -382,8 +415,8 @@ def _parse_symbol_names(site: str, inst: str, symbols: str) -> tuple[list[str] | return short_symbols, fully_qualified_symbols -def _query_time_series_db(begin, end, interval, influx_symbols, epoch): - if calc_num_records(begin, end, interval) > data_responses.RESPONSES_LIMIT: +def _query_time_series_db(time_parameters, influx_symbols): + if calc_num_records(time_parameters[0], time_parameters[1], time_parameters[2]) > data_responses.RESPONSES_LIMIT: message = "Request will return too many values, please use files API" code = 413 status = "fail" @@ -392,7 +425,7 @@ def _query_time_series_db(begin, end, interval, influx_symbols, epoch): message = "" code = 200 status = "success" - result = query(influx_symbols, begin, end, interval, epoch) + result = query(influx_symbols, *time_parameters) response_info = {"message": message, "code": code, "status": status} return result, response_info diff --git a/metobsapi/server.py b/metobsapi/server.py index bd6260305bb71e40b4cc27e8f384698ca172ca78..a279b17d3867432d0e47d85740eb0355a35d7062 100644 --- a/metobsapi/server.py +++ b/metobsapi/server.py @@ -63,6 +63,16 @@ def data_index(): return render_template("data_index.html") +@app.errorhandler(data_api.BadHandlerFormat) +def bad_request_format(err: data_api.BadHandlerFormat): + return render_template("400.html", format=err.bad_format), 400 + + +@app.errorhandler(data_api.FormattedBadRequest) +def formatted_bad_request(err: data_api.FormattedBadRequest): + return err.handler_func(None, None, err.response_info) + + @app.errorhandler(404) def page_not_found(_): return render_template("404.html"), 404 @@ -86,12 +96,16 @@ def get_data(fmt): site = request.args.get("site") inst = request.args.get("inst") symbols = request.args.get("symbols") - interval = request.args.get("interval") + interval = request.args.get("interval", "1m") sep = request.args.get("sep", ",") order = request.args.get("order", "row") epoch = request.args.get("epoch") - return data_api.modify_data(fmt, begin_time, end_time, site, inst, symbols, interval, sep, order, epoch) + handler_func = data_api.get_format_handler(fmt, {"sep": sep, "order": order}) + if isinstance(handler_func, tuple): + return handler_func + time_parameters = data_api.parse_time_parameters(handler_func, begin_time, end_time, interval, epoch) + return data_api.modify_data(handler_func, time_parameters, site, inst, symbols) @app.route("/api/files.<fmt>", methods=["GET"]) diff --git a/metobsapi/tests/test_data_api.py b/metobsapi/tests/test_data_api.py index fb668cbe59b2309fce16b3c5fb202808bd64f38f..69009aa5f72dabac7080d4886cda914a53525762 100644 --- a/metobsapi/tests/test_data_api.py +++ b/metobsapi/tests/test_data_api.py @@ -295,7 +295,8 @@ def test_too_many_points(client): assert res["status"] == "fail" -def test_shorthand_one_symbol_json_row(client, influxdb_air_temp_9_values): +@pytest.mark.usefixtures("influxdb_air_temp_9_values") +def test_shorthand_one_symbol_json_row(client): # row should be the default res = client.get("/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin=-00:10:00") res = json.loads(res.data.decode()) @@ -307,7 +308,8 @@ def test_shorthand_one_symbol_json_row(client, influxdb_air_temp_9_values): assert len(res["results"]["data"][0]) == 1 -def test_shorthand_one_symbol_json_column(client, influxdb_air_temp_9_values): +@pytest.mark.usefixtures("influxdb_air_temp_9_values") +def test_shorthand_one_symbol_json_column(client): res = client.get("/api/data.json?site=aoss&inst=tower&symbols=air_temp&begin=-00:10:00&order=column") res = json.loads(res.data.decode()) assert res["code"] == 200 @@ -324,7 +326,8 @@ def test_shorthand_one_symbol_json_column(client, influxdb_air_temp_9_values): ["wind_speed", "wind_direction"], ], ) -def test_wind_speed_direction_json(symbols, client, influxdb_wind_fields_9_values): +@pytest.mark.usefixtures("influxdb_wind_fields_9_values") +def test_wind_speed_direction_json(symbols, client): symbol_param = ":".join(symbols) site_inst_params = "&site=aoss&inst=tower" if "." not in symbols[0] else "" @@ -389,7 +392,8 @@ def test_one_symbol_three_insts_json_row(client, mock_influxdb_query): assert len(res["results"]["data"][0]) == 3 -def test_one_symbol_csv(client, influxdb_air_temp_9_values): +@pytest.mark.usefixtures("influxdb_air_temp_9_values") +def test_one_symbol_csv(client): # row should be the default res = client.get("/api/data.csv?symbols=aoss.tower.air_temp&begin=-00:10:00") res = res.data.decode() @@ -401,7 +405,8 @@ def test_one_symbol_csv(client, influxdb_air_temp_9_values): assert "# code: 200" in res -def test_one_symbol_xml(client, influxdb_air_temp_9_values): +@pytest.mark.usefixtures("influxdb_air_temp_9_values") +def test_one_symbol_xml(client): from xml.dom.minidom import parseString # row should be the default @@ -412,7 +417,8 @@ def test_one_symbol_xml(client, influxdb_air_temp_9_values): assert len(res.childNodes[0].childNodes[1].childNodes) == 9 -def test_three_symbol_csv(client, influxdb_3_symbols_9_values): +@pytest.mark.usefixtures("influxdb_3_symbols_9_values") +def test_three_symbol_csv(client): """Test that multiple channels in a CSV file are structured properly.""" # row should be the default res = client.get( @@ -427,7 +433,8 @@ def test_three_symbol_csv(client, influxdb_3_symbols_9_values): assert "# code: 200" in res -def test_three_symbol_csv_repeat(client, influxdb_3_symbols_9_values): +@pytest.mark.usefixtures("influxdb_3_symbols_9_values") +def test_three_symbol_csv_repeat(client): """Test that multiple channels in a CSV file are structured properly.""" # row should be the default res = client.get( @@ -441,7 +448,3 @@ def test_three_symbol_csv_repeat(client, influxdb_3_symbols_9_values): # time + 1 channel assert len(lines[5].split(",")) == 1 assert "# code: 400" in res - - -def test_temporary(tmp_path): - assert True diff --git a/pyproject.toml b/pyproject.toml index b4e7b00fc5c8d0fc0e2a7afa8e6b92e005d6dc0e..5b0240281239c3fd44bf18c362fcefbfb1e7ad99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ select = [ ignore = ["D100", "D101", "D102", "D103", "D104", "D106", "D107", "D203", "D213", "B008"] [tool.ruff.per-file-ignores] -"metobsapi/tests/*" = ["S", "ARG001", "ARG002", "PLR2004"] +"metobsapi/tests/*" = ["S", "PLR2004"] [tool.mypy] python_version = "3.10"