import logging
from datetime import datetime, timedelta

# Security: Document is only used for creating an XML document, not parsing one
from xml.dom.minidom import Document  # nosec B408

import numpy as np
import pandas as pd
from flask import Response, render_template
from flask_json import as_json_p

from metobsapi.util import data_responses
from metobsapi.util.query_influx import build_queries, query

LOG = logging.getLogger(__name__)


ROUNDING = {
    "aoss.tower.rel_hum": 0,
    "aoss.tower.wind_direction": 0,
    "aoss.tower.accum_precip": 2,
    "mendota.buoy.rel_hum": 0,
    "mendota.buoy.wind_direction": 0,
    "mendota.buoy.chlorophyll_ysi": 2,
    "mendota.buoy.phycocyanin_ysi": 2,
}


def round_value(value, symbol):
    return np.round(value, ROUNDING.get(symbol, 1))


def handle_date(date):
    try:
        date_len = len(date)
        if date_len == 10:
            return datetime.strptime(date, "%Y-%m-%d")
        else:
            return datetime.strptime(date, "%Y-%m-%dT%H:%M:%S")
    except ValueError:
        LOG.warning("Malformed date string '%s'", date)
        raise ValueError("Malformed date string '%s'" % (date,))


def handle_time_string(date_string):
    if date_string[0] == "-":
        times = [float(x) for x in date_string[1:].split(":")]
        diff = timedelta(hours=times[0], minutes=times[1], seconds=times[2])
        return diff
    return handle_date(date_string)


def handle_symbols(symbols):
    ret = {}
    handled_symbols = set()

    add_winds = set()
    for symbol in symbols:
        if symbol in handled_symbols:
            raise ValueError("Symbol listed multiple times: {}".format(symbol))
        handled_symbols.add(symbol)

        try:
            site, inst, s = symbol.split(".")
            si = (site, inst)
        except ValueError:
            raise ValueError("Symbols must have 3 period-separated parts: {}".format(symbol))

        trans = data_responses.SYMBOL_TRANSLATIONS.get((site, inst))
        if not trans:
            raise ValueError("Unknown site/instrument: {},{}".format(site, inst))
        influx_name = trans.get(s)
        if s == "wind_direction":
            add_winds.add((site, inst))
        elif influx_name is None:
            raise ValueError("Unknown symbol: {}".format(symbol))

        # NOTE: if wind speed/dir is specified InfluxDB should provide it with
        #       all fill values so we can fill it in later
        ret.setdefault(si, ([], []))
        ret[si][0].append(symbol)
        ret[si][1].append(influx_name or s)

    # Add the symbols needed to compute the wind_speed and wind_direction
    for si in add_winds:
        ret[si][0].extend((None, None))
        ret[si][1].extend(("wind_east", "wind_north"))

    return ret


def handle_influxdb_result(result, symbols, interval):
    frames = []
    for idx, (si, (req_syms, influx_symbs)) in enumerate(symbols.items()):
        if isinstance(result, list):
            # multiple query select statements results in a list
            res = result[idx]
        else:
            # single query statement results in a single ResultSet
            res = result

        columns = ["time"] + influx_symbs
        if not res:
            frame = pd.DataFrame(columns=columns)
        else:
            data_points = res.get_points("metobs_" + interval, tags={"site": si[0], "inst": si[1]})
            frame = pd.DataFrame(data_points, columns=["time"] + influx_symbs)
        frame.set_index("time", inplace=True)
        frame.fillna(value=np.nan, inplace=True)
        # remove wind components
        if influx_symbs[-1] == "wind_north" and "wind_direction" in frame.columns:
            frame["wind_direction"] = np.rad2deg(np.arctan2(frame["wind_east"], frame["wind_north"]))
            frame["wind_direction"] = frame["wind_direction"].where(
                frame["wind_direction"] > 0, frame["wind_direction"] + 360.0
            )
            frame = frame.iloc[:, :-2]
            frame.columns = req_syms[:-2]
        else:
            frame.columns = req_syms
        frame = frame.round({s: ROUNDING.get(s, 1) for s in frame.columns})
        frames.append(frame)
    frame = pd.concat(frames, axis=1, copy=False)
    return frame


