"""Helpers for querying an InfluxDB backend."""
from datetime import timedelta
from typing import TypeAlias

import numpy as np
import pandas as pd
from flask import current_app

try:
    # version 1 library
    from influxdb import InfluxDBClient as InfluxDBClientv1
    from influxdb.resultset import ResultSet
except ImportError:
    InfluxDBClientv1 = None
    ResultSet = None

try:
    # version 2 library
    from influxdb_client import InfluxDBClient as InfluxDBClientv2
except ImportError:
    InfluxDBClientv2 = None

QUERY_FORMAT = "SELECT {symbol_list} FROM metobs.forever.metobs_{interval} WHERE {where_clause} GROUP BY site,inst"

QueryResult: TypeAlias = ResultSet | list[ResultSet]


def parse_dt_v1(d):
    if d is None:
        return "now()"
    elif isinstance(d, timedelta):
        return "now() - {:d}s".format(int(d.total_seconds()))
    else:
        return d.strftime("'%Y-%m-%dT%H:%M:%SZ'")


def build_queries_v1(symbols, begin, end, value):
    begin = parse_dt_v1(begin)
    end = parse_dt_v1(end)

    where_clause = []
    where_clause.append("time >= {}".format(begin))
    where_clause.append("time <= {}".format(end))

    queries = []
    for si, s_list in symbols.items():
        wc = where_clause.copy()
        wc.append("site='{}'".format(si[0]))
        wc.append("inst='{}'".format(si[1]))
        query = QUERY_FORMAT.format(
            symbol_list=", ".join(s_list[1]),
            interval=value,
            where_clause=" AND ".join(wc),
        )
        queries.append(query)
    queries = "; ".join(queries)
    return queries


def _query_influxdbv1(query_str, epoch) -> QueryResult:
    kwargs = {
        "host": current_app.config["INFLUXDB_HOST"],
        "port": current_app.config["INFLUXDB_PORT"],
        "database": current_app.config["INFLUXDB_DB"],
    }
    if current_app.config["INFLUXDB_TOKEN"]:
        # version 2.x, ignore user/pass, use token
        kwargs["username"] = None
        kwargs["password"] = None
        kwargs["headers"] = {
            "Authorization": "Token {}".format(current_app.config["INFLUXDB_TOKEN"]),
        }

    client = InfluxDBClientv1(
        **kwargs,
    )
    return client.query(query_str, epoch=epoch)


def query(symbols, begin, end, interval, epoch):
    handler = QueryHandler(symbols, begin, end, interval)
    result = handler.query(epoch)
    return result


class QueryHandler:
    """Helper class to build queries and submit the query for v1 or v2 of InfluxDB."""

    def __init__(self, symbols, begin, end, interval):
        self._symbols = symbols
        if begin is None:
            begin = timedelta(minutes=5)
        self._begin = begin
        self._end = end
        self._interval = interval

    @property
    def is_v2_db(self):
        return bool(current_app.config["INFLUXDB_TOKEN"])

    def query(self, epoch):
        if not self.is_v2_db and InfluxDBClientv1 is None:
            raise ImportError("Missing 'influxdb' dependency to communicate with InfluxDB v1 database")
        if InfluxDBClientv1 is None and InfluxDBClientv2 is None:
            raise ImportError(
                "Missing 'influxdb-client' and legacy 'influxdb' dependencies to communicate " "with InfluxDB database"
            )

        if True or InfluxDBClientv2 is None:
            # fallback to v1 library
            query_str = build_queries_v1(self._symbols, self._begin, self._end, self._interval)
            result = _query_influxdbv1(query_str, epoch)
            return self._convert_v1_result_to_dataframe(result)

        # version 2 library
        raise NotImplementedError()

    def _convert_v1_result_to_dataframe(self, result):
        frames = []
        for idx, ((site, inst), (_, influx_symbs)) in enumerate(self._symbols.items()):
            if isinstance(result, list):
                # multiple query select statements results in a list
                res = result[idx]
            else:
                # single query statement results in a single ResultSet
                res = result

            columns = ["time"] + influx_symbs
            if not res:
                frame = pd.DataFrame(columns=columns)
            else:
                data_points = res.get_points("metobs_" + self._interval, tags={"site": site, "inst": inst})
                frame = pd.DataFrame(data_points, columns=["time"] + influx_symbs)
            frame.set_index("time", inplace=True)
            frame.fillna(value=np.nan, inplace=True)
            # rename channels to be site-based
            frame.columns = ["time"] + [f"{site}.{inst}.{col_name}" for col_name in frame.columns[1:]]
            frames.append(frame)
            # # remove wind components
            # if influx_symbs[-1] == "wind_north" and "wind_direction" in frame.columns:
            #     frame["wind_direction"] = np.rad2deg(np.arctan2(frame["wind_east"], frame["wind_north"]))
            #     frame["wind_direction"] = frame["wind_direction"].where(
            #         frame["wind_direction"] > 0, frame["wind_direction"] + 360.0
            #     )
            #     frame = frame.iloc[:, :-2]
            #     frame.columns = req_syms[:-2]
            # else:
            #     frame.columns = req_syms
            # frame = frame.round({s: ROUNDING.get(s, 1) for s in frame.columns})
            # frames.append(frame)
        # frame = pd.concat(frames, axis=1, copy=False)
        return frames