From 338e895f3350bf548c507ed9140956083261cb28 Mon Sep 17 00:00:00 2001
From: David Hoese <david.hoese@ssec.wisc.edu>
Date: Mon, 27 Feb 2023 14:25:40 -0600
Subject: [PATCH] Start refactoring influxdb requests to return DataFrames

---
 metobsapi/common_config.py       |   3 +
 metobsapi/data_api.py            |  98 +++++++++++++++--------
 metobsapi/tests/test_data_api.py |  44 +++++++----
 metobsapi/util/query_influx.py   | 129 +++++++++++++++++++++++++++----
 4 files changed, 213 insertions(+), 61 deletions(-)

diff --git a/metobsapi/common_config.py b/metobsapi/common_config.py
index 76603e0..a4f0d0c 100644
--- a/metobsapi/common_config.py
+++ b/metobsapi/common_config.py
@@ -14,4 +14,7 @@ INFLUXDB_PORT = 8086
 INFLUXDB_USER = "root"
 # Security: This is expected to be overwritten either via environment variable or sub-configuration
 INFLUXDB_PASS = "root"  # nosec B105
+# If token is provided, user and password are ignored
+# Security: This is expected to be overwritten either via environment variable or sub-configuration
+INFLUXDB_TOKEN = ""  # nosec B105
 INFLUXDB_DB = "metobs"
diff --git a/metobsapi/data_api.py b/metobsapi/data_api.py
index 81e1e08..45c9c55 100644
--- a/metobsapi/data_api.py
+++ b/metobsapi/data_api.py
@@ -1,5 +1,6 @@
 import logging
 from datetime import datetime, timedelta
+from typing import Iterable
 
 # Security: Document is only used for creating an XML document, not parsing one
 from xml.dom.minidom import Document  # nosec B408
@@ -10,7 +11,7 @@ from flask import Response, render_template
 from flask_json import as_json_p
 
 from metobsapi.util import data_responses
-from metobsapi.util.query_influx import build_queries, query
+from metobsapi.util.query_influx import query
 
 LOG = logging.getLogger(__name__)
 
@@ -91,40 +92,72 @@ def handle_symbols(symbols):
     return ret
 
 
-def handle_influxdb_result(result, symbols, interval):
-    frames = []
-    for idx, (si, (req_syms, influx_symbs)) in enumerate(symbols.items()):
-        if isinstance(result, list):
-            # multiple query select statements results in a list
-            res = result[idx]
-        else:
-            # single query statement results in a single ResultSet
-            res = result
-
-        columns = ["time"] + influx_symbs
-        if not res:
+def handle_influxdb_result(data_frames, symbols, interval):
+    new_frames = []
+    for idx, (_, (req_syms, _)) in enumerate(symbols.items()):
+        if data_frames is None:
+            # TODO: Combine this with the same logic in query_influx
+            # invalid query
+            columns = ["time"] + req_syms
             frame = pd.DataFrame(columns=columns)
-        else:
-            data_points = res.get_points("metobs_" + interval, tags={"site": si[0], "inst": si[1]})
-            frame = pd.DataFrame(data_points, columns=["time"] + influx_symbs)
-        frame.set_index("time", inplace=True)
-        frame.fillna(value=np.nan, inplace=True)
+            frame.set_index("time", inplace=True)
+            frame.fillna(value=np.nan, inplace=True)
+            new_frames.append(frame)
+            continue
+
+        frame = data_frames[idx]
         # remove wind components
-        if influx_symbs[-1] == "wind_north" and "wind_direction" in frame.columns:
-            frame["wind_direction"] = np.rad2deg(np.arctan2(frame["wind_east"], frame["wind_north"]))
-            frame["wind_direction"] = frame["wind_direction"].where(
-                frame["wind_direction"] > 0, frame["wind_direction"] + 360.0
-            )
-            frame = frame.iloc[:, :-2]
-            frame.columns = req_syms[:-2]
+        if _has_any_wind_vectors(frame.columns):
+            for weast_col_name, wnorth_col_name, wdir_col_name in _wind_vector_column_groups(frame.columns):
+                frame[wdir_col_name] = np.rad2deg(np.arctan2(frame[weast_col_name], frame[wnorth_col_name]))
+                frame[wdir_col_name] = frame[wdir_col_name].where(
+                    frame[wdir_col_name] > 0, frame[wdir_col_name] + 360.0
+                )
+                weast_col_idx = frame.columns.get_loc(weast_col_name)
+                wnorth_col_idx = frame.columns.get_loc(wnorth_col_name)
+                keep_requested_symbols = [
+                    req_sym
+                    for sym_idx, req_sym in enumerate(req_syms)
+                    if sym_idx not in [weast_col_idx, wnorth_col_idx]
+                ]
+                frame = frame.drop(columns=[weast_col_name, wnorth_col_name])
+                frame.columns = keep_requested_symbols
         else:
             frame.columns = req_syms
         frame = frame.round({s: ROUNDING.get(s, 1) for s in frame.columns})
