From 9770e71d4e103fc4e71010363a7d50d9838fcd55 Mon Sep 17 00:00:00 2001
From: Max Drexler <mndrexler@wisc.edu>
Date: Wed, 17 Jul 2024 18:04:24 +0000
Subject: [PATCH] formatting

---
 grib_processor/__init__.py | 12 +++---
 grib_processor/grib.py     | 36 +++++++++---------
 grib_processor/main.py     | 77 ++++++++++++++++++++++++--------------
 grib_processor/utils.py    | 63 ++++++++++++++++++-------------
 4 files changed, 110 insertions(+), 78 deletions(-)

diff --git a/grib_processor/__init__.py b/grib_processor/__init__.py
index 8ed7d8d..82394af 100644
--- a/grib_processor/__init__.py
+++ b/grib_processor/__init__.py
@@ -12,12 +12,12 @@ from grib_processor.utils import grib_file_watch
 
 __author__ = "Max Drexler"
 __email__ = "mndrexler@wisc.edu"
-__version__ = '0.0.1'
+__version__ = "0.0.1"
 
 __all__ = [
-    'GribMessage',
-    'itergrib',
-    'dump_message',
-    'publish_message',
-    'grib_file_watch',
+    "GribMessage",
+    "itergrib",
+    "dump_message",
+    "publish_message",
+    "grib_file_watch",
 ]
diff --git a/grib_processor/grib.py b/grib_processor/grib.py
index de3dc58..bdf0dc2 100644
--- a/grib_processor/grib.py
+++ b/grib_processor/grib.py
@@ -18,13 +18,13 @@ from typing_extensions import Literal, TypedDict
 
 # Contains a serializable mapping of first_lat, first_lon, rows, cols,
 # and generating_process_ids to xcd model names and ids.
-# 
-# format: 
+#
+# format:
 
-# dict[first_lat, 
-#   dict[first_lon, 
-#       dict[rows, 
-#           dict[cols, 
+# dict[first_lat,
+#   dict[first_lon,
+#       dict[rows,
+#           dict[cols,
 #               dict[generating_process_id, list[str]]
 #           ]
 #       ]
@@ -35,15 +35,15 @@ _XCD_MODELS = None
 
 
 class GribMessage(TypedDict):
-    """Amqp payload for each grib message.
-    """
+    """Amqp payload for each grib message."""
+
     __payload_gen_time__: datetime
     __injector_script__: str
     path: str
     directory: str
     file_name: str
     server_ip: str
-    server_type: Literal['realtime', 'archive']
+    server_type: Literal["realtime", "archive"]
     first_lat: float | None
     first_lon: float | None
     last_lat: float | None
@@ -68,12 +68,12 @@ class GribMessage(TypedDict):
     xcd_model_id: str
     grib_record_number: int
     title: str
-    
 
 
-def load_xcd_models(*addtnl_models: tuple[float, float, int, int, int, str, str]) -> None:
-    """Load the xcd models from the package.
-    """
+def load_xcd_models(
+    *addtnl_models: tuple[float, float, int, int, int, str, str],
+) -> None:
+    """Load the xcd models from the package."""
     pass
 
 
