import os
import re
import sys
from datetime import datetime, timedelta

import rrdtool
import numpy as np

from metobs.data import (wind_vector_degrees, to_unix_timestamp,
                         wind_vector_components)


def dewpoint(tempC, relhum):
    """
    Algorithm from Tom Whittaker tempC is the temperature in degrees Celsius,
    relhum is the relative humidity as a percentage.

    :param tempC: temperature in celsius
    :param relhum: relative humidity as a percentage
    """
    gasconst = 461.5
    latheat = 2500800.0

    dp = (1.0 / (1.0 / (273.15 + tempC) - gasconst * np.log((0.0 + relhum) / 100) / (latheat - tempC * 2397.5)))

    return np.minimum(dp - 273.15, tempC)


class ModelError(Exception):
    """Base class for model errors.
    """


class WrapErrors(object):
    """Class wrapper to catch exceptions and properly re-raise them such that
    the only exceptions to propagate are `ModelError`s. Essentially, this
    prevents anyone from having to import rrdtool lib.
    """

    def __init__(self, *exceptions):
        self.exceptions = exceptions

    def __call__(self, cls):
        def _wrap(fcn):
            def wrapped(*args, **kwargs):
                try:
                    return fcn(*args, **kwargs)
                except self.exceptions as err:
                    traceback = sys.exc_info()[2]
                    raise ModelError, str(err), traceback
            wrapped.__doc__ = fcn.__doc__
            return wrapped
        for name in dir(cls):
            value = getattr(cls, name)
            if not name.startswith('_') and hasattr(value, '__call__'):
                setattr(cls, name, _wrap(value))

        return cls


@WrapErrors(rrdtool.error)
class RrdModel(object):
    """Model for storing the Level0 uncalibrated data for non-scientific
    purposes, such as web-widgets.
    """

    # data interval in seconds
    DATA_INTERVAL = 5

    def __init__(self, filepath):
        self._filepath = filepath
        self._averages = tuple()
        self._datasets = None

    @property
    def datasets(self):
        """Get dataset names available in the database.
        """
        if self._datasets is None:
            datasets = set()
            info = rrdtool.info(self._filepath)
            for key in info.keys():
                match = re.match('^ds\[(.*)\]', key)
                if not match:
                    continue
                datasets.add(match.groups()[0])
            self._datasets = tuple(sorted(datasets))
        return self._datasets

    def averaging_intervals(self):
        """Lazy load averaging intervals from database.
        """
        if not self._averages:
            averages = set()
            info = rrdtool.info(self._filepath)
            for key in info.keys():
                if key.startswith('rra') and key.endswith('pdp_per_row'):
                    averages.add(int(info[key] * info['step']))
            self._averages = tuple(sorted(averages))
        return self._averages

    def initialize(self, start=None, days=365):
        """Create a new empty RRD database.
        """
        assert not os.path.exists(self._filepath), "DB already exists"
        start = start or (datetime.utcnow() - timedelta(days=days))
        # normalize start to data interval
        secs = to_unix_timestamp(start)
        secs -= secs % self.DATA_INTERVAL

        rrdtool.create(self._filepath,
                       '--start={}'.format(secs),
                       '--step={:d}'.format(self.DATA_INTERVAL),
                       'DS:air_temp:GAUGE:10:-40:50',
                       'DS:rh:GAUGE:10:0:100',
                       'DS:dewpoint:GAUGE:10:0:100',
                       'DS:wind_speed:GAUGE:10:0:100',
                       'DS:winddir_north:GAUGE:10:-100:100',
                       'DS:winddir_east:GAUGE:10:-100:100',
                       'DS:pressure:GAUGE:10:0:1100',
                       'DS:precip:GAUGE:10:0:100',
                       'DS:accum_precip:GAUGE:10:0:100',
                       'DS:solar_flux:GAUGE:10:0:1000',
                       'DS:altimeter:GAUGE:10:0:100',
                       'RRA:AVERAGE:0.5:1:6307200',
                       'RRA:AVERAGE:0.5:{:d}:525600'.format(60/self.DATA_INTERVAL),
                       'RRA:AVERAGE:0.5:{:d}:105120'.format(300/self.DATA_INTERVAL),
                       'RRA:AVERAGE:0.5:{:d}:17520'.format(1800/self.DATA_INTERVAL))

    def _format_data(self, stamp, data):
        """Format data for insert into RRD.
        """
        values = ':'.join([str(data[k]) for k in self.datasets])
        values = '{:d}@{}'.format(to_unix_timestamp(stamp), values)
        return values

    def _record_to_data(self, record):
        """Turn a tower record into database data.
        """

        expected_keys = set(self.datasets) - {'winddir_north', 'winddir_east'}
        missing_keys = expected_keys - set(record.keys())
        if missing_keys:
            raise ModelError("Missing datasets %s" % missing_keys)

        data = {}
        wspd, wdir = (float(record.wind_speed), float(record.wind_dir))
        winds = wind_vector_components(wspd, wdir)
        data['winddir_east'] = winds[0]
        data['winddir_north'] = winds[1]
        data['wind_speed'] = winds[2]
        for name in self.datasets:
            if name in record:
                data[name] = record[name]
        return data

    def add_record(self, record):
        """Add a single record to the database.
        """
        # Normalize to data interval
        utime = to_unix_timestamp(record.get_stamp())
        stamp = datetime.utcfromtimestamp(utime - utime % self.DATA_INTERVAL)
        data = self._record_to_data(record)
        rrdtool.update(self._filepath,
                       '--template=%s' % ':'.join(self.datasets),
                       self._format_data(stamp, data))

    def get_slice(self, start, end, names=None, average=5):
        """Get a slice of data from the database.

        :param start: Start time as datetime
        :param end: Inclusive end time as datetime
        :param names: Names to query for, defaults to all available, see ``datasets``
        :param average: Averaging interval supported by the database, see ``averaging_intervals``.
        """
        if average not in self.averaging_intervals():
            raise ValueError("Invalid average:%d", average)
        names = names or self.datasets[:]

        if isinstance(start, datetime):
            start = to_unix_timestamp(start)
        if isinstance(end, datetime):
            end = to_unix_timestamp(end)

        # normalize request times to averaging interval
        start -= start % average
        end -= end % average

        # we always get all the data, no matter what was requested
        range, columns, rawdata = rrdtool.fetch(self._filepath,
                                                'AVERAGE',
                                                '-r {:d}'.format(average),
                                                '-s {:d}'.format(start),
                                                '-e {:d}'.format(end))

        src_data = np.array(rawdata)
        # NaN filled matrix of shape big enough for the request names
        dst_data = np.zeros((src_data.shape[0], len(names))) * float('nan')

        # get only the columns we're interested in
        for dst_idx, name in enumerate(names):
            if name in columns:
                dst_data[:, dst_idx] = src_data[:, columns.index(name)]

                # we compute dewpoint since it wasn't always available
                if name == 'dewpoint':
                    temp = src_data[:, self.datasets.index('air_temp')].astype(np.float64)
                    rh = src_data[:, self.datasets.index('rh')].astype(np.float64)
                    dst_data[:, dst_idx] = dewpoint(temp, rh)

            # get the wind direction in degrees from the vector components
            elif name == 'wind_dir':
                east = src_data[:, self.datasets.index('winddir_east')].astype(np.float64)
                north = src_data[:, self.datasets.index('winddir_north')].astype(np.float64)
                dst_data[:, dst_idx] = wind_vector_degrees(east, north)

        # generate column of times for the req average interval
        times = np.array([np.arange(start, end + average, average)])
        return np.concatenate((times.T, dst_data), axis=1)