Skip to content
Snippets Groups Projects
data_api.py 16.70 KiB
import logging
from collections.abc import Callable, Iterable
from datetime import datetime, timedelta
from functools import partial

# Security: Document is only used for creating an XML document, not parsing one
from xml.dom.minidom import Document

import numpy as np
import pandas as pd
from flask import Response
from flask_json import as_json_p
from urllib3.exceptions import ReadTimeoutError
from werkzeug.exceptions import BadRequest

from metobsapi.util import data_responses
from metobsapi.util.query_influx import query

LOG = logging.getLogger(__name__)

LENGTH_DATE_STR = 10

# NOTE: These use the influxdb column name (default API name)
#     In most cases these names are equivalent
ROUNDING = {
    "aoss.tower.rel_hum": 0,
    "aoss.tower.wind_direction": 0,
    "aoss.tower.accum_precip": 2,
    "mendota.buoy.rel_hum": 0,
    "mendota.buoy.wind_direction": 0,
    "mendota.buoy.chlorophyll_ysi": 2,
    "mendota.buoy.phycocyanin_ysi": 2,
}


def handle_date(date: str) -> datetime:
    datetime_fmt = "%Y-%m-%d" if len(date) == LENGTH_DATE_STR else "%Y-%m-%dT%H:%M:%S"
    try:
        return datetime.strptime(date, datetime_fmt)
    except ValueError as e:
        msg = f"Malformed date string '{date}'"
        LOG.warning(msg)
        raise ValueError(msg) from e


def handle_time_string(date_string):
    if not date_string:
        return None
    if date_string[0] == "-":
        times = [float(x) for x in date_string[1:].split(":")]
        diff = timedelta(hours=times[0], minutes=times[1], seconds=times[2])
        return diff
    return handle_date(date_string)


def handle_symbols(symbols: list[str]):
    influx_symbols: dict[tuple[str, str], list[str]] = {}
    influx_to_requested_symbols: dict[str, str | None] = {}
    handled_symbols = set()

    add_winds = set()
    for symbol in symbols:
        if symbol in handled_symbols:
            msg = f"Symbol listed multiple times: {symbol}"
            raise ValueError(msg)
        handled_symbols.add(symbol)

        try:
            site, inst, s = symbol.split(".")
        except ValueError as e:
            msg = f"Symbols must have 3 period-separated parts: {symbol}"
            raise ValueError(msg) from e

        influx_name = _get_column_name_in_influxdb(site, inst, s)
        # this assignment must happen before we continue the loop after
        # detecting 'wind_direction', otherwise we don't know we need to readd it
        influx_to_requested_symbols[f"{site}.{inst}.{influx_name}"] = symbol

        influx_symbols.setdefault((site, inst), [])
        if s == "wind_direction":
            add_winds.add((site, inst))
            continue
        influx_symbols[(site, inst)].append(influx_name)

    # Add the symbols needed to compute the wind_speed and wind_direction
    for site, inst in add_winds:
        influx_to_requested_symbols[f"{site}.{inst}.wind_east"] = None
        influx_to_requested_symbols[f"{site}.{inst}.wind_north"] = None
        influx_symbols[(site, inst)].extend(("wind_east", "wind_north"))

    return influx_to_requested_symbols, influx_symbols


def _get_column_name_in_influxdb(site: str, inst: str, api_symbol: str) -> str:
    api_to_influx_rename_map = data_responses.SYMBOL_TRANSLATIONS.get((site, inst))
    if not api_to_influx_rename_map:
        msg = f"Unknown site/instrument: {site},{inst}"
        raise ValueError(msg)
    influx_name = api_to_influx_rename_map.get(api_symbol)

    if api_symbol != "wind_direction" and influx_name is None:
        msg = f"Unknown symbol: {site}.{inst}.{api_symbol}"
        raise ValueError(msg)
    return influx_name or api_symbol


