From 82883d61de030828eab7c83c0a0dec38ba8cd2f6 Mon Sep 17 00:00:00 2001
From: Bruce Flynn <brucef@ssec.wisc.edu>
Date: Tue, 20 Oct 2015 14:51:05 +0000
Subject: [PATCH] Fix CLI merge utility.

---
 edosl0util/cli.py    | 20 ++---------
 edosl0util/merge.py  | 81 +++++++++++++++++++++-----------------------
 edosl0util/stream.py | 15 +++++---
 3 files changed, 52 insertions(+), 64 deletions(-)

diff --git a/edosl0util/cli.py b/edosl0util/cli.py
index c07a779..d8bceef 100644
--- a/edosl0util/cli.py
+++ b/edosl0util/cli.py
@@ -104,23 +104,9 @@ def cmd_merge():
 
     _configure_logging(args)
 
-    index = set()  # will remove duplicates
-    for filepath in args.pds:
-        LOG.debug('indexing %s', filepath)
-        fobj = io.open(filepath, 'rb')
-        index.update(merge.read_packet_index(fobj))
-
-    LOG.info('sorting index with %d pointers', len(index))
-    index = sorted(index)
-    LOG.info('writing index to %s', args.output)
-    with io.open(args.output, 'wb') as fobj:
-        for ptr in index:
-            if args.trunc_to:
-                interval = args.trunc_to
-                if ptr.stamp >= interval[0] and ptr.stamp < interval[1]:
-                    fobj.write(ptr.bytes())
-            else:
-                fobj.write(ptr.bytes())
+    streams = [stream.jpss_packet_stream(io.open(f, 'rb')) for f in args.pds]
+    merge.merge(
+        streams, output=io.open(args.output, 'wb'), trunc_to=args.trunc_to)
 
 
 def cmd_rdr2l0():