@@ -105,7 +105,9 @@ def xcd_lookup(
         return ("UNKWN", "UNKWN")
 
 
-def itergrib(grib_path: str, start_message: int | None = None) -> Generator[GribMessage, None, None]:
+def itergrib(
+    grib_path: str, start_message: int | None = None
+) -> Generator[GribMessage, None, None]:
     """Generate AMQP payloads from a grib file, starting at a specific message.
 
     Args:
@@ -132,7 +134,7 @@ def itergrib(grib_path: str, start_message: int | None = None) -> Generator[Grib
                 file_name=os.path.basename(grib_path),
                 server_ip=amqp_utils.SERVER_NAME,
                 server_type="realtime",
-                first_lat= f_lat,
+                first_lat=f_lat,
                 last_lat=getattr(msg, "latitudeLastGridpoint", None),
                 first_lon=f_lon,
                 last_long=getattr(msg, "longitudeLastGridpoint", None),
@@ -154,7 +156,7 @@ def itergrib(grib_path: str, start_message: int | None = None) -> Generator[Grib
                 model=gen_proc.definition,
                 xcd_model=xcd_info[1],
                 xcd_model_id=xcd_info[0],
-                #resolution=msg._xydivisor, THIS IS NOT CORRECT!
+                # resolution=msg._xydivisor, THIS IS NOT CORRECT!
                 grib_record_number=msg._msgnum,
                 title=str(msg),
-            )
\ No newline at end of file
+            )
diff --git a/grib_processor/main.py b/grib_processor/main.py
index 75923e4..9f53fbe 100644
--- a/grib_processor/main.py
+++ b/grib_processor/main.py
@@ -10,7 +10,6 @@ from __future__ import annotations
 
 import argparse
 from itertools import chain
-import json
 import logging
 import logging.handlers
 import os
@@ -23,11 +22,19 @@ from dotenv import load_dotenv
 import ssec_amqp.api as mq
 
 from grib_processor.grib import GribMessage, itergrib
-from grib_processor.utils import dump_message, grib_file_watch, publish_message, realtime, DEFAULT_WATCH_DEBOUNCE, REALTIME_WATCH_DEBOUNCE, signalcontext
+from grib_processor.utils import (
+    dump_message,
+    grib_file_watch,
+    publish_message,
+    realtime,
+    DEFAULT_WATCH_DEBOUNCE,
+    REALTIME_WATCH_DEBOUNCE,
+    signalcontext,
+)
 
 
 # Logging stuff
-LOG = logging.getLogger('grib_ingest')
+LOG = logging.getLogger("grib_ingest")
 LOG_LEVELS = [
     logging.CRITICAL,
     logging.ERROR,
@@ -35,7 +42,7 @@ LOG_LEVELS = [
     logging.INFO,
     logging.DEBUG,
 ]
-LOG_NAME = 'grib_processor.log'
+LOG_NAME = "grib_processor.log"
 
 # Where we're publishing data to
 DEFAULT_AMQP_SERVERS = ["mq1.ssec.wisc.edu", "mq2.ssec.wisc.edu", "mq3.ssec.wisc.edu"]
@@ -47,7 +54,7 @@ DEFAULT_AMQP_EXCHANGE = "model"
 def initialize_logging(verbosity: int, rotating_dir: str | None = None) -> None:
     """Set up logging for the package. Don't use basicConfig so we don't
     override other's logs.
-    
+
     Optionally, can specify an existing directory to put rotating log files into.
     """
     # Set up logging
@@ -79,7 +86,7 @@ def setup() -> tuple[Generator[GribMessage, None, None], Callable[[GribMessage],
     )
     parser.add_argument(
         "grib_src",
-        metavar='grib_source',
+        metavar="grib_source",
         help="Directory to watch for new grib files, file to read grib paths from, or '--' to read file paths from stdin.",
     )
     parser.add_argument(
@@ -90,29 +97,35 @@ def setup() -> tuple[Generator[GribMessage, None, None], Callable[[GribMessage],
         help="Specify verbosity to stderr from 0 (CRITICAL) to 5 (DEBUG).",
     )
     parser.add_argument(
-        "--log-dir", default=None, help="Set up a rotating file logger in this directory."
+        "--log-dir",
+        default=None,
+        help="Set up a rotating file logger in this directory.",
     )
-    
+
     parser.add_argument(
-        '-R', '--realtime',
-        action='store_true',
-        help='Turns on parameters for more efficient realtime processing like caching and debounce. Also adds more robust error handling.'
+        "-R",
+        "--realtime",
+        action="store_true",
+        help="Turns on parameters for more efficient realtime processing like caching and debounce. Also adds more robust error handling.",
     )
     parser.add_argument(
-        '-r', '--recursive',
-        action='store_true',
-        help='When using a directory as the source, process all files in the directory recursively instead of setting up a watch.'
+        "-r",
+        "--recursive",
+        action="store_true",
+        help="When using a directory as the source, process all files in the directory recursively instead of setting up a watch.",
     )
-    
+
     parser.add_argument(
-        '-o', '--output',
-        choices=['amqp', 'json'],
-        default='amqp',
-        help='Where to output processed grib messages to. Default is %(default)s.'
+        "-o",
+        "--output",
+        choices=["amqp", "json"],
+        default="amqp",
+        help="Where to output processed grib messages to. Default is %(default)s.",
     )
 
     mq_group = parser.add_argument_group(
-        "RabbitMQ arguments", "Options for how to output grib messages when using '-o=amqp'."
+        "RabbitMQ arguments",
+        "Options for how to output grib messages when using '-o=amqp'.",
     )
     mq_group.add_argument(
         "--servers",
@@ -147,34 +160,40 @@ def setup() -> tuple[Generator[GribMessage, None, None], Callable[[GribMessage],
 
     # Get an iterator over grib files to process
     if os.path.isdir(args.grib_src):
-        LOG.info("Sourcing grib files from directory watch @ %s realtime=%s", args.grib_src, args.realtime)
+        LOG.info(
+            "Sourcing grib files from directory watch @ %s realtime=%s",
+            args.grib_src,
+            args.realtime,
+        )
         if not args.realtime:
-            warnings.warn('Processing files from a directory without using --realtime!')
+            warnings.warn("Processing files from a directory without using --realtime!")
             debounce = DEFAULT_WATCH_DEBOUNCE
         else:
             debounce = REALTIME_WATCH_DEBOUNCE
 
-        file_iter = grib_file_watch(args.grib_src, debounce=debounce, recursive=args.recursive)
+        file_iter = grib_file_watch(
+            args.grib_src, debounce=debounce, recursive=args.recursive
+        )
     elif os.path.isfile(args.grib_src):
         LOG.info("Sourcing grib file directly from CLI, %s", args.grib_src)
         file_iter = (args.grib_src,)
     elif args.grib_src == "--":
         LOG.info("Sourcing grib files from stdin, realtime=%s", args.realtime)
-        file_iter = map(str.strip, iter(sys.stdin.readline, ''))
+        file_iter = map(str.strip, iter(sys.stdin.readline, ""))
     else:
         parser.error("{0} is not a valid source!".format(args.grib_src))
-    
+
     # Get an iterator of grib messages to emit
     if args.realtime:
         source = realtime(file_iter)
     else:
         source = chain.from_iterable(map(itergrib, file_iter))
-    
+
     # Possible to source connection parameters from the environment
     if not load_dotenv():
         LOG.warning("Couldn't find .env file")
 
-    if args.output == 'amqp':
+    if args.output == "amqp":
         LOG.info("Emitting messages to amqp")
         # will automatically disconnect on exit
         mq.connect(
@@ -188,7 +207,7 @@ def setup() -> tuple[Generator[GribMessage, None, None], Callable[[GribMessage],
     else:
         LOG.info("Emitting messages to stdout")
         emit = dump_message
-        
+
     return source, emit
 
 
@@ -207,7 +226,7 @@ def main() -> None:
     with signalcontext(signal.SIGTERM, __raise_interrupt):
         try:
             for grib_msg in source:
-                LOG.debug('Got %s from source', grib_msg)
+                LOG.debug("Got %s from source", grib_msg)
                 emit(grib_msg)
         except KeyboardInterrupt:
             LOG.critical("Got interrupt, goodbye!")
diff --git a/grib_processor/utils.py b/grib_processor/utils.py
index 54faacf..afe015a 100644
--- a/grib_processor/utils.py
+++ b/grib_processor/utils.py
@@ -19,12 +19,13 @@ from typing import Callable, DefaultDict, Generator, Iterator, Tuple
 
 from watchfiles import watch, Change
 from typing_extensions import TypeAlias
+import ssec_amqp.api as mq
 
 from grib_processor.grib import GribMessage, itergrib
 
 
 # Our logger
-LOG = logging.getLogger('grib_ingest')
+LOG = logging.getLogger("grib_ingest")
 
 # How long grib files stay in the message cache.
 CACHE_TTL = 43200  # (12 hours)
@@ -41,13 +42,15 @@ GRIB_ENDINGS = [".grb", ".grib", ".grb2", ".grib2"]
 DEFAULT_WATCH_DEBOUNCE = 1600
 
 # debounce when doing realtime processing
-REALTIME_WATCH_DEBOUNCE = 4500 
+REALTIME_WATCH_DEBOUNCE = 4500
 
 # Format of the route key to messages publish with
 GRIB_ROUTE_KEY_FMT = "{xcd_model}.{xcd_model_id}.realtime.reserved.reserved.2"
 
 # Cache path
-REALTIME_CACHE_PATH = os.path.join(os.path.expanduser('~/.cache/'), 'grib_processor_cache.json')
+REALTIME_CACHE_PATH = os.path.join(
+    os.path.expanduser("~/.cache/"), "grib_processor_cache.json"
+)
 
 # Types used for realtime caching
 _CEntry: TypeAlias = Tuple[float, int]
@@ -84,13 +87,13 @@ def _initialize_cache() -> _Cache:
     """
     LOG.info("Loading cache from disk")
     try:
-        with open(REALTIME_CACHE_PATH, 'r', encoding='utf-8') as cache_file:
+        with open(REALTIME_CACHE_PATH, "r", encoding="utf-8") as cache_file:
             cache_data = json.load(cache_file)
             LOG.info("realtime - cache loaded from disk at %s", REALTIME_CACHE_PATH)
     except (OSError, json.JSONDecodeError):
         LOG.exception("Error loading previous cache data from %s", REALTIME_CACHE_PATH)
         cache_data = {}
-        
+
     return defaultdict(_default_cache_entry, cache_data)
 
 
@@ -105,7 +108,7 @@ def _save_cache(cache: _Cache) -> None:
         LOG.debug("realtime - cache is empty")
         return
     try:
-        with open(REALTIME_CACHE_PATH, 'w', encoding='utf-8') as cache_file:
+        with open(REALTIME_CACHE_PATH, "w", encoding="utf-8") as cache_file:
             json.dump(cache, cache_file)
             LOG.info("realtime - cache saved to disk at %s", REALTIME_CACHE_PATH)
     except (OSError, ValueError):
@@ -145,7 +148,7 @@ def _clean_caches(realtime_cache: _Cache) -> None:
 
 def realtime(file_iter: Iterator[str]) -> Generator[GribMessage, None, None]:
     """Add message caching and exception handling to a grib file iterator. This
-    is useful if the iterator returns files that aren't fully written to or 
+    is useful if the iterator returns files that aren't fully written to or
     duplicates.
 
     Args:
@@ -160,27 +163,31 @@ def realtime(file_iter: Iterator[str]) -> Generator[GribMessage, None, None]:
     atexit.register(_save_cache, cache)
 
     for file in file_iter:
-        LOG.info('realtime - got file %s', file)
+        LOG.info("realtime - got file %s", file)
         next_msg = cache.get(file, (None, None))[1]
         if next_msg is None:
             LOG.debug("realtime - cache miss %s", file)
             cache[file] = _default_cache_entry()
             next_msg = 0
         else:
-            LOG.debug('realtime - cache hit %s @ %d', file, next_msg)
-        msg_num = None    
+            LOG.debug("realtime - cache hit %s @ %d", file, next_msg)
+        msg_num = None
         try:
-            for msg_num, msg in enumerate(itergrib(file, start_message=next_msg), next_msg):
-                LOG.debug('realtime - got msg %s:%d', file, msg_num)
+            for msg_num, msg in enumerate(
+                itergrib(file, start_message=next_msg), next_msg
+            ):
+                LOG.debug("realtime - got msg %s:%d", file, msg_num)
                 yield msg
         except (KeyError, OSError):
-            LOG.exception("realtime - Error processing grib file %s @ msg %d", file, msg_num or 0)
+            LOG.exception(
+                "realtime - Error processing grib file %s @ msg %d", file, msg_num or 0
+            )
 
-        if msg_num is not None: # msg_num is None when exception or end of grib file
+        if msg_num is not None:  # msg_num is None when exception or end of grib file
             cache[file] = time.time(), msg_num + 1
-        
+
         if time.time() - last_clean_time > CACHE_CLEAN_INTERVAL:
-            LOG.info('realtime - time to clean the cache')
+            LOG.info("realtime - time to clean the cache")
             _clean_caches(cache)
 
 
@@ -201,14 +208,19 @@ def grib_file_watch(
 
     Args:
         watch_dir (str): The directory to watch for grib files.
-        debounce (int | None): How long to wait to group file changes together before yielding. 
+        debounce (int | None): How long to wait to group file changes together before yielding.
         recursive (bool): Watch directories under watch_dir recursively?
-    
+
     Returns:
         Generator[str, None, None]: A grib file path generator.
     """
     debounce = debounce or DEFAULT_WATCH_DEBOUNCE
-    LOG.info("Watching directory %s for gribs with debounce=%d, recursive=%s", watch_dir, debounce, recursive)
+    LOG.info(
+        "Watching directory %s for gribs with debounce=%d, recursive=%s",
+        watch_dir,
+        debounce,
+        recursive,
+    )
     for changes in watch(
         watch_dir,
         watch_filter=grib_filter,
@@ -220,17 +232,16 @@ def grib_file_watch(
 
 
 def publish_message(msg: GribMessage) -> None:
-    """Publish a message to RabbitMQ servers using the ssec_amqp package. 
+    """Publish a message to RabbitMQ servers using the ssec_amqp package.
     This requires previous setup using ssec_amqp.api.connect().
     """
-    route_key = GRIB_ROUTE_KEY_FMT.format(xcd_model=msg.get('xcd_model'), xcd_model_id=msg.get('xcd_model_id'))
+    route_key = GRIB_ROUTE_KEY_FMT.format(
+        xcd_model=msg.get("xcd_model"), xcd_model_id=msg.get("xcd_model_id")
+    )
     status = mq.publish(msg, route_key=route_key)
-    LOG.info('Published to %s. status: %s', route_key, status)
+    LOG.info("Published to %s. status: %s", route_key, status)
 
 
 def dump_message(msg: GribMessage) -> None:
-    """Print a message to stdout.
-    """
+    """Print a message to stdout."""
     print(json.dumps(msg), flush=True)
-
-
-- 
GitLab