def handle_influxdb_result(data_frame, influx_to_requested_symbols):
    if data_frame is None:
        # invalid query
        valid_requested_symbols = [symbol for symbol in influx_to_requested_symbols.values() if symbol is not None]
        return _create_empty_dataframe_for_columns(valid_requested_symbols)

    data_frame = _convert_wind_vectors_to_direction_if_needed(data_frame)
    data_frame = data_frame.round({s: ROUNDING.get(s, 1) for s in data_frame.columns})
    # reorder columns to match the order the user requested
    new_column_order = [influx_name for influx_name in influx_to_requested_symbols if influx_name in data_frame.columns]
    data_frame = data_frame[new_column_order]
    # rename columns to API names requested by the user
    # User could request "<site>.<inst>.name" or just "name"
    renamed_columns = [influx_to_requested_symbols[symbol] for symbol in data_frame.columns]
    data_frame.columns = renamed_columns
    return data_frame


def _create_empty_dataframe_for_columns(columns: list[str]) -> pd.DataFrame:
    columns = ["time", *columns]
    data_frame = pd.DataFrame(columns=columns)
    data_frame = data_frame.set_index("time")
    return data_frame.fillna(value=np.nan)


def _convert_wind_vectors_to_direction_if_needed(data_frame: pd.DataFrame) -> pd.DataFrame:
    for weast_col_name, wnorth_col_name, wdir_col_name in _wind_vector_column_groups(data_frame.columns):
        data_frame[wdir_col_name] = np.rad2deg(np.arctan2(data_frame[weast_col_name], data_frame[wnorth_col_name]))
        data_frame[wdir_col_name] = data_frame[wdir_col_name].where(
            data_frame[wdir_col_name] > 0,
            data_frame[wdir_col_name] + 360.0,
        )
        data_frame = data_frame.drop(columns=[weast_col_name, wnorth_col_name])
    return data_frame


def _wind_vector_column_groups(columns: list[str]) -> Iterable[tuple[str, str, str]]:
    for wind_north_column in _all_wind_north_columns(columns):
        # assume column name is <site>.<inst>.wind_north
        site, inst = wind_north_column.split(".")[:2]
        wind_east_column = f"{site}.{inst}.wind_east"
        wind_direction_column = f"{site}.{inst}.wind_direction"
        if wind_east_column not in columns:
            continue
        yield wind_east_column, wind_north_column, wind_direction_column


def _all_wind_north_columns(columns: list[str]) -> list[str]:
    return [col_name for col_name in columns if "wind_north" in col_name]


def calc_num_records(begin, end, interval):
    now = datetime.utcnow()
    if begin is None:
        begin = now
    elif isinstance(begin, timedelta):
        begin = now - begin
    if end is None:
        end = now
    elif isinstance(end, timedelta):
        end = now - end

    diff = (end - begin).total_seconds()
    return diff / data_responses.INTERVALS[interval]


def handle_csv(frame, epoch, response_info, formatting_kwargs):
    code = response_info.get("code", 200)
    sep = formatting_kwargs.get("sep", ",")
    output = """# status: {status}
# code: {code:d}
# message: {message}
# num_results: {num_results:d}
# fields: {epoch_str}{sep}{symbol_list}
{symbol_data}
"""

    data_lines = []
    line_format = sep.join(["{time}", "{symbols}"])
    if frame is not None and not frame.empty:
        for index_time, row in frame.iterrows():
            time_str = _time_column_to_epoch_str(index_time, epoch)
            line = line_format.format(time=time_str, symbols=sep.join(str(x) for x in row.to_list()))
            data_lines.append(line)

    if not epoch:
        epoch_str = "%Y-%m-%dT%H:%M:%SZ"
    else:
        epoch_str = data_responses.epoch_translation[epoch] + " since epoch (1970-01-01 00:00:00)"

    output = output.format(
        sep=sep,
        status=response_info.get("status_info", "success"),
        code=code,
        message=response_info.get("message", ""),
        num_results=frame.shape[0] if frame is not None else 0,
        epoch_str=epoch_str,
        symbol_list=sep.join(frame.columns) if frame is not None else "",
        symbol_data="\n".join(data_lines),
    )

    res = Response(output, mimetype="text/csv")
    res.headers.set("Content-Disposition", "attachment", filename="data.csv")
    return res, code


