import ctypes as c
import errno
import logging
import os
from collections import defaultdict, deque, namedtuple

from edosl0util import headers
from edosl0util.headers import (GROUP_CONTINUING, GROUP_FIRST, GROUP_LAST,
                                GROUP_STANDALONE, PrimaryHeader)
from edosl0util.timecode import dt_to_cds

LOG = logging.getLogger(__name__)

# Max CCSDS sequence id
MAX_SEQ_ID = 2 ** 14 - 1


class Error(Exception):
    """
    Base stream error.
    """


class PacketTooShort(Error):
    """
    The number of bytes read for a packet does not match the expected length.
    After this error occurs the stream is no longer usable because data offsets
    are no longer reliable.
    """


class NonConsecutiveSeqId(Error):
    """
    Non-consecutive sequence numbers encounterd for apid, i.e., packet gap.
    """


class BasicStream(object):
    """
    Basic packet stream iterator that reads the primary and secondary headers
    and maintains offsets and read sizes.
    """

    Tracker = namedtuple("Tracker", ["h1", "h2", "size", "offset", "data"])

    def __init__(self, fobj, header_lookup=None, with_data=True):
        self.file = fobj
        self.header_lookup = header_lookup
        self.with_data = with_data
        try:
            self._offset = self.file.tell()
        except IOError as err:  # handle illegal seek for pipes
            if err.errno != errno.ESPIPE:
                raise
            self._offset = 0

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def _read(self, size):
        buf = self.file.read(size)
        if not buf:  # EOF
            raise StopIteration()
        if len(buf) != size:
            raise PacketTooShort(
                "expected to read {:d} bytes, got {:d} at offset {:d}".format(
                    size, len(buf), self._offset
                ),
                self.file,
            )
        self._offset += size
        return buf

    def read_primary_header(self):
        buf = self._read(headers.PRIMARY_HEADER_SIZE)
        h1 = PrimaryHeader.from_buffer_copy(buf)
        return h1, headers.PRIMARY_HEADER_SIZE

    def read_secondary_header(self, ph):
        H2Impl = self.header_lookup(ph) if self.header_lookup else None
        if H2Impl is None:
            return bytes(), 0
        h2size = c.sizeof(H2Impl)
        buf = self._read(h2size)
        return H2Impl.from_buffer_copy(buf), h2size

    def next(self):
        offset = self._offset
        h1, h1size = self.read_primary_header()
        h2, h2size = self.read_secondary_header(h1)
        # data length includes h2size
        total_size = h1size + h1.data_length_minus1 + 1
        data_size = h1.data_length_minus1 + 1 - h2size
        data = None
        if self.with_data and data_size:
            data = self._read(data_size)
        else:
            self.file.seek(data_size, os.SEEK_CUR)
            self._offset += data_size
        return self.Tracker(h1, h2, total_size, offset, data)


class Packet(object):
    def __init__(self, h1, h2, data, data_size=None, offset=None):
        self.primary_header = h1
        self.secondary_header = h2
        self.data = data
        self.size = data_size
        self.offset = offset
        self.apid = h1.apid
        self.seqid = self.primary_header.source_sequence_count

    def __str__(self):
        return "<Packet apid=%d seqid=%d stamp=%s size=%s offset=%s>" % (
            self.apid,
            self.seqid,
            self.stamp,
            self.size,
            self.offset,
        )

    __repr__ = __str__

    @property
    def cds_timecode(self):
        return (
            self.secondary_header
            and hasattr(self.secondary_header, "timecode")
            and self.secondary_header.timecode.day_segmented_timecode()
            or None
        )

    @property
    def stamp(self):
        return (
            self.secondary_header
            and hasattr(self.secondary_header, "timecode")
            and self.secondary_header.timecode.asdatetime()
            or None
        )

    def bytes(self):
        return (
            bytearray(self.primary_header)
            + bytearray(self.secondary_header)
            + self.data
        )

    def is_group(self):
        return self.is_first() or self.is_continuing() or self.is_last()

    def is_first(self):
        return self.primary_header.sequence_grouping == GROUP_FIRST

    def is_continuing(self):
        return self.primary_header.sequence_grouping == GROUP_CONTINUING

    # compatibility
    is_continuine = is_continuing

    def is_last(self):
        return self.primary_header.sequence_grouping == GROUP_LAST

    def is_standalone(self):
        return self.primary_header.sequence_grouping == GROUP_STANDALONE


