From 122cf7658f2404264ee149a4cf472576f17cfe2c Mon Sep 17 00:00:00 2001
From: Max Drexler <mndrexler@wisc.edu>
Date: Wed, 3 Jul 2024 21:53:42 +0000
Subject: [PATCH] rename processor, add cache cleaning

---
 grib_pipeline.py => grib_processor.py | 237 +++++++++++++++++---------
 1 file changed, 152 insertions(+), 85 deletions(-)
 rename grib_pipeline.py => grib_processor.py (87%)

diff --git a/grib_pipeline.py b/grib_processor.py
similarity index 87%
rename from grib_pipeline.py
rename to grib_processor.py
index aeae1c3..3ed14e4 100644
--- a/grib_pipeline.py
+++ b/grib_processor.py
@@ -11,7 +11,6 @@ __email__ = 'mndrexler@wisc.edu'
 import argparse
 from collections import defaultdict
 from datetime import datetime
-import json
 import os
 from pathlib import Path
 from socket import gethostname
@@ -19,7 +18,8 @@ import logging
 import logging.handlers
 import sys
 import time
-from typing import Generator
+import threading
+from typing import DefaultDict, Generator
 
 import grib2io
 from watchfiles import watch, Change
@@ -49,6 +49,12 @@ ROUTE_KEY_FMT = "{xcd_model}.{xcd_model_id}.realtime.reserved.reserved.2"
 # What file extensions to parse as grib files
 GRIB_ENDINGS = [".grb", ".grib", ".grb2", ".grib2"]
 
