"""

1. Cat PDS files together
2. Index packets
    - drop any leading hanging packets
    -
3. Sort index
4. Write

"""
import logging
import os
from collections import deque, OrderedDict
from functools import total_ordering

LOG = logging.getLogger(__name__)

VIIRS_APID_ORDER = (826, 821) + tuple(range(800, 821)) + tuple(range(822, 826))


@total_ordering
class _Ptr(object):
    """
    Represents one or more packets that share the same time timecode and apid.
    """

    def __init__(self, fobj, stamp, apid, offset, size):
        self.fobj = fobj
        self.stamp = stamp
        self.apid = apid
        self.offset = offset
        self.size = size
        self.count = 1

    def __repr__(self):
        attrs = " ".join(
            "{}={}".format(k, v)
            for k, v in sorted(vars(self).items())
            if not k.startswith("_")
        )

        return "<{:s} {:s}>".format(self.__class__.__name__, attrs)

    def __eq__(self, that):
        return (self.stamp, self.apid) == (that.stamp, that.apid)

    def __ne__(self, that):
        return not self == that

    def __lt__(self, that):
        return (self.stamp, self.apid) < (that.stamp, that.apid)

    # hash by stamp, apid, size so we can dedup in index using set
    def __hash__(self):
        return hash((self.stamp, self.apid, self.size))

    def bytes(self):
        self.fobj.seek(self.offset, os.SEEK_SET)
        return self.fobj.read(self.size)


def read_packet_index(stream):
    index = deque()
    try:
        # drop any leading hanging packets
        count = 0
        packet = stream.next()
        while not packet.stamp:
            packet = stream.next()
            count += 1
        if count:
            LOG.info("dropped %d leading packets", count)

        while True:
            if not packet.stamp:
                # corrupt packet groups can cause apid mismatch
                # so skip until we get to the next group
                packet = stream.next()
                continue
            ptr = _Ptr(
                stream.file,
                stamp=packet.stamp,
                apid=packet.apid,
                offset=packet.offset,
                size=packet.size,
            )
            index.append(ptr)
            # collect all packets for this timecode/group
            packet = stream.next()
            while not packet.stamp:
                # Bail if we're collecting group packets and apids don't match
                # This means group is corrupt
                if ptr.apid != packet.apid:
                    break
                ptr.size += packet.size
                ptr.count += 1
                packet = stream.next()

    except StopIteration:
        pass

    return index


def _sort_by_time_apid(index, order=None):
    """
    Sort pointers by time and apid in-place.
    """
    if order:
        index = sorted(
            index, key=lambda p: order.index(p.apid) if p.apid in order else -1
        )
    else:
        index = sorted(index, key=lambda p: p.apid)
    return sorted(index, key=lambda p: p.stamp)


def _filter_duplicates_by_size(index):
    """
    Take the packet with the largest size.
    """
    filtered = OrderedDict()
    for ptr in index:
        key = (ptr.stamp, ptr.apid)
        if key in filtered:
            if ptr.size > filtered[key].size:
                filtered[key] = ptr
        else:
            filtered[key] = ptr
    return filtered.values()


def merge(streams, output, trunc_to=None, apid_order=None):
    """
    Merge packets from multiple streams to an output file. Duplicate packets
    will be filtered and packets will be sorted by time and the apid order
    provided.

    :param streams: List of `PacketStream`s such as returned by
        `jpss_packet_stream`.
    :param output: File-like object to write output packets to.
    :keyword trunc_to: (start, end) datetime to truncate packets to. Start is
        inclusive and end is exclusive.
    :keyword apid_order: List of apid is the order in which they should appear
        in the output. If provided ALL apids must be present, otherwise an
        IndexError will occur.
    """
    index = set()  # will remove duplicates
    for stream in streams:
        LOG.debug("indexing %s", stream)
        index.update(read_packet_index(stream))

    LOG.debug("sorting index with %d pointers", len(index))
    index = _sort_by_time_apid(index, order=apid_order)

    index = _filter_duplicates_by_size(index)

    LOG.debug("writing index to %s", output)
    for ptr in index:
        if trunc_to:
            if trunc_to[0] <= ptr.stamp < trunc_to[1]:
                output.write(ptr.bytes())
        else:
            output.write(ptr.bytes())