"""
Code for reading/writing/manipulating JPSS Common RDR files as documented in:

    Joint Polar Satellite System (JPSS)
    Common Data Format Control Book - External (CDFCB-X)
    Volume II - RDR Formats

    http://jointmission.gsfc.nasa.gov/sciencedocs/2015-06/474-00001-02_JPSS-CDFCB-X-Vol-II_0123B.pdf
"""
__copyright__ = "Copyright (C) 2015 University of Wisconsin SSEC. All rights reserved."

import os
import logging
import ctypes as c
from collections import namedtuple

import numpy as np
from h5py import File as H5File

from edosl0util.headers import BaseStruct

LOG = logging.getLogger(__name__)


class StaticHeader(BaseStruct):
    """
    Common RDR static header.
    """
    _fields_ = [
        ('satellite', c.c_char * 4),
        ('sensor', c.c_char * 16),
        ('type_id', c.c_char * 16),
        ('num_apids', c.c_uint32),
        ('apid_list_offset', c.c_uint32),
        ('pkt_tracker_offset', c.c_uint32),
        ('ap_storage_offset', c.c_uint32),
        ('next_pkt_pos', c.c_uint32),
        ('start_boundry', c.c_uint64),
        ('end_boundry', c.c_uint64),
    ]


class Apid(BaseStruct):
    """
    Entry in the ApidList storage area.
    """
    _fields_ = [
        ('name', c.c_char * 16),
        ('value', c.c_uint32),
        ('pkt_tracker_start_idx', c.c_uint32),
        ('pkts_reserved', c.c_uint32),
        ('pkts_received', c.c_uint32),
    ]


class PacketTracker(BaseStruct):
    """
    Entry in the PacketTracker storage area.
    """
    _fields_ = [
        ('obs_time', c.c_int64),
        ('sequence_number', c.c_int32),
        ('size', c.c_int32),
        ('offset', c.c_int32),
        ('fill_percent', c.c_int32),
    ]


def _make_packet_impl(size):
    """
    Create a struct for a CCSDS packet. The struct size is dynamic so we need
    a new class for each packet.
    """
    # size minus the CCSDS primary header size
    data_size = size - 6
    class PktImpl(BaseStruct):  # noqa
        _fields_ = [
            ('version', c.c_uint8, 3),
            ('type', c.c_uint8, 1),
            ('secondary_header', c.c_uint8, 1),
            ('apid', c.c_uint16, 11),
            ('sequence_grouping', c.c_uint8, 2),
            ('sequence_number', c.c_uint16, 14),
            ('length_minus1', c.c_uint16),
            ('data', c.c_uint8 * data_size)
        ]
    return PktImpl


Packet = namedtuple('Packet', ('tracker', 'packet'))


class CommonRdr(namedtuple('CommonRdr', ('buf', 'header', 'apids'))):
    """
    A single common rdr as found in a RawApplicationPackets_X dataset.
    """

    def packets_for_apid(self, apid):
        """
        Return a generator that yields tuples of (tracker, packet) for the
        given `Apid`.
        """
        for tracker, packet in _packets_for_apid(self.buf, self.header, apid):
            yield Packet(tracker, packet)

    def packets(self):
        """
        Return a generator that yields `Packet`s for each apid in the
        apid list.
        """
        for apid in self.apids:
            for packet in self.packets_for_apid(apid):
                yield packet


def _packets_for_apid(buf, header, apid):
    """
    Generate tuples of (PacketTracker, Packet)
    """
    t_off = header.pkt_tracker_offset + apid.pkt_tracker_start_idx * c.sizeof(PacketTracker)
    for idx in range(apid.pkts_received):
        tracker = PacketTracker.from_buffer(buf, t_off)
        t_off += c.sizeof(PacketTracker)
        if tracker.offset < 0:  # -1 == missing
            continue

        p_off = header.ap_storage_offset + tracker.offset
        pkt_impl = _make_packet_impl(tracker.size)

        pkt = pkt_impl.from_buffer(buf, p_off)
        assert c.sizeof(pkt) == tracker.size
        yield tracker, pkt


def _read_apid_list(header, buf):
    """
    Generate Apid-s
    """
    offset = header.apid_list_offset
    for idx in range(header.num_apids):
        yield Apid.from_buffer(buf, offset)
        offset += c.sizeof(Apid)


def _sorted_packet_dataset_names(names):
    names = [k for k in names if k.startswith('RawApplicationPackets_')]
    return sorted(names, key=lambda x: int(x.split('_')[-1]))


def _generate_packet_datasets(group):
    dsnames = group.keys()
    for name in _sorted_packet_dataset_names(dsnames):
        ds = group[name]
        yield name, np.array(ds)