def calc_num_records(begin, end, interval):
    now = datetime.utcnow()
    if begin is None:
        begin = now
    elif isinstance(begin, timedelta):
        begin = now - begin
    if end is None:
        end = now
    elif isinstance(end, timedelta):
        end = now - end

    diff = (end - begin).total_seconds()
    return diff / data_responses.INTERVALS[interval]


def calc_file_size(num_records, num_streams):
    """Number of bytes returned for a text based format"""
    # estimate about 7 bytes (overhead + data characters) per data point
    return num_records * num_streams * 7.0


def handle_csv(frame, epoch, sep=",", message="", code=200, status="success", **kwargs):
    output = """# status: {status}
# code: {code:d}
# message: {message}
# num_results: {num_results:d}
# fields: {epoch_str}{sep}{symbol_list}
{symbol_data}
"""

    data_lines = []
    line_format = sep.join(["{time}", "{symbols}"])
    if frame is not None and not frame.empty:
        for t, row in frame.iterrows():
            line = line_format.format(time=t, symbols=sep.join(str(x) for x in row.values))
            data_lines.append(line)

    if not epoch:
        epoch_str = "%Y-%m-%dT%H:%M:%SZ"
    else:
        epoch_str = data_responses.epoch_translation[epoch] + " since epoch (1970-01-01 00:00:00)"

    output = output.format(
        sep=sep,
        status=status,
        code=code,
        message=message,
        num_results=frame.shape[0] if frame is not None else 0,
        epoch_str=epoch_str,
        symbol_list=sep.join(frame.columns) if frame is not None else "",
        symbol_data="\n".join(data_lines),
    )

    res = Response(output, mimetype="text/csv")
    res.headers.set("Content-Disposition", "attachment", filename="data.csv")
    return res, code


@as_json_p(optional=True)
def handle_json(frame, epoch, order="columns", message="", code=200, status="success", **kwargs):
    package = {}

    if frame is not None and not frame.empty:
        # force conversion to float types so they can be json'd
        for column, data_type in zip(frame.columns, frame.dtypes.values):
            if issubclass(data_type.type, np.integer):
                frame[column] = frame[column].astype(float)
        # replace NaNs with None
        frame = frame.where(pd.notnull(frame), None)

        package["timestamps"] = frame.index.values
        if epoch:
            newStamps = []
            for stamp in package["timestamps"]:
                newStamps.append(float(stamp))
            package["timestamps"] = newStamps

        if order == "column":
            package["data"] = dict(frame)
        else:
            package["symbols"] = frame.columns
            package["data"] = [frame.iloc[i].values for i in range(frame.shape[0])]
            # package['data'] = frame.values
    else:
        package["timestamps"] = []
        if order == "column":
            package["data"] = {}
        else:
            package["data"] = []
            package["symbols"] = []

    output = {
        "status": status,
        "message": message,
        "code": code,
        "num_results": frame.shape[0] if frame is not None else 0,
        "results": package,
    }
    return output, code


