import os
import re
import sys
from datetime import datetime, timedelta

import rrdtool
import numpy as np

from metobs.data import wind_vector_degrees, to_unix_timestamp
from aosstower import meta


class ModelError(Exception):
    """Base class for model errors.
    """


class WrapErrors(object):
    """Class wrapper to catch exceptions and properly re-raise them such that
    the only exceptions to propagate are `ModelError`s. Essentially, this
    prevents anyone from having to import rrdtool lib.
    """

    def __init__(self, *exceptions):
        self.exceptions = exceptions

    def __call__(self, cls):
        def _wrap(fcn):
            def wrapped(*args, **kwargs):
                try:
                    return fcn(*args, **kwargs)
                except self.exceptions as err:
                    traceback = sys.exc_info()[2]
                    raise ModelError, str(err), traceback
            wrapped.__doc__ = fcn.__doc__
            return wrapped
        for name in dir(cls):
            value = getattr(cls, name)
            if not name.startswith('_') and hasattr(value, '__call__'):
                setattr(cls, name, _wrap(value))

        return cls


def initialize(filepath, start=None, days=365, data_interval=5):
    """Create a new empty RRD database.
    """
    assert not os.path.exists(filepath), "DB already exists"
    start = start or (datetime.utcnow() - timedelta(days=days))
    # normalize start to data interval
    secs = to_unix_timestamp(start)
    secs -= secs % data_interval

    rrdtool.create(filepath,
                   '--start={}'.format(secs),
                   '--step={:d}'.format(data_interval),
                   'DS:air_temp:GAUGE:10:-40:50',
                   'DS:rh:GAUGE:10:0:100',
                   'DS:dewpoint:GAUGE:10:0:100',
                   'DS:wind_speed:GAUGE:10:0:100',
                   'DS:winddir_north:GAUGE:10:-100:100',
                   'DS:winddir_east:GAUGE:10:-100:100',
                   'DS:pressure:GAUGE:10:0:1100',
                   'DS:precip:GAUGE:10:0:100',
                   'DS:accum_precip:GAUGE:10:0:100',
                   'DS:solar_flux:GAUGE:10:0:1000',
                   'DS:altimeter:GAUGE:10:0:100',
                   # native resolution
                   'RRA:AVERAGE:0.5:1:6307200',
                   # 1 minute
                   'RRA:AVERAGE:0.5:{:d}:525600'.format(60/data_interval),
                   # 5 minute
                   'RRA:AVERAGE:0.5:{:d}:105120'.format(300/data_interval),
                   # 30 minute
                   'RRA:AVERAGE:0.5:{:d}:17520'.format(1800/data_interval))


@WrapErrors(rrdtool.error)
class RrdModel(object):
    """Model for storing the Level0 uncalibrated data for non-scientific
    purposes, such as web-widgets.
    """

    def __init__(self, filepath):
        self._filepath = filepath
        self._averages = tuple()
        self._datasets = None

    @property
    def datasets(self):
        """Get dataset names available in the database.
        """
        if self._datasets is None:
            datasets = set()
            info = rrdtool.info(self._filepath)
            for key in info.keys():
                match = re.match('^ds\[(.*)\]', key)
                if not match:
                    continue
                datasets.add(match.groups()[0])
            self._datasets = tuple(sorted(datasets))
        return self._datasets

    def averaging_intervals(self):
        """Lazy load averaging intervals from database.
        """
        if not self._averages:
            averages = set()
            info = rrdtool.info(self._filepath)
            for key in info.keys():
                if key.startswith('rra') and key.endswith('pdp_per_row'):
                    averages.add(int(info[key] * info['step']))
            self._averages = tuple(sorted(averages))
        return self._averages

    def _format_data(self, stamp, data):
        """Format data for insert into RRD returning a template string and data
        line appropriate for arguments to rrdupdate.
        """
        validkeys = set(self.datasets).intersection(data.keys())
        if not validkeys:
            raise ModelError("No valid data keys provided", data)
        tmpl = ':'.join(validkeys)
        values = ':'.join([str(data[k]) for k in validkeys])
        values = '{:d}@{}'.format(to_unix_timestamp(stamp), values)
        return tmpl, values

    def add_record(self, stamp, record):
        """Add a single record to the database, where a record is a dict like
        object with keys for each dataset. Additional keys are ignored.
        """
        # Normalize to data interval
        utime = to_unix_timestamp(stamp)
        data_interval = min(self.averaging_intervals())
        stamp = datetime.utcfromtimestamp(utime - utime % data_interval)

        tmpl, data = self._format_data(stamp, dict(record))
        rrdtool.update(self._filepath, '--template=%s' % tmpl, data)

    def get_slice(self, start, end, names=None, average=5):
        """Get a slice of data from the database.

        :param start: Start time as datetime
        :param end: Inclusive end time as datetime
        :param names: Names to query for, defaults to all available, see ``datasets``
        :param average: Averaging interval supported by the database, see ``averaging_intervals``.
        """
        if average not in self.averaging_intervals():
            raise ValueError("Invalid average:%d", average)
        names = names or self.datasets[:]

        if isinstance(start, datetime):
            start = to_unix_timestamp(start)
        if isinstance(end, datetime):
            end = to_unix_timestamp(end)

        # normalize request times to averaging interval
        start -= start % average
        end -= end % average

        # we always get all the data, no matter what was requested
        range, columns, rawdata = rrdtool.fetch(self._filepath,
                                                'AVERAGE',
                                                '-r {:d}'.format(average),
                                                '-s {:d}'.format(start),
                                                '-e {:d}'.format(end))

        src_data = np.array(rawdata)
        # NaN filled matrix of shape big enough for the request names
        dst_data = np.zeros((src_data.shape[0], len(names))) * float('nan')

        # get only the columns we're interested in
        for dst_idx, name in enumerate(names):
            if name in columns:
                dst_data[:, dst_idx] = src_data[:, columns.index(name)]

            # recompose the wind direction if asked for  
            elif name == 'wind_dir':
                east = src_data[:, self.datasets.index('winddir_east')].astype(np.float64)
                north = src_data[:, self.datasets.index('winddir_north')].astype(np.float64)
                dst_data[:, dst_idx] = wind_vector_degrees(east, north)

        # generate column of times for the req average interval
        times = np.array([np.arange(start, end + average, average)])
        return np.concatenate((times.T, dst_data), axis=1)