diff --git a/edosl0util/merge.py b/edosl0util/merge.py
index dcc67cf..ee68ee7 100644
--- a/edosl0util/merge.py
+++ b/edosl0util/merge.py
@@ -8,15 +8,10 @@
 4. Write
 
 """
-import io
 import os
 import logging
-from datetime import datetime
 from collections import deque
 
-from edosl0util import headers
-from edosl0util.stream import PacketStream, BasicStream
-
 LOG = logging.getLogger(__name__)
 
 
@@ -55,11 +50,8 @@ class _Ptr(object):
         return self.fobj.read(self.size)
 
 
-def read_packet_index(fobj):
-    lookup = headers.jpss_header_lookup
-    stream = PacketStream(BasicStream(fobj, lookup, with_data=False))
+def read_packet_index(stream):
     index = deque()
-
     try:
         # drop any leading hanging packets
         count = 0
@@ -72,7 +64,7 @@ def read_packet_index(fobj):
 
         while True:
             ptr = _Ptr(
-                fobj,
+                stream.file,
                 stamp=packet.stamp,
                 apid=packet.apid,
                 offset=packet.offset,
@@ -92,38 +84,41 @@ def read_packet_index(fobj):
     return index
 
 
-if __name__ == '__main__':
-    import argparse
-    parser = argparse.ArgumentParser()
-    parser.add_argument('-o', '--output', default='out.pds')
-    dtarg = lambda v: datetime.strptime(v, '%Y-%m-%d %H:%M:%S')
-    def interval(v):
-        dt = lambda v: datetime.strptime(v, '%Y-%m-%d %H:%M:%S')
-        return [dt(x) for x in v.split(',')]
-    parser.add_argument(
-        '-t', '--trunc-to', type=interval,
-        help=('Truncate to the interval given as coma separated timestamps of '
-              'the format YYYY-MM-DD HH:MM:SS. The begin time is inclusive, the '
-              'end time is exclusive.'))
-    parser.add_argument('pds', nargs='+')
-    args = parser.parse_args()
-
-    logging.basicConfig(level=logging.DEBUG, format='%(message)s')
+def _sort_by_time_apid(index, order=None):
+    if order:
+        index = sorted(index, key=lambda p: order.index(p.apid))
+    else:
+        index = sorted(index, key=lambda p: p.apid)
+    return sorted(index, key=lambda p: p.stamp)
+
 
+def merge(streams, output, trunc_to=None, apid_order=None):
+    """
+    Merge packets from multiple streams to an output file. Duplicate packets
+    will be filtered and packets will be sorted by time and the apid order
+    provided.
+
+    :param streams: List of `PacketStream`s such as returned by
+        `jpss_packet_stream`.
+    :param output: File-like object to write output packets to.
+    :keyword trunc_to: (start, end) datetime to truncate packets to. Start is
+        inclusive and end is exclusive.
+    :keyword apid_order: List of apid is the order in which they should appear
+        in the output. If provided ALL apids must be present, otherwise an
+        IndexError will occur.
+    """
     index = set()  # will remove duplicates
-    for filepath in args.pds:
-        LOG.debug('indexing %s', filepath)
-        fobj = io.open(filepath, 'rb')
-        index.update(read_packet_index(fobj))
-
-    LOG.info('sorting index with %d pointers', len(index))
-    index = sorted(index)
-    LOG.info('writing index to %s', args.output)
-    with io.open(args.output, 'wb') as fobj:
-        for ptr in index:
-            if args.trunc_to:
-                interval = args.trunc_to
-                if ptr.stamp >= interval[0] and ptr.stamp < interval[1]:
-                    fobj.write(ptr.bytes())
-            else:
-                fobj.write(ptr.bytes())
+    for stream in streams:
+        LOG.debug('indexing %s', stream)
+        index.update(read_packet_index(stream))
+
+    LOG.debug('sorting index with %d pointers', len(index))
+    index = _sort_by_time_apid(index, order=apid_order)
+
+    LOG.debug('writing index to %s', output)
+    for ptr in index:
+        if trunc_to:
+            if ptr.stamp >= trunc_to[0] and ptr.stamp < trunc_to[1]:
+                output.write(ptr.bytes())
+        else:
+            output.write(ptr.bytes())
diff --git a/edosl0util/stream.py b/edosl0util/stream.py
index e23225a..2efd489 100644
--- a/edosl0util/stream.py
+++ b/edosl0util/stream.py
@@ -1,4 +1,3 @@
-import io
 import os
 import logging
 import ctypes as c
@@ -44,7 +43,7 @@ class BasicStream(object):
     Tracker = namedtuple('Tracker', ['h1', 'h2', 'size', 'offset', 'data'])
 
     def __init__(self, fobj, header_lookup=None, with_data=True):
-        self.fobj = fobj
+        self.file = fobj
         self.header_lookup = header_lookup
         self.with_data = with_data
 
@@ -55,7 +54,7 @@ class BasicStream(object):
         return self.next()
 
     def _read(self, size):
-        buf = self.fobj.read(size)
+        buf = self.file.read(size)
         if not buf:  # EOF
             raise StopIteration()
         if len(buf) != size:
@@ -78,7 +77,7 @@ class BasicStream(object):
         return H2Impl.from_buffer_copy(buf), h2size
 
     def next(self):
-        offset = self.fobj.tell()
+        offset = self.file.tell()
         h1, h1size = self.read_primary_header()
         h2, h2size = self.read_secondary_header(h1)
         # data length includes h2size
@@ -155,6 +154,10 @@ class PacketStream(object):
                      'num_missing': 0})
         self._fail_on_missing = fail_on_missing
 
+    def __repr__(self):
+        filepath = getattr(self.file, 'name', None)
+        return '<{} file={}>'.format(self.__class__.__name__, filepath)
+
     def __iter__(self):
         return self
 
@@ -162,6 +165,10 @@ class PacketStream(object):
     def __next__(self):
         return self.next()
 
+    @property
+    def file(self):
+        return self._stream.file
+
     def push_back(self, packet):
         self._seek_cache.append(packet)
 
-- 
GitLab