def handle_xml(frame, epoch, sep=",", message="", code=200, status="success", **kwargs):
    doc = Document()
    header = "metobs"

    head = doc.createElement(header)
    head.setAttribute("status", status)
    head.setAttribute("code", str(code))
    head.setAttribute("message", message)
    head.setAttribute("num_results", str(frame.shape[0]) if frame is not None else str(0))
    head.setAttribute("seperator", sep)
    data_elem = doc.createElement("data")

    doc.appendChild(head)
    columns_elem = doc.createElement("symbols")

    time_elem = doc.createElement("symbol")
    time_elem.setAttribute("name", "time")
    time_elem.setAttribute("short_name", "time")
    if not epoch:
        time_elem.setAttribute("format", "%Y-%m-%dT%H:%M:%SZ")
    else:
        time_elem.setAttribute("format", data_responses.epoch_translation[epoch] + " since epoch (1970-01-01 00:00:00)")
    columns_elem.appendChild(time_elem)

    if frame is not None and not frame.empty:
        for c in frame.columns:
            col_elem = doc.createElement("symbol")
            col_elem.setAttribute("name", c)
            parts = c.split(".")
            col_elem.setAttribute("short_name", parts[2])
            col_elem.setAttribute("site", parts[0])
            col_elem.setAttribute("inst", parts[1])
            columns_elem.appendChild(col_elem)

        for idx, (t, row) in enumerate(frame.iterrows()):
            row_elem = doc.createElement("row")
            row_elem.setAttribute("id", str(idx))
            row_elem.appendChild(doc.createTextNode(str(t)))
            for point in row:
                row_elem.appendChild(doc.createTextNode(sep))
                row_elem.appendChild(doc.createTextNode(str(point)))
            data_elem.appendChild(row_elem)
    head.appendChild(columns_elem)
    head.appendChild(data_elem)
    txt = doc.toxml(encoding="utf-8")
    res = Response(txt, mimetype="text/xml")
    res.headers.set("Content-Disposition", "attachment", filename="data.xml")
    return res, code


def handle_error(fmt, error_str):
    handler = RESPONSE_HANDLERS[fmt]
    err_code, err_msg = data_responses.ERROR_MESSAGES.get(error_str, (400, error_str))
    res = handler(None, None, message=err_msg, code=err_code, status="error")
    return res


RESPONSE_HANDLERS = {
    "csv": handle_csv,
    "xml": handle_xml,
    "json": handle_json,
}


def modify_data(fmt, begin, end, site, inst, symbols, interval, sep=",", order="columns", epoch=None):
    if fmt not in RESPONSE_HANDLERS:
        return render_template("400.html", format=fmt), 400

    try:
        # these will be either datetime or timedelta objects
        begin = handle_time_string(begin) if begin else None
        end = handle_time_string(end) if end else None
    except (TypeError, ValueError):
        return handle_error(fmt, "malformed_timestamp")

    if order not in ("column", "row"):
        return handle_error(fmt, "bad_order")
    if epoch and epoch not in data_responses.epoch_translation:
        return handle_error(fmt, "bad_epoch")
    if not symbols:
        return handle_error(fmt, "missing_symbols")
    if not interval:
        interval = "1m"
    elif interval not in data_responses.INTERVALS:
        return handle_error(fmt, "bad_interval")

    if site and inst:
        # shorthand for symbols that all use the same site and inst
        short_symbols = symbols.split(":")
        symbols = ["{}.{}.{}".format(site, inst, s) for s in short_symbols]
    elif not site and not inst:
        # each symbol is fully qualified with site.inst.symbol
        short_symbols = None
        symbols = symbols.split(":")
    else:
        return handle_error(fmt, "missing_site_inst")

    try:
        influx_symbols = handle_symbols(symbols)
    except ValueError as e:
        return handle_error(fmt, str(e))

    if calc_num_records(begin, end, interval) > data_responses.RESPONSES_LIMIT:
        message = "Request will return too many values, please use files API"
        code = 413
        status = "fail"
        result = None
    else:
        message = ""
        code = 200
        status = "success"
        queries = build_queries(influx_symbols, begin, end, interval)
        result = query(queries, epoch)

    frame = handle_influxdb_result(result, influx_symbols, interval)
    # order the resulting symbols the way the user requested
    # assume time is the first column
    frame = frame[symbols]
    if site:
        frame.columns = short_symbols

    handler = RESPONSE_HANDLERS[fmt]
    return handler(frame, epoch, sep=sep, order=order, status=status, code=code, message=message)