From 9c01ce884889b78a165b10c60cd52611ff8cf83d Mon Sep 17 00:00:00 2001
From: Max Drexler <mndrexler@wisc.edu>
Date: Wed, 17 Jul 2024 20:21:58 +0000
Subject: [PATCH] separate the amqp payload from grib metadata

---
 grib_processor/grib.py | 107 ++++++++++++++++++++++++-----------------
 1 file changed, 64 insertions(+), 43 deletions(-)

diff --git a/grib_processor/grib.py b/grib_processor/grib.py
index c49f8fc..bb2aeae 100644
--- a/grib_processor/grib.py
+++ b/grib_processor/grib.py
@@ -34,16 +34,9 @@ from typing_extensions import Literal, TypedDict
 _XCD_MODELS = None
 
 
-class GribPayload(TypedDict):
-    """Amqp payload for each grib message."""
+class GribMetadata(TypedDict):
+    """Metadata extracted from each grib message."""
 
-    __payload_gen_time__: datetime
-    __injector_script__: str
-    path: str
-    directory: str
-    file_name: str
-    server_ip: str
-    server_type: Literal["realtime", "archive"]
     first_lat: float | None
     first_lon: float | None
     last_lat: float | None
@@ -70,6 +63,22 @@ class GribPayload(TypedDict):
     title: str
 
 
+class GribPayload(GribMetadata):
+    """Amqp payload for each grib message.
+
+    Includes metadata about the grib message and metadata about
+    the file and server itself.
+    """
+
+    __payload_gen_time__: datetime
+    __injector_script__: str
+    path: str
+    directory: str
+    file_name: str
+    server_ip: str
+    server_type: Literal["realtime", "archive"]
+
+
 def load_xcd_models(
     *addtnl_models: tuple[float, float, int, int, int, str, str],
 ) -> None:
@@ -105,6 +114,50 @@ def xcd_lookup(
         return ("UNKWN", "UNKWN")
 
 
+def extract_metadata(msg: grib2io.Grib2Message) -> GribMetadata:
+    """Extract metadata about a grib2 message created using grib2io.
+
+    Returns:
+        dictionary with keys defined by GribMetadata.
+    """
+    f_lat = getattr(msg, "latitudeFirstGridpoint", None)
+    f_lon = getattr(msg, "longitudeFirstGridpoint", None)
+    rows = msg.ny
+    cols = msg.nx
+    gen_proc = msg.generatingProcess
+    xcd_info = (
+        "UNKWN",
+        "UNKWN",
+    )  # xcd_lookup(f_lat, f_lon, rows, cols, gen_proc.value)
+    return GribMetadata(
+        first_lat=f_lat,
+        last_lat=getattr(msg, "latitudeLastGridpoint", None),
+        first_lon=f_lon,
+        last_lon=getattr(msg, "longitudeLastGridpoint", None),
+        forecast_hour=int(msg.valueOfForecastTime),
+        run_hour=int(msg.hour) * 100 + int(msg.minute),
+        model_time=amqp_utils.format_datetime(msg.refDate),
+        start_time=amqp_utils.format_datetime(msg.refDate),
+        projection=msg.gridDefinitionTemplateNumber.definition,
+        center_id=int(msg.originatingCenter.value),
+        center_desc=msg.originatingCenter.definition,
+        level=msg.level,
+        parameter=msg.shortName,
+        param_long=msg.fullName,
+        param_units=msg.units,
+        grib_number=2,
+        size="%d x %d" % (rows, cols),
+        ensemble=getattr(msg, "pertubationNumber", None),
+        model_id=int(gen_proc.value),
+        model=gen_proc.definition,
+        xcd_model=xcd_info[1],
+        xcd_model_id=xcd_info[0],
+        # resolution=msg._xydivisor, THIS IS NOT CORRECT!
+        grib_record_number=msg._msgnum,
+        title=str(msg),
+    )
+
+
 def itergrib(
     grib_path: str, start_message: int | None = None
 ) -> Generator[GribPayload, None, None]:
@@ -120,15 +173,7 @@ def itergrib(
     start_message = start_message or 0
     with grib2io.open(grib_path) as grib_file:
         for msg in grib_file[start_message:]:
-            f_lat = getattr(msg, "latitudeFirstGridpoint", None)
-            f_lon = getattr(msg, "longitudeFirstGridpoint", None)
-            rows = msg.ny
-            cols = msg.nx
-            gen_proc = msg.generatingProcess
-            xcd_info = (
-                "UNKWN",
-                "UNKWN",
-            )  # xcd_lookup(f_lat, f_lon, rows, cols, gen_proc.value)
+            meta = extract_metadata(msg)
             yield GribPayload(
                 __payload_gen_time__=amqp_utils.format_datetime(datetime.now()),
                 __injector_script__=amqp_utils.INJECTOR_SCRIPT,
@@ -137,29 +182,5 @@ def itergrib(
                 file_name=os.path.basename(grib_path),
                 server_ip=amqp_utils.SERVER_NAME,
                 server_type="realtime",
-                first_lat=f_lat,
-                last_lat=getattr(msg, "latitudeLastGridpoint", None),
-                first_lon=f_lon,
-                last_lon=getattr(msg, "longitudeLastGridpoint", None),
-                forecast_hour=int(msg.valueOfForecastTime),
-                run_hour=int(msg.hour) * 100 + int(msg.minute),
-                model_time=amqp_utils.format_datetime(msg.refDate),
-                start_time=amqp_utils.format_datetime(msg.refDate),
-                projection=msg.gridDefinitionTemplateNumber.definition,
-                center_id=int(msg.originatingCenter.value),
-                center_desc=msg.originatingCenter.definition,
-                level=msg.level,
-                parameter=msg.shortName,
-                param_long=msg.fullName,
-                param_units=msg.units,
-                grib_number=2,
-                size="%d x %d" % (rows, cols),
-                ensemble=getattr(msg, "pertubationNumber", None),
-                model_id=int(gen_proc.value),
-                model=gen_proc.definition,
-                xcd_model=xcd_info[1],
-                xcd_model_id=xcd_info[0],
-                # resolution=msg._xydivisor, THIS IS NOT CORRECT!
-                grib_record_number=msg._msgnum,
-                title=str(msg),
+                **meta,
             )
-- 
GitLab