"""
Code for reading/writing/manipulating JPSS Common RDR files as documented in:

    Joint Polar Satellite System (JPSS)
    Common Data Format Control Book - External (CDFCB-X)
    Volume II - RDR Formats

    http://jointmission.gsfc.nasa.gov/sciencedocs/2015-06/474-00001-02_JPSS-CDFCB-X-Vol-II_0123B.pdf
"""
import logging
import ctypes as c
from collections import namedtuple

import numpy as np
from h5py import File as H5File

LOG = logging.getLogger(__name__)


class BaseStruct(c.BigEndianStructure):
    _pack_ = 1
    _fields_ = []

    def __repr__(self):
        attrs = ' '.join('%s=%s' % (f[0], getattr(self, f[0])) for f in self._fields_)
        return '<{} {}>'.format(self.__class__.__name__, attrs)


class StaticHeader(BaseStruct):

    _fields_ = [
        ('satellite', c.c_char * 4),
        ('sensor', c.c_char * 16),
        ('type_id', c.c_char * 16),
        ('num_apids', c.c_uint32),
        ('apid_list_offset', c.c_uint32),
        ('pkt_tracker_offset', c.c_uint32),
        ('ap_storage_offset', c.c_uint32),
        ('next_pkt_pos', c.c_uint32),
        ('start_boundry', c.c_uint64),
        ('end_boundry', c.c_uint64),
    ]


class Apid(BaseStruct):

    _fields_ = [
        ('name', c.c_char * 16),
        ('value', c.c_uint32),
        ('pkt_tracker_start_idx', c.c_uint32),
        ('pkts_reserved', c.c_uint32),
        ('pkts_received', c.c_uint32),
    ]


class PacketTracker(BaseStruct):

    _fields_ = [
        ('obs_time', c.c_int64),
        ('sequence_number', c.c_int32),
        ('size', c.c_int32),
        ('offset', c.c_int32),
        ('fill_percent', c.c_int32),
    ]


def _make_packet(buf, size, offset):
    """
    Create a struct for a CCSDS packet. The struct size is dynamic so we need
    a new class for each packet.

    XXX: Many packets will have similar sizes so perhaps we should cache
         packet impls based on their size.
    """
    # size minus the CCSDS primary header size
    data_size = size - 6
    class PktImpl(BaseStruct):  # noqa
        _fields_ = [
            ('version', c.c_uint8, 3),
            ('type', c.c_uint8, 1),
            ('secondary_header', c.c_uint8, 1),
            ('apid', c.c_uint16, 11),
            ('sequence_grouping', c.c_uint8, 2),
            ('sequence_number', c.c_uint16, 14),
            ('length_minus1', c.c_uint16),
            ('data', c.c_uint8 * data_size)
        ]
    pkt = PktImpl.from_buffer(buf, offset)
    assert pkt.length_minus1 + 1 == data_size
    return pkt


Packet = namedtuple('Packet', ('tracker', 'packet'))


class CommonRdr(namedtuple('CommonRdr', ('buf', 'header', 'apids'))):

    def packets_for_apid(self, apid):
        """
        Return a generator that yields tuples of (tracker, packet) for the
        given `Apid`.
        """
        return _packets_for_apid(self.buf, self.header, apid)

    def packets(self):
        """
        Return a generator that yields `Packet`s for each apid in the
        apid list.
        """
        for apid in self.apids:
            for tracker, packet in self.packets_for_apid(apid):
                yield Packet(tracker, packet)


def _packets_for_apid(buf, header, apid):
    t_off = header.pkt_tracker_offset + apid.pkt_tracker_start_idx * c.sizeof(PacketTracker)
    for idx in range(apid.pkts_received):
        tracker = PacketTracker.from_buffer(buf, t_off)
        t_off += c.sizeof(PacketTracker)
        if tracker.offset < 0:  # -1 == missing
            continue

        p_off = header.ap_storage_offset + tracker.offset
        # packet = buf[p_off:p_off+tracker.size]
        yield tracker, _make_packet(buf, tracker.size, p_off)


def _read_apid_list(header, buf):
    """
    Return a generator that yields `Apid`s
    """
    start = header.apid_list_offset
    end = header.num_apids * c.sizeof(Apid)
    for offset in range(start, end, c.sizeof(Apid)):
        yield Apid.from_buffer(buf, offset)


def read_common_rdrs(sensor, filepath):
    """
    Return a generator that yields `CommonRdr` for each dataset provided by
    `read_rdr_datasets`.
    """
    for buf in read_rdr_datasets(sensor, filepath):
        header = StaticHeader.from_buffer(buf)
        apids = _read_apid_list(header, buf)
        yield CommonRdr(buf, header, list(apids))


def read_rdr_datasets(sensor, filepath):
    """
    Return a generator that yields bytearrays for each RawApplicationPackets
    dataset in numerical order.
    """
    sensor = sensor.upper()
    fobj = H5File(filepath)
    grp = fobj['/All_Data/{}-SCIENCE-RDR_All'.format(sensor)]

    dsnames = [k for k in grp.keys() if k.startswith('RawApplicationPackets_')]
    # compare ds names by index after '_'
    cmp_names = lambda *tup: cmp(*(int(x.split('_')[-1]) for x in tup))  # noqa
    for name in sorted(dsnames, cmp=cmp_names):
        ds = grp[name]
        yield np.getbuffer(np.array(ds))


def sort_packets_by_obs_time(packets):
    """
    Sort `Packet`s in-place by the PacketTracker obs_time (IET).
    """
    return sorted(packets, key=lambda p: p.tracker.obs_time)


def sort_packets_by_apid(packets, order=None):
    """
    Sort packets in-place by apid. An error will be raised if `order` is given
    and packet has an apid not in the list.
    """
    if order:
        return sorted(packets, key=lambda p: order.index(p.packet.apid))
    else:
        return sorted(packets, key=lambda p: p.packet.apid)


def sort_packets_edos(rdr, packets):
    """
    Sort packets by timestamp and order the apids the same as EDOS NASA L0
    files.
    """
    order = None
    if rdr.header.sensor.lower() == 'viirs':
        order = [826, 821] + range(0, 825)
    return sort_packets_by_obs_time(sort_packets_by_apid(packets, order=order))


def convert_to_nasa_l0(sensor, filename, skipfill=False, outfn=None):
    """
    Convert a JSPP RDR to a NASA L0 PDS file.
    """
    outfn = outfn or '{}.pds'.format(filename)
    with open(outfn, 'wb') as fptr:
        for idx, rdr in enumerate(read_common_rdrs(sensor, filename)):
            packets = sort_packets_edos(rdr, rdr.packets())
            for packet in packets:
                if packet.tracker.fill_percent != 0:
                    LOG.debug('fill: %s', packet.tracker)
                    if skipfill:
                        continue
                fptr.write(packet.packet)
    return outfn


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-o', '--output')
    parser.add_argument('-f', '--skipfill', action='store_true')
    parser.add_argument('sensor', choices=('viirs', 'atms', 'cris'))
    parser.add_argument('rdr')
    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG)
    convert_to_nasa_l0(args.sensor, args.rdr, outfn=args.output)