-        frames.append(frame)
-    frame = pd.concat(frames, axis=1, copy=False)
+        new_frames.append(frame)
+    frame = pd.concat(new_frames, axis=1, copy=False)
     return frame
 
 
+def _has_any_wind_vectors(columns: list[str]) -> bool:
+    has_wind_east = any("wind_east" in col_name for col_name in columns)
+    has_wind_north = any("wind_north" in col_name for col_name in columns)
+    has_wind_direction = any("wind_direction" in col_name for col_name in columns)
+    return has_wind_east and has_wind_north and has_wind_direction
+
+
+def _wind_vector_column_groups(columns: list[str]) -> Iterable[tuple[str, str, str]]:
+    for wind_north_column in _all_wind_north_columns(columns):
+        # assume column name is <site>.<inst>.wind_north
+        site, inst = wind_north_column.split(".")[:2]
+        wind_east_column = f"{site}.{inst}.wind_east"
+        wind_direction_column = f"{site}.{inst}.wind_direction"
+        if wind_east_column not in columns:
+            continue
+        if wind_direction_column not in columns:
+            continue
+        yield wind_east_column, wind_north_column, wind_direction_column
+
+
+def _all_wind_north_columns(columns: list[str]) -> list[str]:
+    return [col_name for col_name in columns if "wind_north" in col_name]
+
+
+def _convert_wind_vectors_to_direction(frame: pd.DataFrame) -> pd.DataFrame:
+    pass
+
+
 def calc_num_records(begin, end, interval):
     now = datetime.utcnow()
     if begin is None:
@@ -304,11 +337,11 @@ def modify_data(fmt, begin, end, site, inst, symbols, interval, sep=",", order="
     except ValueError as e:
         return handle_error(fmt, str(e))
 
-    result, response_info = _query_time_series_db(begin, end, interval, influx_symbols, epoch)
-    frame = handle_influxdb_result(result, influx_symbols, interval)
-    frame = _reorder_and_rename_result_dataframe(frame, symbols, short_symbols)
+    data_frames, response_info = _query_time_series_db(begin, end, interval, influx_symbols, epoch)
+    data_frame = handle_influxdb_result(data_frames, influx_symbols, interval)
+    data_frame = _reorder_and_rename_result_dataframe(data_frame, symbols, short_symbols)
     handler = RESPONSE_HANDLERS[fmt]
-    return handler(frame, epoch, sep=sep, order=order, **response_info)
+    return handler(data_frame, epoch, sep=sep, order=order, **response_info)
 
 
 def _convert_begin_and_end(begin, end) -> tuple[datetime | timedelta, datetime | timedelta]:
@@ -356,8 +389,7 @@ def _query_time_series_db(begin, end, interval, influx_symbols, epoch):
         message = ""
         code = 200
         status = "success"
-        queries = build_queries(influx_symbols, begin, end, interval)
-        result = query(queries, epoch)
+        result = query(influx_symbols, begin, end, interval, epoch)
     response_info = {"message": message, "code": code, "status": status}
     return result, response_info
 
diff --git a/metobsapi/tests/test_data_api.py b/metobsapi/tests/test_data_api.py
index d80edee..9b1d19f 100644
--- a/metobsapi/tests/test_data_api.py
+++ b/metobsapi/tests/test_data_api.py
@@ -2,20 +2,21 @@ import contextlib
 import json
 import random
 from datetime import datetime, timedelta
-from typing import Callable, ContextManager, Iterable, Iterator, TypeAlias
+from typing import Callable, ContextManager, Iterable, Iterator
 from unittest import mock
 
+import numpy as np
 import pytest
 from influxdb.resultset import ResultSet
 
-QueryResult: TypeAlias = ResultSet | list[ResultSet]
+from metobsapi.util.query_influx import QueryResult
 
 
 @pytest.fixture
 def mock_influxdb_query() -> Callable[[QueryResult], ContextManager[mock.Mock]]:
     @contextlib.contextmanager
     def _mock_influxdb_query_with_fake_data(fake_result_data: QueryResult) -> Iterator[mock.Mock]:
-        with mock.patch("metobsapi.data_api.query") as query_func:
+        with mock.patch("metobsapi.util.query_influx._query_influxdbv1") as query_func:
             query_func.return_value = fake_result_data
             yield query_func
 
@@ -45,16 +46,16 @@ def influxdb_wind_fields_9_values(mock_influxdb_query) -> Iterable[None]:
         yield
 
 
-def _fake_data(interval, symbols, num_vals, single_result=False):
+def _fake_data(interval: str, symbols: dict, num_vals: int, single_result: bool = False) -> QueryResult:
     now = datetime(2017, 3, 5, 19, 0, 0)
     t_format = "%Y-%m-%dT%H:%M:%SZ"
     measurement_name = "metobs_" + interval
     series = []
     for (site, inst), columns in symbols.items():
         tags = {"site": site, "inst": inst}
-        vals = []
+        vals: list[list[str | float | None]] = []
         for i in range(num_vals):
-            vals.append([(now + timedelta(minutes=i)).strftime(t_format)] + [random.random()] * (len(columns) - 1))
+            vals.append([(now + timedelta(minutes=i)).strftime(t_format)] + _generate_fake_column_values(columns[1:]))
             # make up some Nones/nulls (but not all the time)
             r_idx = int(random.random() * len(columns) * 3)
             # don't include time (index 0)
@@ -88,6 +89,14 @@ def _fake_data(interval, symbols, num_vals, single_result=False):
         return series
 
 
+def _generate_fake_column_values(columns: list[str]) -> list[float | None]:
+    rand_data: list[float | None] = [random.random()] * len(columns)
+    if "wind_direction" in columns:
+        # wind direction shouldn't exist in the database
+        rand_data[columns.index("wind_direction")] = None
+    return rand_data
+
+
 @pytest.fixture
 def app():
     from metobsapi import app as real_app
@@ -189,16 +198,25 @@ class TestDataAPI:
         assert len(res["results"]["data"]["air_temp"]) == 9
         assert len(res["results"]["timestamps"]) == 9
 
-    def test_wind_speed_direction_json(self, client, influxdb_wind_fields_9_values):
-        res = client.get(
-            "/api/data.json?symbols=aoss.tower.wind_speed:aoss.tower.wind_direction&begin=-00:10:00&order=column"
-        )
+    @pytest.mark.parametrize(
+        "symbols",
+        [
+            ["aoss.tower.wind_speed", "aoss.tower.wind_direction"],
+            ["wind_speed", "wind_direction"],
+        ],
+    )
+    def test_wind_speed_direction_json(self, symbols, client, influxdb_wind_fields_9_values):
+        symbol_param = ":".join(symbols)
+        site_inst_params = "&site=aoss&inst=tower" if "." not in symbols[0] else ""
+
+        res = client.get(f"/api/data.json?symbols={symbol_param}{site_inst_params}&begin=-00:10:00&order=column")
         res = json.loads(res.data.decode())
         assert res["code"] == 200
         assert res["num_results"] == 9
-        assert "aoss.tower.wind_direction" in res["results"]["data"]
-        assert "aoss.tower.wind_speed" in res["results"]["data"]
-        assert len(list(res["results"]["data"].keys())) == 2
+        for symbol_name in symbols:
+            assert symbol_name in res["results"]["data"]
+            assert not np.isnan(res["results"]["data"][symbol_name]).all()
+        assert len(list(res["results"]["data"].keys())) == len(symbols)
 
     def test_one_symbol_two_insts_json_row(self, client, mock_influxdb_query):
         fake_result = _fake_data(
diff --git a/metobsapi/util/query_influx.py b/metobsapi/util/query_influx.py
index 57f4054..f50b2ab 100644
--- a/metobsapi/util/query_influx.py
+++ b/metobsapi/util/query_influx.py
@@ -1,13 +1,31 @@
 """Helpers for querying an InfluxDB backend."""
 from datetime import timedelta
+from typing import TypeAlias
 
+import numpy as np
+import pandas as pd
 from flask import current_app
-from influxdb import InfluxDBClient
+
+try:
+    # version 1 library
+    from influxdb import InfluxDBClient as InfluxDBClientv1
+    from influxdb.resultset import ResultSet
+except ImportError:
+    InfluxDBClientv1 = None
+    ResultSet = None
+
+try:
+    # version 2 library
+    from influxdb_client import InfluxDBClient as InfluxDBClientv2
+except ImportError:
+    InfluxDBClientv2 = None
 
 QUERY_FORMAT = "SELECT {symbol_list} FROM metobs.forever.metobs_{interval} WHERE {where_clause} GROUP BY site,inst"
 
+QueryResult: TypeAlias = ResultSet | list[ResultSet]
 
-def parse_dt(d):
+
+def parse_dt_v1(d):
     if d is None:
         return "now()"
     elif isinstance(d, timedelta):
@@ -16,12 +34,9 @@ def parse_dt(d):
         return d.strftime("'%Y-%m-%dT%H:%M:%SZ'")
 
 
-def build_queries(symbols, begin, end, value):
-    if begin is None:
-        # TODO: This should use the "FIRST(symbol_name)" function
-        begin = timedelta(minutes=5)
-    begin = parse_dt(begin)
-    end = parse_dt(end)
+def build_queries_v1(symbols, begin, end, value):
+    begin = parse_dt_v1(begin)
+    end = parse_dt_v1(end)
 
     where_clause = []
     where_clause.append("time >= {}".format(begin))
@@ -42,12 +57,96 @@ def build_queries(symbols, begin, end, value):
     return queries
 
 
-def query(query_str, epoch):
-    client = InfluxDBClient(
-        current_app.config["INFLUXDB_HOST"],
-        current_app.config["INFLUXDB_PORT"],
-        current_app.config["INFLUXDB_USER"],
-        current_app.config["INFLUXDB_PASS"],
-        current_app.config["INFLUXDB_DB"],
+def _query_influxdbv1(query_str, epoch) -> QueryResult:
+    kwargs = {
+        "host": current_app.config["INFLUXDB_HOST"],
+        "port": current_app.config["INFLUXDB_PORT"],
+        "database": current_app.config["INFLUXDB_DB"],
+    }
+    if current_app.config["INFLUXDB_TOKEN"]:
+        # version 2.x, ignore user/pass, use token
+        kwargs["username"] = None
+        kwargs["password"] = None
+        kwargs["headers"] = {
+            "Authorization": "Token {}".format(current_app.config["INFLUXDB_TOKEN"]),
+        }
+
+    client = InfluxDBClientv1(
+        **kwargs,
     )
     return client.query(query_str, epoch=epoch)
+
+
+def query(symbols, begin, end, interval, epoch):
+    handler = QueryHandler(symbols, begin, end, interval)
+    result = handler.query(epoch)
+    return result
+
+
+class QueryHandler:
+    """Helper class to build queries and submit the query for v1 or v2 of InfluxDB."""
+
+    def __init__(self, symbols, begin, end, interval):
+        self._symbols = symbols
+        if begin is None:
+            begin = timedelta(minutes=5)
+        self._begin = begin
+        self._end = end
+        self._interval = interval
+
+    @property
+    def is_v2_db(self):
+        return bool(current_app.config["INFLUXDB_TOKEN"])
+
+    def query(self, epoch):
+        if not self.is_v2_db and InfluxDBClientv1 is None:
+            raise ImportError("Missing 'influxdb' dependency to communicate with InfluxDB v1 database")
+        if InfluxDBClientv1 is None and InfluxDBClientv2 is None:
+            raise ImportError(
+                "Missing 'influxdb-client' and legacy 'influxdb' dependencies to communicate " "with InfluxDB database"
+            )
+
+        if True or InfluxDBClientv2 is None:
+            # fallback to v1 library
+            query_str = build_queries_v1(self._symbols, self._begin, self._end, self._interval)
+            result = _query_influxdbv1(query_str, epoch)
+            return self._convert_v1_result_to_dataframe(result)
+
+        # version 2 library
+        raise NotImplementedError()
+
+    def _convert_v1_result_to_dataframe(self, result):
+        frames = []
+        for idx, ((site, inst), (_, influx_symbs)) in enumerate(self._symbols.items()):
+            if isinstance(result, list):
+                # multiple query select statements results in a list
+                res = result[idx]
+            else:
+                # single query statement results in a single ResultSet
+                res = result
+
+            columns = ["time"] + influx_symbs
+            if not res:
+                frame = pd.DataFrame(columns=columns)
+            else:
+                data_points = res.get_points("metobs_" + self._interval, tags={"site": site, "inst": inst})
+                frame = pd.DataFrame(data_points, columns=["time"] + influx_symbs)
+            frame.set_index("time", inplace=True)
+            frame.fillna(value=np.nan, inplace=True)
+            # rename channels to be site-based
+            frame.columns = ["time"] + [f"{site}.{inst}.{col_name}" for col_name in frame.columns[1:]]
+            frames.append(frame)
+            # # remove wind components
+            # if influx_symbs[-1] == "wind_north" and "wind_direction" in frame.columns:
+            #     frame["wind_direction"] = np.rad2deg(np.arctan2(frame["wind_east"], frame["wind_north"]))
+            #     frame["wind_direction"] = frame["wind_direction"].where(
+            #         frame["wind_direction"] > 0, frame["wind_direction"] + 360.0
+            #     )
+            #     frame = frame.iloc[:, :-2]
+            #     frame.columns = req_syms[:-2]
+            # else:
+            #     frame.columns = req_syms
+            # frame = frame.round({s: ROUNDING.get(s, 1) for s in frame.columns})
+            # frames.append(frame)
+        # frame = pd.concat(frames, axis=1, copy=False)
+        return frames
-- 
GitLab