+# How long grib files stay in the message cache.
+CACHE_TTL = 43200 # (12 hours)
+
+# How often to clean the cache of previous messages.
+CACHE_CLEAN_INTERVAL = 5400 # (1.5 hours)
+
 # Table containing XCD model and model ids for grib messages
 # (first_lat, first_lon, rows, cols, gen_proc_id)
 XCD_MODELS = {
@@ -412,6 +418,101 @@ XCD_MODELS = {
 }
 
 
+# argparse types
+def non_negative_int(_arg: str) -> int:
+    i = int(_arg)
+    if i < 0:
+        raise ValueError
+    return i
+
+
+def positive_int(_arg: str) -> int:
+    i = int(_arg)
+    if i <= 0:
+        raise ValueError
+    return i
+    
+
+def setup() -> argparse.Namespace:
+    """Initialize the script so it's ready to run.
+    """
+    # Set up cmdline args
+    parser = argparse.ArgumentParser(description='Set paramters for ingesting and publishing of grib data.')
+    parser.add_argument('grib_dir', help='Directory to watch for new grib files, use -- to read file paths from stdin.')
+    parser.add_argument('-v', '--verbosity', type=non_negative_int, default=3, help='Specify verbosity to stderr from 0 to 5.')
+    parser.add_argument('--watch-debounce', type=non_negative_int, default=5000, help='maximum time in milliseconds to group grib file changes over before processing them.')
+    parser.add_argument('--log-dir', default=LOG_DIR, help='Directory to put rotating file logs into.')
+    parser.add_argument('--servers', nargs='+', default=DEFAULT_AMQP_SERVERS, help='Servers to publish AMQP messages to.')
+    parser.add_argument('--clean-interval', metavar='SECONDS', type=positive_int, default=CACHE_CLEAN_INTERVAL, help='How often to clean the message cache (the default is recommended).')
+    parser.add_argument('--cache-ttl', metavar='SECONDS', type=positive_int, default=CACHE_TTL, help='How long paths stay in the cache for (the default is recommended).')
+        
+    args = parser.parse_args()
+    
+    # Verify parsed arguments
+    if not os.path.isdir(args.grib_dir):
+        if args.grib_dir == '--':
+            raise NotImplementedError('Reading from stdin is not yet implemented ¯\_(ツ)_/¯')
+        else:
+            parser.error('{0} is not a valid directory!'.format(args.grib_dir))
+
+    if not os.path.isdir(args.log_dir):
+        os.mkdir(args.log_dir)
+
+    if not load_dotenv():
+        raise SystemError("Couldn't find .env file with RabbitMQ credentials!")
+
+    if os.getenv('RMQ_USER') is None:
+        raise ValueError("Specify the RabbitMQ user in your .env file with RMQ_USER!")
+    if os.getenv('RMQ_PASS') is None:
+        raise ValueError("Specify the RabbitMQ password in your .env file with RMQ_PASS!")
+
+    # Set up logging
+    log_formatter = logging.Formatter(fmt='[%(asctime)s][%(levelname)s]-%(message)s', datefmt='%Y-%m-%dT%H:%M:%S')
+   
+    file_handler = logging.handlers.TimedRotatingFileHandler(os.path.join(args.log_dir, LOG_NAME), when='D', utc=True)
+    file_handler.setFormatter(log_formatter)
+    file_handler.setLevel(logging.DEBUG)
+
+    out_handler = logging.StreamHandler()
+    out_handler.setFormatter(log_formatter)
+    out_handler.setLevel(LOG_LEVELS[min(args.verbosity, len(LOG_LEVELS) - 1)])
+
+    LOG.addHandler(file_handler)
+    LOG.addHandler(out_handler)
+    LOG.setLevel(logging.DEBUG)
+    
+    return args
+
+
+def _cache_entry() -> tuple[float, int]:
+    """default cache entry (current time, first grib message).
+    """
+    return time.time(), 0
+
+
+def _cache_cleaner(cache: DefaultDict[str, tuple[float, int]], interval: float, ttl: float) -> None:
+    """Called from another thread, continually cleans cache in order to
+    prevent memory leaks.
+    """
+    while True:
+        to_remove = []
+        LOG.debug('Cleaning the message cache, size=%d!', len(cache))
+        for path, (last_update_time, last_msg_num) in cache.items():
+            time_in_cache = time.time() - last_update_time
+            if time_in_cache > ttl:
+                LOG.info('Evicted %s from cache! (last message: %d, time since last update: %f).', path, last_msg_num - 1, time_in_cache)
+                to_remove.append(path)
+        if not to_remove:
+            LOG.debug('Nothing to remove from cache...')
+        else:
+            for rem_path in to_remove:
+                cache.pop(rem_path, None)
+            # LOG.debug('Cleaning grib2io caches')
+            # grib2io._grib2io._msg_class_store.clear()
+            # grib2io._grib2io._latlon_datastore.clear()
+        time.sleep(interval)
+
+
 def watch_filter(change: Change, path: str) -> bool:
     """Make sure files we ingest are grib files.
     """
@@ -445,7 +546,7 @@ def amqp_grib(grib_path: str, start_message: int) -> Generator[dict, None, None]
     Yields:
         JSON-able: The json formatted AMQP payload for each grib message from start_message to end of the file.
     """
-    with grib2io.open(grib_path) as grib_file: # Grib2Message._msg_class_store might lead to memory leakage
+    with grib2io.open(grib_path) as grib_file:
         for msg in grib_file[start_message:]:
             f_lat = getattr(msg, "latitudeFirstGridpoint", None)
             f_lon = getattr(msg, "longitudeFirstGridpoint", None)
@@ -476,7 +577,7 @@ def amqp_grib(grib_path: str, start_message: int) -> Generator[dict, None, None]
                 "param_units": msg.units,
                 "grib_number": 2,
                 "size": "%d x %d" % (rows, cols),
-                "ensemble": getattr(msg, "pertubatyionNumber", None),
+                "ensemble": getattr(msg, "pertubationNumber", None),
                 "model_id": int(gen_proc.value),
                 "model": gen_proc.definition,
                 "xcd_model": xcd_info[1],
@@ -487,90 +588,56 @@ def amqp_grib(grib_path: str, start_message: int) -> Generator[dict, None, None]
             }
 
 
-def init() -> argparse.Namespace:
-    """Initialize the script so it's ready to run.
-    """
-    # Set up cmdline args
-    parser = argparse.ArgumentParser(description='Set paramters for ingesting and publishing of grib data.')
-    parser.add_argument('grib_dir', help='Directory to watch for new grib files, use -- to read file paths from stdin.')
-    parser.add_argument('-v', '--verbosity', type=int, default=3, help='Specify verbosity to stderr from 0 to 5.')
-    parser.add_argument('--watch-debounce', type=int, default=5000, help='maximum time in milliseconds to group grib file changes over before processing them.')
-    parser.add_argument('--log-dir', default=LOG_DIR, help='Directory to put rotating file logs into.')
-    parser.add_argument('--servers', nargs='+', default=DEFAULT_AMQP_SERVERS, help='Servers to publish AMQP messages to.')
-    parser.add_argument('--no-monitor', action='store_true', help="Toggle off the monitoring process that collects data on how the main process is running.")
-        
-    args = parser.parse_args()
-    
-    # Verify parsed arguments
-    if not os.path.isdir(args.grib_dir):
-        if args.grib_dir == '--':
-            raise NotImplementedError('Reading from stdin is not yet implemented ¯\_(ツ)_/¯')
-        else:
-            raise FileNotFoundError(args.grib_dir)
-
-    if not os.path.isdir(args.log_dir):
-        os.mkdir(args.log_dir)
-
-    if args.watch_debounce < 0:
-        raise ValueError("--watch-debounce can't be negative!")
-
-    if args.verbosity < 0:
-        raise ValueError("--verbosity can't be negative!")
-
-    if not load_dotenv():
-        raise SystemError("Couldn't find .env file with RabbitMQ credentials!")
-
-    if os.getenv('RMQ_USER') is None:
-        raise ValueError("Specify the RabbitMQ user in your .env file with RMQ_USER!")
-    if os.getenv('RMQ_PASS') is None:
-        raise ValueError("Specify the RabbitMQ password in your .env file with RMQ_PASS!")
+def gribs_from_dir(watch_dir: str, cache: DefaultDict[str, tuple[int, float]], watch_debounce: int) -> None:
+    """Process and publish grib files by watching a directory.
 
-
-    # Set up logging
-    log_formatter = logging.Formatter(fmt='[%(asctime)s][%(levelname)s]-%(message)s', datefmt='%Y-%m-%dT%H:%M:%S')
-   
-    file_handler = logging.handlers.TimedRotatingFileHandler(os.path.join(args.log_dir, LOG_NAME), when='D', utc=True)
-    file_handler.setFormatter(log_formatter)
-    file_handler.setLevel(logging.DEBUG)
-
-    out_handler = logging.StreamHandler()
-    out_handler.setFormatter(log_formatter)
-    out_handler.setLevel(LOG_LEVELS[min(args.verbosity, len(LOG_LEVELS) - 1)])
-
-    LOG.addHandler(file_handler)
-    LOG.addHandler(out_handler)
-    LOG.setLevel(logging.DEBUG)
-    
-    return args
-    
-
-def main(args: argparse.Namespace):
-    LOG.info('Starting grib_pipeline')
-    mq.connect(
-        *args.servers,
-        user=os.environ['RMQ_USER'],
-        password=os.environ['RMQ_PASS'],
-        exchange='model'
-    )
-    LOG.info('Connected to %s', ', '.join(args.servers))
-
-    mem = defaultdict(int) # this will cause a memory leak without cleaning old data.
-    LOG.info('Starting watch on %s with debounce=%d', args.grib_dir, args.watch_debounce)
-    for changes in watch(args.grib_dir, watch_filter=watch_filter, debounce=args.watch_debounce):
+    Args:
+        watch_dir (str): The directory to watch for grib files.
+        cache (DefaultDict[str, tuple[int, float]]): The cache to use to store grib messages in (if messages arrive in chunks).
+        watch_debounce (int): debounce for watch.
+    """
+    LOG.info('Watching %s for gribs with debounce=%d', watch_dir, watch_debounce)
+    for changes in watch(watch_dir, watch_filter=watch_filter, debounce=watch_debounce):
         for _, path in changes:
-            i = 0
-            next_msg = mem[path]
-            LOG.debug('Got grib file %s, already processed up to %d messages.', path, next_msg)
-            for i, amqp in enumerate(amqp_grib(path, next_msg)):
-                route_key = ROUTE_KEY_FMT.format(xcd_model=amqp['xcd_model'], xcd_model_id=amqp['xcd_model_id'])
-                LOG.debug('Publishing %s msg %d to %s', path, next_msg + i, route_key)
-                mq.publish(amqp, route_key=route_key)
+            next_msg = cache[path][1]
+            LOG.debug('Got grib file %s (next msg to process: %d).', path, next_msg)
+            msg_num = 0 # declare here to avoid unbound errors.
+            try:
+                for msg_num, amqp in enumerate(amqp_grib(path, next_msg)):
+                    route_key = ROUTE_KEY_FMT.format(xcd_model=amqp['xcd_model'], xcd_model_id=amqp['xcd_model_id'])
+                    LOG.debug('Publishing %s msg %d to %s', path, next_msg + msg_num, route_key)
+                    mq.publish(amqp, route_key=route_key)
+            except (ValueError, IndexError, OSError):
+                LOG.exception("Error parsing grib file %s!", path)
+            else:
+                cache[path] = time.time(), next_msg + msg_num + 1
+
+
+def main() -> int | None:
+    """Main method, setup and start processing.
+    """
+    try:
+        args = setup()
+        mq.connect(
+            *args.servers,
+            user=os.environ['RMQ_USER'],
+            password=os.environ['RMQ_PASS'],
+            exchange='model'
+        )
+        LOG.info('Connected to %s', ', '.join(args.servers))
+    except KeyboardInterrupt:
+        return
 
+    # maybe add loading/storing cache to disk after exits?
+    cache = defaultdict(_cache_entry)
+    LOG.debug('Starting cache cleaner')
+    threading.Thread(target=_cache_cleaner, args=(cache, args.clean_interval, args.cache_ttl), daemon=True).start()
 
-if __name__ == "__main__":
     try:
-        args = init()
-        main(args)
+        gribs_from_dir(args.grib_dir, cache, args.watch_debounce)
     except KeyboardInterrupt:
-        LOG.info('Got interrupt, goodbye')        
-    sys.exit(0)
+        LOG.critical('Got interrupt, goodbye!')
+
+
+if __name__ == "__main__":
+    sys.exit(main())
-- 
GitLab