From d1f3336224fa487d31a861c94017e9f555e76968 Mon Sep 17 00:00:00 2001
From: Max Drexler <mndrexler@wisc.edu>
Date: Mon, 8 Jul 2024 16:50:38 +0000
Subject: [PATCH] new cli options, formatting

---
 grib_processor.py | 244 +++++++++++++++++++++++++++++++++-------------
 1 file changed, 177 insertions(+), 67 deletions(-)

diff --git a/grib_processor.py b/grib_processor.py
index b75c21b..dc4f2ef 100644
--- a/grib_processor.py
+++ b/grib_processor.py
@@ -4,8 +4,8 @@ Ingest grib files and publish metadata for files to RabbitMQ.
 
 from __future__ import annotations
 
-__author__ = 'Max Drexler'
-__email__ = 'mndrexler@wisc.edu'
+__author__ = "Max Drexler"
+__email__ = "mndrexler@wisc.edu"
 
 
 import argparse
@@ -28,20 +28,35 @@ from dotenv import load_dotenv
 
 
 if sys.version_info < (3, 8):
-    raise SystemError('Python version too low')
+    raise SystemError("Python version too low")
 
 # logging stuff
-LOG = logging.getLogger('grib_ingest')
-LOG_DIR = os.getenv('GRIB_LOG_DIR', os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs'))
-LOG_LEVELS = [logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO,logging.DEBUG]
-LOG_NAME = os.path.splitext(sys.argv[0])[0] + '.log'
+LOG = logging.getLogger("grib_ingest")
+LOG_DIR = os.getenv(
+    "GRIB_LOG_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
+)
+LOG_LEVELS = [
+    logging.CRITICAL,
+    logging.ERROR,
+    logging.WARNING,
+    logging.INFO,
+    logging.DEBUG,
+]
+LOG_NAME = os.path.splitext(sys.argv[0])[0] + ".log"
 
 
 # Where we're publishing data from
 HOSTNAME = gethostname()
 
 # Where we're publishing data to
-DEFAULT_AMQP_SERVERS = ['mq1.ssec.wisc.edu', 'mq2.ssec.wisc.edu', 'mq3.ssec.wisc.edu']
+DEFAULT_AMQP_SERVERS = ["mq1.ssec.wisc.edu", "mq2.ssec.wisc.edu", "mq3.ssec.wisc.edu"]
+
+# Where on the servers we're putting the data.
+DEFAULT_AMQP_EXCHANGE = "model"
+
+# RabbitMQ default login details (shouldn't be authenticated to publish).
+DEFAULT_AMQP_USER = "guest"
+DEFAULT_AMQP_PASS = "guest"
 
 # Format of the route key to publish to
 ROUTE_KEY_FMT = "{xcd_model}.{xcd_model_id}.realtime.reserved.reserved.2"
@@ -50,10 +65,10 @@ ROUTE_KEY_FMT = "{xcd_model}.{xcd_model_id}.realtime.reserved.reserved.2"
 GRIB_ENDINGS = [".grb", ".grib", ".grb2", ".grib2"]
 
 # How long grib files stay in the message cache.
-CACHE_TTL = 43200 # (12 hours)
+CACHE_TTL = 43200  # (12 hours)
 
 # How often to clean the cache of previous messages.
-CACHE_CLEAN_INTERVAL = 5400 # (1.5 hours)
+CACHE_CLEAN_INTERVAL = 5400  # (1.5 hours)
 
 # Table containing XCD model and model ids for grib messages
 # (first_lat, first_lon, rows, cols, gen_proc_id)
@@ -431,45 +446,114 @@ def positive_int(_arg: str) -> int:
     if i <= 0:
         raise ValueError
     return i
-    
+
 
 def setup() -> argparse.Namespace:
-    """Initialize the script so it's ready to run.
-    """
+    """Initialize the script so it's ready to run."""
     # Set up cmdline args
-    parser = argparse.ArgumentParser(description='Set paramters for ingesting and publishing of grib data.')
-    parser.add_argument('grib_dir', help='Directory to watch for new grib files, use -- to read file paths from stdin.')
-    parser.add_argument('-v', '--verbosity', type=non_negative_int, default=3, help='Specify verbosity to stderr from 0 to 5.')
-    parser.add_argument('--watch-debounce', type=non_negative_int, default=5000, help='maximum time in milliseconds to group grib file changes over before processing them.')
-    parser.add_argument('--log-dir', default=LOG_DIR, help='Directory to put rotating file logs into.')
-    parser.add_argument('--servers', nargs='+', default=DEFAULT_AMQP_SERVERS, help='Servers to publish AMQP messages to.')
-    parser.add_argument('--clean-interval', metavar='SECONDS', type=positive_int, default=CACHE_CLEAN_INTERVAL, help='How often to clean the message cache (the default is recommended).')
-    parser.add_argument('--cache-ttl', metavar='SECONDS', type=positive_int, default=CACHE_TTL, help='How long paths stay in the cache for (the default is recommended).')
-        
+    parser = argparse.ArgumentParser(
+        description="Set paramters for ingesting and publishing of grib data."
+    )
+    parser.add_argument(
+        "grib_dir",
+        help="Directory to watch for new grib files, use -- to read file paths from stdin.",
+    )
+    parser.add_argument(
+        "-v",
+        "--verbosity",
+        type=non_negative_int,
+        default=3,
+        help="Specify verbosity to stderr from 0 (CRITICAL) to 5 (DEBUG).",
+    )
+    parser.add_argument(
+        "--watch-debounce",
+        type=non_negative_int,
+        default=5000,
+        help="maximum time in milliseconds to group grib file changes over before processing them.",
+    )
+    parser.add_argument(
+        "--log-dir", default=LOG_DIR, help="Directory to put rotating file logs into."
+    )
+
+    ch_group = parser.add_argument_group(
+        "cache arguments",
+        "Specify how the cache will behave (the defaults work best for standard usage).",
+    )
+    ch_group.add_argument(
+        "--no-cache",
+        action="store_true",
+        help="Turn the cache off. Could be useful if watched files are only written to once.",
+    )
+    ch_group.add_argument(
+        "--clean-interval",
+        metavar="SECONDS",
+        type=positive_int,
+        default=CACHE_CLEAN_INTERVAL,
+        help="How often to clean the message cache.",
+    )
+    ch_group.add_argument(
+        "--cache-ttl",
+        metavar="SECONDS",
+        type=positive_int,
+        default=CACHE_TTL,
+        help="How long paths stay in the cache for.",
+    )
+
+    mq_group = parser.add_argument_group(
+        "RabbitMQ arguments", "arguments to specify how to publish the data."
+    )
+    mq_group.add_argument(
+        "--servers",
+        nargs="+",
+        default=DEFAULT_AMQP_SERVERS,
+        help="Servers to publish AMQP messages to. Defaults to %(default)s",
+    )
+    mq_group.add_argument(
+        "-X",
+        "--exchange",
+        default=DEFAULT_AMQP_EXCHANGE,
+        help='RabbitMQ exchange to publish the messages to. Defaults to "%(default)s"',
+    )
+    mq_group.add_argument(
+        "-u",
+        "--user",
+        default=None,
+        help="RabbitMQ user to login with. Can also specify using a .env file, see README#configuration.",
+    )
+    mq_group.add_argument(
+        "-p",
+        "--password",
+        default=None,
+        help="RabbitMQ password to login with. Can also specify using a .env file, see README#configuration.",
+    )
+
     args = parser.parse_args()
-    
+
     # Verify parsed arguments
     if not os.path.isdir(args.grib_dir):
-        if args.grib_dir == '--':
-            raise NotImplementedError('Reading from stdin is not yet implemented ¯\_(ツ)_/¯')
+        if args.grib_dir == "--":
+            raise NotImplementedError(
+                "Reading from stdin is not yet implemented ¯\_(ツ)_/¯"
+            )
         else:
-            parser.error('{0} is not a valid directory!'.format(args.grib_dir))
+            parser.error("{0} is not a valid directory!".format(args.grib_dir))
+
+    if args.no_cache:
+        raise NotImplementedError(
+            "Turning the cache off is not yet implemented ¯\_(ツ)_/¯"
+        )
 
     if not os.path.isdir(args.log_dir):
         os.mkdir(args.log_dir)
 
-    if not load_dotenv():
-        raise SystemError("Couldn't find .env file with RabbitMQ credentials!")
-
-    if os.getenv('RMQ_USER') is None:
-        raise ValueError("Specify the RabbitMQ user in your .env file with RMQ_USER!")
-    if os.getenv('RMQ_PASS') is None:
-        raise ValueError("Specify the RabbitMQ password in your .env file with RMQ_PASS!")
-
     # Set up logging
-    log_formatter = logging.Formatter(fmt='[%(asctime)s][%(levelname)s]-%(message)s', datefmt='%Y-%m-%dT%H:%M:%S')
-   
-    file_handler = logging.handlers.TimedRotatingFileHandler(os.path.join(args.log_dir, LOG_NAME), when='D', utc=True)
+    log_formatter = logging.Formatter(
+        fmt="[%(asctime)s][%(levelname)s]-%(message)s", datefmt="%Y-%m-%dT%H:%M:%S"
+    )
+
+    file_handler = logging.handlers.TimedRotatingFileHandler(
+        os.path.join(args.log_dir, LOG_NAME), when="D", utc=True
+    )
     file_handler.setFormatter(log_formatter)
     file_handler.setLevel(logging.DEBUG)
 
@@ -480,30 +564,36 @@ def setup() -> argparse.Namespace:
     LOG.addHandler(file_handler)
     LOG.addHandler(out_handler)
     LOG.setLevel(logging.DEBUG)
-    
+
     return args
 
 
 def _cache_entry() -> tuple[float, int]:
-    """default cache entry (current time, first grib message).
-    """
+    """default cache entry (current time, first grib message)."""
     return time.time(), 0
 
 
-def _cache_cleaner(cache: DefaultDict[str, tuple[float, int]], interval: float, ttl: float) -> None:
+def _cache_cleaner(
+    cache: DefaultDict[str, tuple[float, int]], interval: float, ttl: float
+) -> None:
     """Called from another thread, continually cleans cache in order to
     prevent memory leaks.
     """
     while True:
         to_remove = []
-        LOG.debug('Cleaning the message cache, size=%d!', len(cache))
+        LOG.debug("Cleaning the message cache, size=%d!", len(cache))
         for path, (last_update_time, last_msg_num) in cache.items():
             time_in_cache = time.time() - last_update_time
             if time_in_cache > ttl:
-                LOG.info('Evicted %s from cache! (last message: %d, time since last update: %f).', path, last_msg_num - 1, time_in_cache)
+                LOG.info(
+                    "Evicted %s from cache! (last message: %d, time since last update: %f).",
+                    path,
+                    last_msg_num - 1,
+                    time_in_cache,
+                )
                 to_remove.append(path)
         if not to_remove:
-            LOG.debug('Nothing to remove from cache...')
+            LOG.debug("Nothing to remove from cache...")
         else:
             for rem_path in to_remove:
                 cache.pop(rem_path, None)
@@ -514,8 +604,7 @@ def _cache_cleaner(cache: DefaultDict[str, tuple[float, int]], interval: float,
 
 
 def watch_filter(change: Change, path: str) -> bool:
-    """Make sure files we ingest are grib files.
-    """
+    """Make sure files we ingest are grib files."""
     if change == Change.deleted:
         return False
     if Path(path).suffix in GRIB_ENDINGS:
@@ -526,7 +615,7 @@ def watch_filter(change: Change, path: str) -> bool:
 def format_dt(dt: datetime) -> str:
     """Format datetimes in a manner consistent with other AMQP publishers at
     the SSEC.
-    
+
     Args:
         dt (datetime): the datetime object to format.
 
@@ -553,7 +642,10 @@ def amqp_grib(grib_path: str, start_message: int) -> Generator[dict, None, None]
             rows = msg.ny
             cols = msg.nx
             gen_proc = msg.generatingProcess
-            xcd_info = XCD_MODELS.get((str(f_lat), str(f_lon), str(rows), str(cols), str(gen_proc.value)), ['UKNWN', 'UKNWN'])
+            xcd_info = XCD_MODELS.get(
+                (str(f_lat), str(f_lon), str(rows), str(cols), str(gen_proc.value)),
+                ["UKNWN", "UKNWN"],
+            )
             yield {
                 "__payload_gen_time__": format_dt(datetime.now()),
                 "path": grib_path,
@@ -568,6 +660,7 @@ def amqp_grib(grib_path: str, start_message: int) -> Generator[dict, None, None]
                 "forecast_hour": int(msg.valueOfForecastTime),
                 "run_hour": int(msg.hour) * 100 + int(msg.minute),
                 "model_time": format_dt(msg.refDate),
+                "start_time": format_dt(msg.refDate),
                 "projection": msg.gridDefinitionTemplateNumber.definition,
                 "center_id": int(msg.originatingCenter.value),
                 "center_desc": msg.originatingCenter.definition,
@@ -588,7 +681,9 @@ def amqp_grib(grib_path: str, start_message: int) -> Generator[dict, None, None]
             }
 
 
-def gribs_from_dir(watch_dir: str, cache: DefaultDict[str, tuple[int, float]], watch_debounce: int) -> None:
+def gribs_from_dir(
+    watch_dir: str, cache: DefaultDict[str, tuple[float, int]], watch_debounce: int
+) -> None:
     """Process and publish grib files by watching a directory.
 
     Args:
@@ -596,40 +691,55 @@ def gribs_from_dir(watch_dir: str, cache: DefaultDict[str, tuple[int, float]], w
         cache (DefaultDict[str, tuple[int, float]]): The cache to use to store grib messages in (if messages arrive in chunks).
         watch_debounce (int): debounce for watch.
     """
-    LOG.info('Watching %s for gribs with debounce=%d', watch_dir, watch_debounce)
+    LOG.info("Watching %s for gribs with debounce=%d", watch_dir, watch_debounce)
     for changes in watch(watch_dir, watch_filter=watch_filter, debounce=watch_debounce):
         for _, path in changes:
             next_msg = cache[path][1]
-            LOG.debug('Got grib file %s (next msg to process: %d).', path, next_msg)
-            msg_num = 0 # declare here to avoid unbound errors.
+            LOG.debug("Got grib file %s (next msg to process: %d).", path, next_msg)
+            msg_num = 0  # declare here to avoid unbound errors.
             try:
                 for msg_num, amqp in enumerate(amqp_grib(path, next_msg)):
-                    route_key = ROUTE_KEY_FMT.format(xcd_model=amqp['xcd_model'], xcd_model_id=amqp['xcd_model_id'])
-                    LOG.debug('Publishing %s msg %d to %s', path, next_msg + msg_num, route_key)
+                    route_key = ROUTE_KEY_FMT.format(
+                        xcd_model=amqp["xcd_model"], xcd_model_id=amqp["xcd_model_id"]
+                    )
+                    LOG.debug(
+                        "Publishing %s msg %d to %s",
+                        path,
+                        next_msg + msg_num,
+                        route_key,
+                    )
                     mq.publish(amqp, route_key=route_key)
-            except (ValueError, IndexError, OSError):
+            except (ValueError, KeyError, OSError):
                 LOG.exception("Error parsing grib file %s!", path)
             else:
                 cache[path] = time.time(), next_msg + msg_num + 1
 
 
-def main() -> int | None:
-    """Main method, setup and start processing.
-    """
+def main() -> None:
+    """Main method, setup and start processing."""
     try:
         args = setup()
+
+        if not load_dotenv():
+            LOG.warning("Couldn't find .env file")
+
         mq.connect(
             *args.servers,
-            user=os.environ['RMQ_USER'],
-            password=os.environ['RMQ_PASS'],
-            exchange='model'
+            user=args.user or os.getenv("RMQ_USER", DEFAULT_AMQP_USER),
+            password=args.password or os.getenv("RMQ_PASS", DEFAULT_AMQP_PASS),
+            exchange=args.exchange,
         )
-        LOG.info('Connected to %s', ', '.join(args.servers))
+        LOG.info("Connected to %s", ", ".join(args.servers))
 
         # maybe add loading/storing cache to disk after exits?
-        cache = defaultdict(_cache_entry)
-        LOG.debug('Starting cache cleaner')
-        threading.Thread(name='grib_pipeline_cleaner', target=_cache_cleaner, args=(cache, args.clean_interval, args.cache_ttl), daemon=True).start()
+        cache: DefaultDict[str, tuple[float, int]] = defaultdict(_cache_entry)
+        LOG.debug("Starting cache cleaner")
+        threading.Thread(
+            name="grib_pipeline_cleaner",
+            target=_cache_cleaner,
+            args=(cache, args.clean_interval, args.cache_ttl),
+            daemon=True,
+        ).start()
     except KeyboardInterrupt:
         mq.disconnect()
         return
@@ -637,9 +747,9 @@ def main() -> int | None:
     try:
         gribs_from_dir(args.grib_dir, cache, args.watch_debounce)
     except KeyboardInterrupt:
-        LOG.critical('Got interrupt, goodbye!')
+        LOG.critical("Got interrupt, goodbye!")
         mq.disconnect()
-
+    return
 
 if __name__ == "__main__":
     sys.exit(main())
-- 
GitLab