@as_json_p(optional=True)
def handle_json(frame, epoch, response_info, formatting_kwargs):
    package = {}
    order = formatting_kwargs.get("order", "columns")
    code = response_info.get("code", 200)

    if frame is not None and not frame.empty:
        # force conversion to float types so they can be json'd
        for column, data_type in zip(frame.columns, frame.dtypes.values, strict=True):
            if issubclass(data_type.type, np.integer):
                frame[column] = frame[column].astype(float)
        # replace NaNs with None
        frame = frame.where(pd.notna(frame), None)

        new_stamps = []
        for stamp in frame.index.to_list():
            new_stamps.append(_time_column_to_epoch_str(stamp, epoch))
        package["timestamps"] = new_stamps

        if order == "column":
            package["data"] = dict(frame)
        else:
            package["symbols"] = frame.columns
            package["data"] = list(frame.to_dict(orient="split")["data"])
    else:
        package["timestamps"] = []
        if order == "column":
            package["data"] = {}
        else:
            package["data"] = []
            package["symbols"] = []

    output = {
        "status": response_info.get("status", "success"),
        "message": response_info.get("message", ""),
        "code": code,
        "num_results": frame.shape[0] if frame is not None else 0,
        "results": package,
    }
    return output, code


def handle_xml(frame, epoch, response_info, formatting_kwargs):
    doc = Document()
    header = "metobs"
    code = response_info.get("code", 200)
    sep = formatting_kwargs.get("sep", ",")

    head = doc.createElement(header)
    head.setAttribute("status", response_info.get("status", "success"))
    head.setAttribute("code", str(code))
    head.setAttribute("message", response_info.get("message", ""))
    head.setAttribute("num_results", str(frame.shape[0]) if frame is not None else str(0))
    head.setAttribute("seperator", sep)
    data_elem = doc.createElement("data")

    doc.appendChild(head)
    columns_elem = doc.createElement("symbols")

    time_elem = doc.createElement("symbol")
    time_elem.setAttribute("name", "time")
    time_elem.setAttribute("short_name", "time")
    if not epoch:
        time_elem.setAttribute("format", "%Y-%m-%dT%H:%M:%SZ")
    else:
        time_elem.setAttribute("format", data_responses.epoch_translation[epoch] + " since epoch (1970-01-01 00:00:00)")
    columns_elem.appendChild(time_elem)

    if frame is not None and not frame.empty:
        for c in frame.columns:
            col_elem = doc.createElement("symbol")
            col_elem.setAttribute("name", c)
            parts = c.split(".")
            col_elem.setAttribute("short_name", parts[2])
            col_elem.setAttribute("site", parts[0])
            col_elem.setAttribute("inst", parts[1])
            columns_elem.appendChild(col_elem)

        for idx, (t, row) in enumerate(frame.iterrows()):
            row_elem = doc.createElement("row")
            row_elem.setAttribute("id", str(idx))
            row_elem.appendChild(doc.createTextNode(_time_column_to_epoch_str(t, epoch)))
            for point in row:
                row_elem.appendChild(doc.createTextNode(sep))
                row_elem.appendChild(doc.createTextNode(str(point)))
            data_elem.appendChild(row_elem)
    head.appendChild(columns_elem)
    head.appendChild(data_elem)
    txt = doc.toxml(encoding="utf-8")
    res = Response(txt, mimetype="text/xml")
    res.headers.set("Content-Disposition", "attachment", filename="data.xml")
    return res, code


def _time_column_to_epoch_str(time_val: datetime | np.datetime64 | pd.Timestamp | str, epoch: str) -> str:
    if isinstance(time_val, str):
        return time_val
    if isinstance(time_val, np.datetime64):
        time_val_dt = time_val.astype("datetime64[us]").astype(datetime)
    elif isinstance(time_val, pd.Timestamp):
        time_val_dt = time_val.to_pydatetime()
    else:
        time_val_dt = time_val

    if not epoch:
        return time_val_dt.strftime("%Y-%m-%dT%H:%M:%SZ")

    # NOTE: Assumes host system is UTC
    time_seconds = time_val_dt.timestamp()
    time_in_units = data_responses.seconds_to_epoch[epoch](time_seconds)
    time_in_units = round(time_in_units)  # round to nearest integer
    return f"{time_in_units:d}"


def create_formatted_error(handler_func, error_str):
    err_code, err_msg = data_responses.ERROR_MESSAGES.get(error_str, (400, error_str))
    response_info = {"message": err_msg, "code": err_code, "status": "error"}
    raise FormattedBadRequest(handler_func=handler_func, response_info=response_info)


RESPONSE_HANDLERS = {
    "csv": handle_csv,
    "xml": handle_xml,
    "json": handle_json,
}


