From 8cc7dae290b527edf4b707bdccbc9dc71407e4ac Mon Sep 17 00:00:00 2001
From: David Hoese <david.hoese@ssec.wisc.edu>
Date: Mon, 27 Feb 2023 15:17:22 -0600
Subject: [PATCH] Update influx queries to return a single large DataFrame

---
 metobsapi/data_api.py          | 101 ++++++++++++++-------------------
 metobsapi/util/query_influx.py |  56 +++++++-----------
 2 files changed, 63 insertions(+), 94 deletions(-)

diff --git a/metobsapi/data_api.py b/metobsapi/data_api.py
index 45c9c55..71c4296 100644
--- a/metobsapi/data_api.py
+++ b/metobsapi/data_api.py
@@ -16,6 +16,8 @@ from metobsapi.util.query_influx import query
 LOG = logging.getLogger(__name__)
 
 
+# NOTE: These use the influxdb column name (default API name)
+#     In most cases these names are equivalent
 ROUNDING = {
     "aoss.tower.rel_hum": 0,
     "aoss.tower.wind_direction": 0,
@@ -54,7 +56,8 @@ def handle_time_string(date_string):
 
 
 def handle_symbols(symbols):
-    ret = {}
+    influx_symbols: dict[tuple[str, str], list[str]] = {}
+    influx_to_requested_symbols = {}
     handled_symbols = set()
 
     add_winds = set()
@@ -69,72 +72,56 @@ def handle_symbols(symbols):
         except ValueError:
             raise ValueError("Symbols must have 3 period-separated parts: {}".format(symbol))
 
-        trans = data_responses.SYMBOL_TRANSLATIONS.get((site, inst))
-        if not trans:
+        api_to_influx_rename_map = data_responses.SYMBOL_TRANSLATIONS.get((site, inst))
+        if not api_to_influx_rename_map:
             raise ValueError("Unknown site/instrument: {},{}".format(site, inst))
-        influx_name = trans.get(s)
+        influx_name = api_to_influx_rename_map.get(s)
         if s == "wind_direction":
             add_winds.add((site, inst))
         elif influx_name is None:
             raise ValueError("Unknown symbol: {}".format(symbol))
+        influx_name = influx_name or s
 
         # NOTE: if wind speed/dir is specified InfluxDB should provide it with
-        #       all fill values so we can fill it in later
-        ret.setdefault(si, ([], []))
-        ret[si][0].append(symbol)
-        ret[si][1].append(influx_name or s)
+        #       all fill values, so we can fill it in later
+        influx_to_requested_symbols[f"{si[0]}.{si[1]}.{influx_name}"] = symbol
+        influx_symbols.setdefault(si, []).append(influx_name)
 
     # Add the symbols needed to compute the wind_speed and wind_direction
     for si in add_winds:
-        ret[si][0].extend((None, None))
-        ret[si][1].extend(("wind_east", "wind_north"))
-
-    return ret
-
-
-def handle_influxdb_result(data_frames, symbols, interval):
-    new_frames = []
-    for idx, (_, (req_syms, _)) in enumerate(symbols.items()):
-        if data_frames is None:
-            # TODO: Combine this with the same logic in query_influx
-            # invalid query
-            columns = ["time"] + req_syms
-            frame = pd.DataFrame(columns=columns)
-            frame.set_index("time", inplace=True)
-            frame.fillna(value=np.nan, inplace=True)
-            new_frames.append(frame)
-            continue
+        influx_to_requested_symbols[f"{si[0]}.{si[1]}.wind_east"] = None
+        influx_to_requested_symbols[f"{si[0]}.{si[1]}.wind_north"] = None
+        influx_symbols[si].extend(("wind_east", "wind_north"))
 
-        frame = data_frames[idx]
-        # remove wind components
-        if _has_any_wind_vectors(frame.columns):
-            for weast_col_name, wnorth_col_name, wdir_col_name in _wind_vector_column_groups(frame.columns):
-                frame[wdir_col_name] = np.rad2deg(np.arctan2(frame[weast_col_name], frame[wnorth_col_name]))
-                frame[wdir_col_name] = frame[wdir_col_name].where(
-                    frame[wdir_col_name] > 0, frame[wdir_col_name] + 360.0
-                )
-                weast_col_idx = frame.columns.get_loc(weast_col_name)
-                wnorth_col_idx = frame.columns.get_loc(wnorth_col_name)
-                keep_requested_symbols = [
-                    req_sym
-                    for sym_idx, req_sym in enumerate(req_syms)
-                    if sym_idx not in [weast_col_idx, wnorth_col_idx]
-                ]
-                frame = frame.drop(columns=[weast_col_name, wnorth_col_name])
-                frame.columns = keep_requested_symbols
-        else:
-            frame.columns = req_syms
-        frame = frame.round({s: ROUNDING.get(s, 1) for s in frame.columns})
-        new_frames.append(frame)
-    frame = pd.concat(new_frames, axis=1, copy=False)
-    return frame
+    return influx_to_requested_symbols, influx_symbols
+
+
+def handle_influxdb_result(data_frame, influx_to_requested_symbols, interval):
+    valid_requested_symbols = [symbol for symbol in influx_to_requested_symbols.values() if symbol is not None]
+    if data_frame is None:
+        # invalid query
+        columns = ["time"] + valid_requested_symbols
+        data_frame = pd.DataFrame(columns=columns)
+        data_frame.set_index("time", inplace=True)
+        data_frame.fillna(value=np.nan, inplace=True)
+        return data_frame
+
+    data_frame = _convert_wind_vectors_to_direction_if_needed(data_frame)
+    data_frame = data_frame.round({s: ROUNDING.get(s, 1) for s in data_frame.columns})
+    # rename columns to API names requested by the user
+    # could be "<site>.<inst>.name" or just "name"
+    data_frame.columns = valid_requested_symbols
+    return data_frame
 
 
-def _has_any_wind_vectors(columns: list[str]) -> bool:
-    has_wind_east = any("wind_east" in col_name for col_name in columns)
-    has_wind_north = any("wind_north" in col_name for col_name in columns)
-    has_wind_direction = any("wind_direction" in col_name for col_name in columns)
-    return has_wind_east and has_wind_north and has_wind_direction
+def _convert_wind_vectors_to_direction_if_needed(data_frame: pd.DataFrame) -> pd.DataFrame:
+    for weast_col_name, wnorth_col_name, wdir_col_name in _wind_vector_column_groups(data_frame.columns):
+        data_frame[wdir_col_name] = np.rad2deg(np.arctan2(data_frame[weast_col_name], data_frame[wnorth_col_name]))
+        data_frame[wdir_col_name] = data_frame[wdir_col_name].where(
+            data_frame[wdir_col_name] > 0, data_frame[wdir_col_name] + 360.0
+        )
+        data_frame = data_frame.drop(columns=[weast_col_name, wnorth_col_name])
+    return data_frame
 
 
 def _wind_vector_column_groups(columns: list[str]) -> Iterable[tuple[str, str, str]]:
@@ -333,12 +320,12 @@ def modify_data(fmt, begin, end, site, inst, symbols, interval, sep=",", order="
         begin, end = _convert_begin_and_end(begin, end)
         _check_query_parameters(order, epoch, symbols, interval)
         short_symbols, symbols = _parse_symbol_names(site, inst, symbols)
-        influx_symbols = handle_symbols(symbols)
+        influx_to_requested_symbols, influx_symbols = handle_symbols(symbols)
     except ValueError as e:
         return handle_error(fmt, str(e))
 
-    data_frames, response_info = _query_time_series_db(begin, end, interval, influx_symbols, epoch)
-    data_frame = handle_influxdb_result(data_frames, influx_symbols, interval)
+    data_frame, response_info = _query_time_series_db(begin, end, interval, influx_symbols, epoch)
+    data_frame = handle_influxdb_result(data_frame, influx_to_requested_symbols, interval)
     data_frame = _reorder_and_rename_result_dataframe(data_frame, symbols, short_symbols)
     handler = RESPONSE_HANDLERS[fmt]
     return handler(data_frame, epoch, sep=sep, order=order, **response_info)
diff --git a/metobsapi/util/query_influx.py b/metobsapi/util/query_influx.py
index f50b2ab..f1a3d69 100644
--- a/metobsapi/util/query_influx.py
+++ b/metobsapi/util/query_influx.py
@@ -22,7 +22,7 @@ except ImportError:
 
 QUERY_FORMAT = "SELECT {symbol_list} FROM metobs.forever.metobs_{interval} WHERE {where_clause} GROUP BY site,inst"
 
-QueryResult: TypeAlias = ResultSet | list[ResultSet]
+QueryResult: TypeAlias = list[ResultSet]
 
 
 def parse_dt_v1(d):
@@ -43,12 +43,12 @@ def build_queries_v1(symbols, begin, end, value):
     where_clause.append("time <= {}".format(end))
 
     queries = []
-    for si, s_list in symbols.items():
+    for (site, inst), symbol_names in symbols.items():
         wc = where_clause.copy()
-        wc.append("site='{}'".format(si[0]))
-        wc.append("inst='{}'".format(si[1]))
+        wc.append("site='{}'".format(site))
+        wc.append("inst='{}'".format(inst))
         query = QUERY_FORMAT.format(
-            symbol_list=", ".join(s_list[1]),
+            symbol_list=", ".join(symbol_names),
             interval=value,
             where_clause=" AND ".join(wc),
         )
@@ -74,7 +74,10 @@ def _query_influxdbv1(query_str, epoch) -> QueryResult:
     client = InfluxDBClientv1(
         **kwargs,
     )
-    return client.query(query_str, epoch=epoch)
+    res = client.query(query_str, epoch=epoch)
+    if not isinstance(res, list):
+        res = [res]
+    return res
 
 
 def query(symbols, begin, end, interval, epoch):
@@ -115,38 +118,17 @@ class QueryHandler:
         # version 2 library
         raise NotImplementedError()
 
-    def _convert_v1_result_to_dataframe(self, result):
+    def _convert_v1_result_to_dataframe(self, result: QueryResult) -> pd.DataFrame:
         frames = []
-        for idx, ((site, inst), (_, influx_symbs)) in enumerate(self._symbols.items()):
-            if isinstance(result, list):
-                # multiple query select statements results in a list
-                res = result[idx]
-            else:
-                # single query statement results in a single ResultSet
-                res = result
-
-            columns = ["time"] + influx_symbs
-            if not res:
-                frame = pd.DataFrame(columns=columns)
-            else:
-                data_points = res.get_points("metobs_" + self._interval, tags={"site": site, "inst": inst})
-                frame = pd.DataFrame(data_points, columns=["time"] + influx_symbs)
+        for idx, ((site, inst), influx_symbols) in enumerate(self._symbols.items()):
+            # we expect one result for each site/inst combination
+            res = result[idx]
+            data_points = res.get_points("metobs_" + self._interval, tags={"site": site, "inst": inst})
+            frame = pd.DataFrame(data_points, columns=["time"] + influx_symbols)
+            # rename channels to be site-based (time will be the index)
+            frame.columns = ["time"] + [f"{site}.{inst}.{col_name}" for col_name in frame.columns[1:]]
             frame.set_index("time", inplace=True)
             frame.fillna(value=np.nan, inplace=True)
-            # rename channels to be site-based
-            frame.columns = ["time"] + [f"{site}.{inst}.{col_name}" for col_name in frame.columns[1:]]
             frames.append(frame)
-            # # remove wind components
-            # if influx_symbs[-1] == "wind_north" and "wind_direction" in frame.columns:
-            #     frame["wind_direction"] = np.rad2deg(np.arctan2(frame["wind_east"], frame["wind_north"]))
-            #     frame["wind_direction"] = frame["wind_direction"].where(
-            #         frame["wind_direction"] > 0, frame["wind_direction"] + 360.0
-            #     )
-            #     frame = frame.iloc[:, :-2]
-            #     frame.columns = req_syms[:-2]
-            # else:
-            #     frame.columns = req_syms
-            # frame = frame.round({s: ROUNDING.get(s, 1) for s in frame.columns})
-            # frames.append(frame)
-        # frame = pd.concat(frames, axis=1, copy=False)
-        return frames
+        frame = pd.concat(frames, axis=1, copy=False)
+        return frame
-- 
GitLab