def _find_data_group(fobj, name, sensors=['viirs', 'cris', 'atms']):
    for sensor in sensors:
        group = fobj.get('/All_Data/{}-{}-RDR_All'.format(sensor.upper(), name.upper()))
        if group:
            return group


def _find_science_group(fobj):
    return _find_data_group(fobj, 'SCIENCE')


def _find_dwell_group(fobj):
    for name in ('HSKDWELL', 'IMDWELL', 'SSMDWELL'):
        group = _find_data_group(fobj, name, sensors=['cris'])
        if group:
            return group
    group = _find_data_group(fobj, 'DWELL', sensors=['atms'])
    if group:
        return group


def _find_diag_group(fobj):
    return _find_data_group(fobj, 'DIAGNOSTIC', sensors=['cris', 'atms'])


def _find_telemetry_group(fobj):
    return _find_data_group(fobj, 'TELEMETRY', sensors=['cris', 'atms'])


def _find_spacecraft_group(fobj):
    return _find_data_group(fobj, 'DIARY', sensors=['SPACECRAFT'])


def _rdrs_for_packet_dataset(group):
    rdrs = []
    if group:
        for name, buf in _generate_packet_datasets(group):
            if buf.shape[0] < c.sizeof(StaticHeader):
                LOG.warn('Not enough bytes for %s static header', name)
                continue
            header = StaticHeader.from_buffer(buf)
            apids = _read_apid_list(header, buf)
            rdrs.append(CommonRdr(buf, header, list(apids)))
    return rdrs


def rdr_datasets(filepath):
    fobj = H5File(filepath)
    rdr = dict(
        telemetry=_rdrs_for_packet_dataset(_find_telemetry_group(fobj)),
        diagnostic=_rdrs_for_packet_dataset(_find_diag_group(fobj)),
        dwell=_rdrs_for_packet_dataset(_find_dwell_group(fobj)),
        science=_rdrs_for_packet_dataset(_find_science_group(fobj)),
        ancillary=_rdrs_for_packet_dataset(_find_spacecraft_group(fobj)),
    )
    fobj.close()
    return rdr


def rdr_ccsds_packets(rdr, skipfill=False):
    """
    Get all CCSDS packets from an `CommonRdr` in packet tracker order.
    """
    for packet in rdr.packets():
        if packet.tracker.fill_percent != 0:
            LOG.debug('fill: %s', packet.tracker)
            if skipfill:
                continue
        yield packet.packet


def sort_packets_by_obs_time(packets):
    """
    Sort `Packet`s in-place by the PacketTracker obs_time.
    """
    return sorted(packets, key=lambda p: p.tracker.obs_time)


def sort_packets_by_apid(packets, order=None):
    """
    Sort packets in-place by apid. An error will be raised if `order` is given
    and packet has an apid not in the list.
    """
    if order:
        return sorted(packets, key=lambda p: order.index(p.packet.apid))
    else:
        return sorted(packets, key=lambda p: p.packet.apid)


def _write_packets(pkts, dest, skipfill):
    for pkt in pkts:
        if pkt.tracker.fill_percent != 0 and skipfill:
            continue
        dest.write(pkt.packet)


def write_rdr_datasets(filepath, ancillary=False, skipfill=False):
    rdrname = os.path.basename(filepath)
    rdrs = rdr_datasets(filepath)

    for typekey in rdrs:
        if not rdrs[typekey]:
            continue
        if ancillary and typekey == 'ancillary':
            for idx, rdr in enumerate(rdrs['ancillary']):
                packets = {a.value: rdr.packets_for_apid(a)
                           for a in rdr.apids}
                for apid, pkts in packets.items():
                    LOG.debug(
                        'writing ancillary gran %d %s-%s-%s %d',
                        idx, rdr.header.satellite, rdr.header.sensor,
                        rdr.header.type_id, apid)
                    pktfile = '{}.{}{}.pkts'.format(rdrname, typekey, apid)
                    with open(pktfile, 'ab') as dest:
                        _write_packets(pkts, dest, skipfill)
        else:
            pktfile = '{}.{}.pkts'.format(rdrname, typekey)
            LOG.debug('writing %s', pktfile)
            with open(pktfile, 'wb') as dest:
                for idx, rdr in enumerate(rdrs[typekey]):
                    LOG.debug(
                        '... %s gran %d %s-%s-%s', typekey,
                        idx, rdr.header.satellite, rdr.header.sensor, rdr.header.type_id)
                    _write_packets(rdr.packets(), dest, skipfill)
    return rdrs