"""

1. Cat PDS files together
2. Index packets
    - drop any leading hanging packets
    -
3. Sort index
4. Write

"""
import logging
import os
from collections import deque, OrderedDict
from functools import total_ordering

from . import eos as _eos

LOG = logging.getLogger(__name__)

VIIRS_APID_ORDER = (
    (826, 821) + tuple(range(800, 821)) + tuple(range(822, 826)) + (827, 828)
)


class _Ptr(object):
    """
    Represents one or more packets that share the same time timecode and apid.
    """

    def __init__(self, fobj, stamp, apid, offset, size, seqid, sort_key=None):
        self.fobj = fobj
        self.stamp = stamp
        self.apid = apid
        self.offset = offset
        self.size = size
        self.seqid = seqid
        self.count = 1
        self._sort_key = sort_key or (self.stamp, self.apid)

    def __repr__(self):
        attrs = " ".join(
            "{}={}".format(k, v)
            for k, v in sorted(vars(self).items())
            if not k.startswith("_")
        )

        return "<{:s} {:s}>".format(self.__class__.__name__, attrs)

    def __eq__(self, that):
        return (self.stamp, self.apid, self.seqid) == (
            that.stamp,
            that.apid,
            that.seqid,
        )

    def __ne__(self, that):
        return self != that

    def __lt__(self, that):
        return self._sort_key < that._sort_key

    # hash by stamp, apid, size so we can dedup in index using set
    def __hash__(self):
        return hash((self.stamp, self.apid, self.seqid, self.size))

    def bytes(self):
        self.fobj.seek(self.offset, os.SEEK_SET)
        return self.fobj.read(self.size)


def _sort_key(p, order=None, eos=False):
    if eos and p.apid == 64:
        return _eos._modis_sort_key(p)
    return (p.stamp, order.index(p.apid) if order else p.apid)


def read_packet_index(stream, order=None, eos=False):
    index = deque()
    try:
        # drop any leading hanging packets
        count = 0
        packet = stream.next()
        while not packet.stamp:
        #while not (packet.is_first() or packet.is_standalone()):
            packet = stream.next()
            count += 1
        if count:
            LOG.info("dropped %d leading packets", count)

        while True:
            if not packet.stamp:
            #if not (packet.is_first() or packet.is_standalone()):
                # corrupt packet groups can cause apid mismatch
                # so skip until we get to the next group
                packet = stream.next()
                continue
            ptr = _Ptr(
                stream.file,
                stamp=packet.stamp,
                apid=packet.apid,
                offset=packet.offset,
                size=packet.size,
                seqid=packet.primary_header.source_sequence_count,
                sort_key=_sort_key(packet, order, eos),
            )
            index.append(ptr)
            # collect all packets for this timecode/group
            packet = stream.next()
            while not packet.stamp:
            #while not (packet.is_first() or packet.is_standalone()):
                # Bail if we're collecting group packets and apids don't match
                # This means group is corrupt
                if ptr.apid != packet.apid:
                    break
                ptr.size += packet.size
                ptr.count += 1
                packet = stream.next()

    except StopIteration:
        pass

    return index


def merge(streams, output, trunc_to=None, apid_order=None, eos=False):
    """
    Merge packets from multiple streams to an output file. Duplicate packets
    will be filtered and packets will be sorted by time and the apid order
    provided.

    :param streams: List of `PacketStream`s such as returned by
        `jpss_packet_stream`.
    :param output: File-like object to write output packets to.
    :keyword trunc_to: (start, end) datetime to truncate packets to. Start is
        inclusive and end is exclusive.
    :keyword apid_order: List of apid is the order in which they should appear
        in the output. If provided ALL apids must be present, otherwise an
        IndexError will occur.
    """
    index = set()  # will remove duplicates
    for stream in streams:
        LOG.debug("indexing %s", stream)
        index.update(read_packet_index(stream, order=apid_order, eos=eos))

    LOG.debug("sorting index with %d pointers", len(index))
    # _Ptr class implements custom sort key impl for sorting
    index = sorted(index)

    LOG.debug("writing index to %s", output)
    for ptr in index:
        if trunc_to:
            if trunc_to[0] <= ptr.stamp < trunc_to[1]:
                output.write(ptr.bytes())
        else:
            output.write(ptr.bytes())