def get_format_handler(fmt: str, formatting_kwargs: dict) -> Callable | tuple:
    if fmt not in RESPONSE_HANDLERS:
        raise BadHandlerFormat(bad_format=fmt)
    handler_func = RESPONSE_HANDLERS[fmt]
    if formatting_kwargs.get("order") not in ("column", "row"):
        # bad order
        raise create_formatted_error(partial(handler_func, formatting_kwargs={}), "bad_order")
    return partial(handler_func, formatting_kwargs=formatting_kwargs)


class BadHandlerFormat(BadRequest):
    """The provided file format is unknown."""

    def __init__(self, *, bad_format: str, **kwargs):
        self.bad_format = bad_format
        super().__init__(**kwargs)


class FormattedBadRequest(BadRequest):
    """A request error that we want in the requested format."""

    def __init__(self, *, handler_func: Callable, response_info: dict, **kwargs):
        super().__init__(**kwargs)
        self.handler_func = handler_func
        self.response_info = response_info


def parse_time_parameters(handler_func, begin, end, interval, epoch):
    try:
        begin, end = _convert_begin_and_end(begin, end)
        _check_query_parameters(epoch, interval)
    except ValueError as e:
        raise create_formatted_error(handler_func, str(e)) from e
    return begin, end, interval, epoch


def modify_data(handler_func, time_parameters, site, inst, symbols):
    try:
        short_symbols, fully_qualified_symbols = _parse_symbol_names(site, inst, symbols)
        influx_to_requested_symbols, influx_symbols = handle_symbols(fully_qualified_symbols)
    except ValueError as e:
        raise create_formatted_error(handler_func, str(e)) from e

    data_frame, response_info = _query_time_series_db(time_parameters, influx_symbols)
    data_frame = handle_influxdb_result(data_frame, influx_to_requested_symbols)
    data_frame = _reorder_and_rename_result_dataframe(data_frame, fully_qualified_symbols, short_symbols)
    return handler_func(data_frame, time_parameters[3], response_info)


def _convert_begin_and_end(begin, end) -> tuple[datetime | timedelta, datetime | timedelta]:
    try:
        # these will be either datetime or timedelta objects
        begin = handle_time_string(begin)
        end = handle_time_string(end)
    except (TypeError, ValueError) as e:
        raise ValueError("malformed_timestamp") from e
    return begin, end


def _check_query_parameters(epoch, interval):
    if epoch and epoch not in data_responses.epoch_translation:
        raise ValueError("bad_epoch")
    if interval not in data_responses.INTERVALS:
        raise ValueError("bad_interval")


def _parse_symbol_names(site: str, inst: str, symbols: str) -> tuple[list[str] | None, list[str]]:
    if not symbols:
        raise ValueError("missing_symbols")
    if site and inst:
        # shorthand for symbols that all use the same site and inst
        short_symbols = symbols.split(":")
        fully_qualified_symbols = [f"{site}.{inst}.{s}" for s in short_symbols]
    elif not site and not inst:
        # each symbol is fully qualified with site.inst.symbol
        short_symbols = None
        fully_qualified_symbols = symbols.split(":")
    else:
        raise ValueError("missing_site_inst")
    return short_symbols, fully_qualified_symbols


def _query_time_series_db(time_parameters, influx_symbols):
    if calc_num_records(time_parameters[0], time_parameters[1], time_parameters[2]) > data_responses.RESPONSES_LIMIT:
        message = "Request will return too many values, please use files API"
        code = 413
        status = "fail"
        result = None
    else:
        message = ""
        code = 200
        status = "success"
        try:
            result = query(influx_symbols, *time_parameters)
        except ReadTimeoutError:
            message = "Request took too long to process. It may be too much data. Try a shorter time range."
            code = 413
            status = "fail"
            result = None
    response_info = {"message": message, "code": code, "status": status}
    return result, response_info


def _reorder_and_rename_result_dataframe(
    frame: pd.DataFrame,
    ordered_symbols: list[str],
    new_symbol_names: list[str],
) -> pd.DataFrame:
    # order the resulting symbols the way the user requested
    # assume time is the first column
    frame = frame[ordered_symbols]
    if new_symbol_names is not None:
        frame.columns = new_symbol_names
    return frame