"""

1. Cat PDS files together
2. Index packets
    - drop any leading hanging packets
    -
3. Sort index
4. Write

"""
import os
import logging
from collections import deque, OrderedDict

LOG = logging.getLogger(__name__)

VIIRS_APID_ORDER = [826, 821] + range(800, 821) + range(822,826)

class _Ptr(object):
    """
    Represents one or more packets that share the same time stamp and apid.
    """
    def __init__(self, fobj, stamp, apid, offset, size):
        self.fobj = fobj
        self.stamp = stamp
        self.apid = apid
        self.offset = offset
        self.size = size
        self.count = 1

    def __repr__(self):
        attrs = ' '.join(
                '{}={}'.format(k, v)
                for k, v in sorted(vars(self).items())
                if not k.startswith('_'))

        return '<{:s} {:s}>'.format(self.__class__.__name__, attrs)

    def __cmp__(self, that):
        return cmp(
            (self.stamp, self.apid),
            (that.stamp, that.apid)
        )

    # instances with same stamp/apid will compare the same
    def __hash__(self):
        return hash((self.stamp, self.apid, self.size))

    def bytes(self):
        self.fobj.seek(self.offset, os.SEEK_SET)
        return self.fobj.read(self.size)


def read_packet_index(stream):
    index = deque()
    try:
        # drop any leading hanging packets
        count = 0
        packet = stream.next()
        while not packet.stamp:
            packet = stream.next()
            count += 1
        if count:
            LOG.info('dropped %d leading packets', count)

        while True:
            ptr = _Ptr(
                stream.file,
                stamp=packet.stamp,
                apid=packet.apid,
                offset=packet.offset,
                size=packet.size,
            )
            index.append(ptr)
            # collect all packets for this stamp/group
            packet = stream.next()
            while not packet.stamp:
                ptr.size += packet.size
                ptr.count += 1
                packet = stream.next()

    except StopIteration:
        pass

    return index


def _sort_by_time_apid(index, order=None):
    if order:
        index = sorted(index, key=lambda p: order.index(p.apid) if p.apid in order else -1)
    else:
        index = sorted(index, key=lambda p: p.apid)
    return sorted(index, key=lambda p: p.stamp)


def _filter_duplicates_by_size(index):
    filtered = OrderedDict()
    for ptr in index:
        key = (ptr.stamp, ptr.apid)
        if key in filtered:
            if ptr.size > filtered[key].size:
                filtered[key] = ptr
        else:
            filtered[key] = ptr
    return filtered.values()


def merge(streams, output, trunc_to=None, apid_order=None):
    """
    Merge packets from multiple streams to an output file. Duplicate packets
    will be filtered and packets will be sorted by time and the apid order
    provided.

    :param streams: List of `PacketStream`s such as returned by
        `jpss_packet_stream`.
    :param output: File-like object to write output packets to.
    :keyword trunc_to: (start, end) datetime to truncate packets to. Start is
        inclusive and end is exclusive.
    :keyword apid_order: List of apid is the order in which they should appear
        in the output. If provided ALL apids must be present, otherwise an
        IndexError will occur.
    """
    index = set()  # will remove duplicates
    for stream in streams:
        LOG.debug('indexing %s', stream)
        index.update(read_packet_index(stream))

    LOG.debug('sorting index with %d pointers', len(index))
    index = _sort_by_time_apid(index, order=apid_order)

    index = _filter_duplicates_by_size(index)

    LOG.debug('writing index to %s', output)
    for ptr in index:
        if trunc_to:
            if ptr.stamp >= trunc_to[0] and ptr.stamp < trunc_to[1]:
                output.write(ptr.bytes())
        else:
            output.write(ptr.bytes())