class PacketStream(object):
    SEQID_NOTSET = -1

    def __init__(self, data_stream, fail_on_missing=False, fail_on_tooshort=False):
        """
        :param data_stream: An interable of ``Tracker`` objects
        :keyword fail_on_missing:
            Raise a ``NonConsecutiveSeqId`` error when sequence id gaps are
            found. If this is False sequence id gaps are ignored.
        :keyword fail_on_tooshort:
            Raise a ``PacketTooShort`` error when a packet is encountered for
            which less bytes can be read than are expected according to the
            packet headers. Most of the this would be due to a truncated file
            with an incomplete packet at the end. If this is False a
            ``StopIteration`` will be raised which will effectively truncate
            the file.
        """
        self._stream = data_stream
        self._seek_cache = deque()
        self._first = None
        self._last = None
        self._apid_info = defaultdict(
            lambda: {"count": 0, "last_seqid": self.SEQID_NOTSET, "num_missing": 0}
        )
        self._fail_on_missing = fail_on_missing
        self._fail_on_tooshort = fail_on_tooshort

    def __repr__(self):
        filepath = getattr(self.file, "name", None)
        return "<{} file={}>".format(self.__class__.__name__, filepath)

    def __iter__(self):
        return self

    # python3
    def __next__(self):
        return self.next()

    @property
    def file(self):
        return self._stream.file

    def push_back(self, packet):
        """
        Put a packet back such that it is the next file provided when
        ``next`` is called.
        """
        self._seek_cache.append(packet)

    def next(self):
        # return any items on the seek cache before reading more
        if len(self._seek_cache):
            return self._seek_cache.popleft()
        try:
            h1, h2, data_size, offset, data = next(self._stream)
        except PacketTooShort as err:
            if self._fail_on_tooshort:
                raise
            LOG.error("Packet too short, aborting stream: %s", err)
            # The result of this is essentially a file truncation
            raise StopIteration()
        packet = Packet(h1, h2, data, data_size=data_size, offset=offset)
        have_missing = self._update_info(packet)
        if have_missing and self._fail_on_missing:
            self.push_back(packet)
            raise NonConsecutiveSeqId()
        return packet

    def _update_info(self, packet):
        """
        Handle gap detection and first/last. Returns whether any missing
        packets were detected.
        """
        have_missing = False
        apid = self._apid_info[packet.apid]
        if packet.stamp:
            if self._first is None:
                self._first = packet.stamp
            else:
                self._first = min(packet.stamp, self._first)
            if self._last is None:
                self._last = packet.stamp
            else:
                self._last = max(packet.stamp, self._last)
        apid["count"] += 1
        if apid["last_seqid"] != self.SEQID_NOTSET:
            if packet.seqid > apid["last_seqid"]:
                num_missing = packet.seqid - apid["last_seqid"] - 1
            else:
                num_missing = packet.seqid - apid["last_seqid"] + MAX_SEQ_ID
            # FIXME: shouldn't this just be num_missing != 0?
            have_missing = num_missing != apid["num_missing"]
            apid["num_missing"] += num_missing
        apid["last_seqid"] = packet.seqid
        return have_missing

    def info(self):
        return self._first, self._last, self._apid_info

    def seek_to(self, stamp, apid=None):
        """
        Seek forward such that the next packet provided will be the first packet
        having a timestamp available in the secondary header >= stamp and apid.
        """
        target_tc = dt_to_cds(stamp)
        # seek past partial groups
        packet = self.next()
        while not packet.cds_timecode:
            packet = self.next()

        # seek to first packet => `stamp`
        current_tc = packet.cds_timecode
        while current_tc < target_tc:
            packet = self.next()
            current_tc = packet.cds_timecode or current_tc

        # put back so next packet is the first packet > `stamp`
        self.push_back(packet)

        if apid is not None:
            packet = self.next()
            while packet.apid != apid:
                packet = self.next()
            self.push_back(packet)


def collect_groups(stream):
    """
    Collects packets into lists of groupped packets. If a packet is standalone
    it will be a list of just a single packet. If a packet is groupped it will
    be a list of packets in the group. There is no garauntee the group will be
    complete. The first packet in a group will always have a timestamp.
    """
    done = object()  # sentinal for stopping iter
    group = deque()
    while True:
        # drop leading packets without timestamp
        p = next(stream, done)
        if p is done:
            return
        # just pass through any non-groupped packets
        if p.is_standalone():
            yield [p]
            continue

        # Yield and start a new group if either we get a first packet or the current
        # packets APID does not match the current group.
        if p.is_first() or (group and group[-1].apid != p.apid):
            if group:
                yield list(group)
            group = deque([p])
            continue

        group.append(p)


def jpss_packet_stream(fobj, **kwargs):
    stream = BasicStream(fobj, headers.jpss_header_lookup)
    return PacketStream(stream, **kwargs)


def aqua_packet_stream(fobj, **kwargs):
    stream = BasicStream(fobj, headers.aqua_header_lookup)
    return PacketStream(stream, **kwargs)