From 3983efcb7ca182cf0b6a8bf01784f6b353449851 Mon Sep 17 00:00:00 2001
From: David Hoese <david.hoese@ssec.wisc.edu>
Date: Mon, 6 Mar 2023 20:53:37 -0600
Subject: [PATCH] Fix a few style issues

---
 aosstower/level_b1/quicklook.py           | 137 ++++++++++------------
 aosstower/tests/level_00/test_influxdb.py |  25 +---
 aosstower/tests/level_b1/test_nc.py       |  11 +-
 aosstower/tests/utils.py                  |  12 +-
 pyproject.toml                            |   4 +
 5 files changed, 84 insertions(+), 105 deletions(-)
 mode change 100755 => 100644 aosstower/tests/level_00/test_influxdb.py

diff --git a/aosstower/level_b1/quicklook.py b/aosstower/level_b1/quicklook.py
index b765d7d..7ba728a 100644
--- a/aosstower/level_b1/quicklook.py
+++ b/aosstower/level_b1/quicklook.py
@@ -1,19 +1,19 @@
-import matplotlib
-
-matplotlib.use("agg")
-
 import logging
 import math
 import os
 import sys
 from datetime import datetime, timedelta
+from pathlib import Path
 
+import matplotlib as mpl
 import matplotlib.dates as md
 import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 from netCDF4 import MFDataset, MFTime
 
+mpl.use("agg")
+
 LOG = logging.getLogger(__name__)
 FIGURE_TITLE_SIZE = 13
 TN_SIZE = (1, 1)
@@ -59,31 +59,27 @@ class PlotMaker:
         delta = (end_time - start_time).total_seconds()
         if delta < timedelta(hours=24).total_seconds():
             return start_time.strftime("%Y-%m-%d")
-        else:
-            return f"{start_time:%Y-%m-%d %H:%M} to {end_time:%Y-%m-%d %H:%M}"
-
-    def get_title(self, frame, is_subplot, start_time, end_time):
-        if self._title:
-            title_prefix = "AO&SS Building Tower " if not is_subplot else ""
-            title_name = TITLES.get(self.name, self.name.replace("_", " ").title())
-            unit_str = f"({self.units})" if self.units and is_subplot else ""
-            date_string = self.get_date_string(start_time, end_time)
-            title = self._title.format(
-                title_prefix=title_prefix,
-                title_name=title_name,
-                units=unit_str,
-                date_string=date_string,
-            )
-        else:
-            title = ""
-        return title
+        return f"{start_time:%Y-%m-%d %H:%M} to {end_time:%Y-%m-%d %H:%M}"
+
+    def get_title(self, frame, is_subplot, start_time, end_time):  # noqa: ARG002
+        if not self._title:
+            return ""
+        title_prefix = "AO&SS Building Tower " if not is_subplot else ""
+        title_name = TITLES.get(self.name, self.name.replace("_", " ").title())
+        unit_str = f"({self.units})" if self.units and is_subplot else ""
+        date_string = self.get_date_string(start_time, end_time)
+        return self._title.format(
+            title_prefix=title_prefix,
+            title_name=title_name,
+            units=unit_str,
+            date_string=date_string,
+        )
 
     def get_yticks(self, ymin, ymax, num_plots):
         if ymin == ymax:
             return [ymin, ymin + 0.05, ymin + 0.1]
         delta = math.ceil((ymax - ymin) / num_plots)
-        new_ticks = np.arange(ymin, (ymin + delta * num_plots), delta)
-        return new_ticks
+        return np.arange(ymin, (ymin + delta * num_plots), delta)
 
     def _get_ylabel(self, is_subplot=False):
         y_label = TITLES.get(self.name, self.name.replace("_", " ").title())
@@ -103,8 +99,7 @@ class PlotMaker:
             ax.text(0.008, 0.9, self.units, horizontalalignment="left", va="top", transform=ax.transAxes, size=8)
 
     def _call_plot(self, frame, ax):
-        lines = ax.plot(frame.index, frame, "k")
-        return lines
+        return ax.plot(frame.index, frame, "k")
 
     def _set_ylim(self, frame, ax):
         ymin = np.floor(frame.min().min())
@@ -148,18 +143,17 @@ class PlotMaker:
         xloc = md.AutoDateLocator(minticks=5, maxticks=8, interval_multiples=True)
         xfmt = md.AutoDateFormatter(xloc)
 
-        def _fmt(interval, x, pos=None):
+        def _fmt(interval, x, pos=None):  # noqa: ARG001
             x_num = md.num2date(x).replace(tzinfo=None)
             delta_seconds = (x_num - start_time.replace(hour=0, minute=0, second=0, microsecond=0)).total_seconds()
             num_hours = delta_seconds / 3600.0
             if interval == md.HOURLY:
                 return f"{num_hours:.0f}"
-            elif interval == md.MINUTELY:
+            if interval == md.MINUTELY:
                 num_minutes = delta_seconds / 60.0
                 num_minutes -= int(num_hours) * 60.0
                 return f"{int(num_hours):02.0f}:{num_minutes:02.0f}"
-            else:
-                return x.strftime("{%Y-%m-%d}")
+            return x.strftime("{%Y-%m-%d}")
 
         from functools import partial
 
@@ -212,15 +206,11 @@ class PlotMaker:
         return fig
 
     def create_plot(self, frame, fig, start_time=None, end_time=None, is_subplot=None, shared_x=None, title=None):
-        """:param frame:
-        :param fig:
-        :param is_subplot: None or (num plots, num columns, num_rows)
-        :param shared_x:
-        :return:
-        """
+        """Create series of plots."""
         specific_frame = frame[[x for x in frame.columns if x in self.deps]]
-        if frame.empty or specific_frame.empty or specific_frame.isnull().all().any():
-            raise ValueError(f"No valid data found or missing necessary data to make {self.name}")
+        if frame.empty or specific_frame.empty or specific_frame.isna().all().any():
+            msg = f"No valid data found or missing necessary data to make {self.name}"
+            raise ValueError(msg)
         if start_time is None:
             start_time = frame.index[0].to_pydatetime()
         if end_time is None:
@@ -263,15 +253,14 @@ class TDPlotMaker(PlotMaker):
 
 
 class WindDirPlotMaker(PlotMaker):
-    def _set_ylim(self, frame, ax):
+    def _set_ylim(self, frame, ax):  # noqa: ARG002
         return 0, 360
 
-    def _set_yticks(self, ax, ymin, ymax, is_subplot):
+    def _set_yticks(self, ax, ymin, ymax, is_subplot):  # noqa: ARG002
         ax.yaxis.set_ticks([0, 90, 180, 270])
 
     def _call_plot(self, frame, ax):
-        lines = ax.plot(frame.index, frame, "k.", markersize=3, linewidth=0)
-        return lines
+        return ax.plot(frame.index, frame, "k.", markersize=3, linewidth=0)
 
 
 class MeteorogramPlotMaker(PlotMaker):
@@ -282,8 +271,8 @@ class MeteorogramPlotMaker(PlotMaker):
         super().__init__(name, dependencies, title=title)
 
     def _convert_ax_to_thumbnail(self, fig):
-        if hasattr(fig, "_my_axes"):
-            for k, ax in fig._my_axes.items():
+        if hasattr(fig, "my_axes"):
+            for k, ax in fig.my_axes.items():
                 if k not in self.thumbnail_deps:
                     fig.delaxes(ax)
                     continue
@@ -305,7 +294,8 @@ class MeteorogramPlotMaker(PlotMaker):
 
     def create_plot(self, frame, fig, start_time=None, end_time=None, is_subplot=False, shared_x=None, title=None):
         if is_subplot or shared_x:
-            raise ValueError("Meteorogram Plot can not be a subplot or share X-axis")
+            msg = "Meteorogram Plot can not be a subplot or share X-axis"
+            raise ValueError(msg)
 
         if start_time is None:
             start_time = frame.index[0].to_pydatetime()
@@ -317,8 +307,8 @@ class MeteorogramPlotMaker(PlotMaker):
 
         num_plots = len(self.plot_deps)
         shared_x = None
-        # some hacky book keeping so we can create a thumbnail
-        fig._my_axes = {}
+        # some hacky bookkeeping so we can create a thumbnail
+        fig.my_axes = {}
         for idx, plot_name in enumerate(self.plot_deps):
             plot_maker = PLOT_TYPES.get(plot_name, PlotMaker(plot_name, (plot_name,)))
             title_name = TITLES.get(plot_name, plot_name.replace("_", " ").title())
@@ -329,14 +319,14 @@ class MeteorogramPlotMaker(PlotMaker):
                 shared_x=shared_x,
                 title=title_name,
             )
-            fig._my_axes[plot_name] = ax
+            fig.my_axes[plot_name] = ax
             if idx == 0:
                 shared_x = ax
             if idx != num_plots - 1:
                 # Disable the x-axis ticks so we don't interfere with other subplots
                 kwargs = {"visible": False}
-                for l in ax.get_xticklabels():
-                    l.update(kwargs)
+                for label in ax.get_xticklabels():
+                    label.update(kwargs)
             # make the top y-tick label invisible
 
         ax.set_xlabel("Time (UTC)")
@@ -392,16 +382,15 @@ def get_data(input_files, columns):
 
 
 def create_plot(plot_names, frame, output, start_time=None, end_time=None, thumbnail=False):
-    """Args:
-        plot_names:
-        frame:
-        output:
-        start_time:
-        end_time:
-        daily: Whether or not this plot should represent one day of data.
+    """Create a series of plots.
 
-    Returns
-    -------
+    Args:
+        plot_names (list of str): Plot names to generate
+        frame (pd.DataFrame): DataFrame of data
+        output (str): Output pattern for the created files
+        start_time (datetime): Start time of the data to use
+        end_time (datetime): End time of the data to use
+        thumbnail (bool): Additionally generate a thumbnail
 
     """
     if start_time is None:
@@ -414,7 +403,8 @@ def create_plot(plot_names, frame, output, start_time=None, end_time=None, thumb
         var_names = []
         for var_name in plot_maker.deps:
             if var_name not in frame:
-                raise ValueError(f"Missing required variable '{var_name}' for plot '{name}'")
+                msg = f"Missing required variable '{var_name}' for plot '{name}'"
+                raise ValueError(msg)
             var_names.append(var_name)
             # write NaNs where QC values are not 0
             qc_name = "qc_" + var_name
@@ -424,7 +414,7 @@ def create_plot(plot_names, frame, output, start_time=None, end_time=None, thumb
 
         # create a frame that doesn't include any of the bad values
         plot_frame = frame[var_names]
-        plot_frame = plot_frame[~plot_frame.isnull().any(axis=1)]
+        plot_frame = plot_frame[~plot_frame.isna().any(axis=1)]
 
         fig = plt.figure()
         try:
@@ -438,7 +428,9 @@ def create_plot(plot_names, frame, output, start_time=None, end_time=None, thumb
         fig.savefig(out_fn)
 
         if thumbnail:
-            stem, ext = os.path.splitext(out_fn)
+            out_path = Path(out_fn)
+            stem = out_path.stem
+            ext = out_path.suffix
             out_fn = f"{stem}_thumbnail{ext}"
             plot_maker.convert_to_thumbnail(fig)
             LOG.info(f"Saving thumbnail '{name}' to filename '{out_fn}'")
@@ -468,7 +460,7 @@ def main():
         action="count",
         default=int(os.environ.get("VERBOSITY", 2)),
         dest="verbosity",
-        help=("each occurence increases verbosity 1 level through" + " ERROR-WARNING-INFO-DEBUG (default INFO)"),
+        help=("each occurence increases verbosity 1 level through ERROR-WARNING-INFO-DEBUG (default INFO)"),
     )
     parser.add_argument(
         "-l",
@@ -481,14 +473,14 @@ def main():
         "--start-time",
         type=_dt_convert,
         help="Start time of plot. If only -s is given, a plot of "
-        + "only that day is created. Formats allowed: 'YYYY-MM-DDTHH:MM:SS', 'YYYY-MM-DD'",
+        "only that day is created. Formats allowed: 'YYYY-MM-DDTHH:MM:SS', 'YYYY-MM-DD'",
     )
     parser.add_argument(
         "-e",
         "--end-time",
         type=_dt_convert,
         help="End time of plot. If only -e is given, a plot of only that day is "
-        + "created. Formats allowed: 'YYYY-MM-DDTHH:MM:SS', 'YYYY-MM-DD', 'YYYYMMDD'",
+        "created. Formats allowed: 'YYYY-MM-DDTHH:MM:SS', 'YYYY-MM-DD', 'YYYYMMDD'",
     )
     parser.add_argument("--met-plots", nargs="+", help="Override plots to use in the combined meteorogram plot")
     parser.add_argument("input_files", nargs="+", help="aoss_tower_level_b1 files")
@@ -506,19 +498,21 @@ def main():
         "--daily",
         action="store_true",
         help="creates a plot for every day. Usually used to create plots "
-        + "that will line up for aoss tower quicklooks page",
+        "that will line up for aoss tower quicklooks page",
     )
     args = parser.parse_args()
 
     levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
     setup_logging(args.log_filepath, level=levels[min(3, args.verbosity)])
 
-    if not os.path.splitext(args.output)[-1]:
+    if not Path(args.output).suffix:
         LOG.warning("File pattern provided does not have a file extension")
 
     # check the dependencies for the meteorogram
     if args.met_plots:
-        assert "meteorogram" not in args.met_plots
+        if "meteorogram" in args.met_plots:
+            msg = "The 'meteorogram' plot can not be a sub-plot of itself."
+            raise ValueError(msg)
         PLOT_TYPES["meteorogram"].deps = args.met_plots
 
     plot_deps = [PLOT_TYPES[k].deps if k in PLOT_TYPES else (k,) for k in args.plot_names]
@@ -536,11 +530,8 @@ def main():
         end_time = args.start_time.replace(hour=23, minute=59, second=59, microsecond=999999)
         frame = frame[args.start_time : end_time]
 
-    if not args.daily:
-        # allow plotting methods to write inplace on a copy
-        frames = [frame.copy()]
-    else:
-        frames = (group[1] for group in frame.groupby(frame.index.day))
+    # allow plotting methods to write inplace on a copy
+    frames = [frame.copy()] if not args.daily else (group[1] for group in frame.groupby(frame.index.day))
 
     for frame in frames:
         if args.daily:
diff --git a/aosstower/tests/level_00/test_influxdb.py b/aosstower/tests/level_00/test_influxdb.py
old mode 100755
new mode 100644
index ed7fbde..b82f14a
--- a/aosstower/tests/level_00/test_influxdb.py
+++ b/aosstower/tests/level_00/test_influxdb.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
 import datetime
 import unittest
 
@@ -8,9 +7,9 @@ import pandas as pd
 from aosstower.level_00.influxdb import Updater, construct_url, get_url_data
 
 
-class TestCase:
-    def __init__(self, input, expected_avg, expected_url):
-        self.input = input
+class TestDataCase:
+    def __init__(self, input_data, expected_avg, expected_url):
+        self.input_data = input_data
         self.expected_avg = expected_avg
         self.expected_url = expected_url
 
@@ -53,7 +52,7 @@ def _isnan(x):
 class TestInfluxdb(unittest.TestCase):
     def setUp(self):
         self.updater = Updater()
-        self.test_data = TestCase(
+        self.test_data = TestDataCase(
             create_data(209),
             [
                 {
@@ -211,7 +210,7 @@ class TestInfluxdb(unittest.TestCase):
 
     def test_updater(self):
         output = []
-        for record in self.test_data.input:
+        for record in self.test_data.input_data:
             avg = self.updater.rolling_average(record)
             if avg is not None:
                 output.append({key: avg[key][-1] for key in avg})
@@ -225,22 +224,10 @@ class TestInfluxdb(unittest.TestCase):
 
     def test_construct_url(self):
         output = []
-        for record in self.test_data.input:
+        for record in self.test_data.input_data:
             avg = self.updater.rolling_average(record)
             if avg is not None:
                 output.append(construct_url(get_url_data(avg, "", "")))
                 assert len(self.test_data.expected_url) >= len(output)
                 assert self.test_data.expected_url[len(output) - 1] == output[-1]
         assert len(self.test_data.expected_url) == len(output)
-
-
-def suite():
-    """The test suite for influxdb."""
-    loader = unittest.TestLoader()
-    mysuite = unittest.TestSuite()
-    mysuite.addTest(loader.loadTestsFromTestCase(TestInfluxdb))
-    return mysuite
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/aosstower/tests/level_b1/test_nc.py b/aosstower/tests/level_b1/test_nc.py
index 9a75a17..ba70741 100644
--- a/aosstower/tests/level_b1/test_nc.py
+++ b/aosstower/tests/level_b1/test_nc.py
@@ -1,7 +1,5 @@
-#!/usr/bin/env python
 """Test basic NetCDF generation."""
 
-import os
 from datetime import datetime
 
 import numpy as np
@@ -13,11 +11,10 @@ def get_nc_schema_database(fields=None):
 
     if fields is None:
         fields = schema.met_vars
-    mini_database = {k: schema.database_dict[k] for k in fields}
-    return mini_database
+    return {k: schema.database_dict[k] for k in fields}
 
 
-def test_nc_basic1(tmpdir):
+def test_nc_basic1(tmp_path):
     """Test basic usage of the NetCDF generation."""
     from netCDF4 import Dataset
 
@@ -25,7 +22,7 @@ def test_nc_basic1(tmpdir):
     from aosstower.tests.utils import get_cached_level_00
 
     input_files = list(get_cached_level_00(num_files=2))
-    nc_out = tmpdir.join("test.nc")
+    nc_out = tmp_path / "test.nc"
     create_giant_netcdf(
         input_files,
         str(nc_out),
@@ -35,7 +32,7 @@ def test_nc_basic1(tmpdir):
         interval_width="1T",
         database=get_nc_schema_database(),
     )
-    assert os.path.isfile(nc_out)
+    assert nc_out.is_file()
     with Dataset(nc_out, "r") as nc:
         sflux = nc["solar_flux"][:]
         assert np.count_nonzero(sflux.mask) == 2
diff --git a/aosstower/tests/utils.py b/aosstower/tests/utils.py
index f43dbf8..3294819 100644
--- a/aosstower/tests/utils.py
+++ b/aosstower/tests/utils.py
@@ -1,10 +1,9 @@
-#!/usr/bin/env python
 """Utilities for running tests."""
 
-import os
 import urllib.request
+from pathlib import Path
 
-CACHE_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data")
+CACHE_DIR = Path(__file__).resolve().parent / "test_data"
 
 
 def get_cached_level_00(num_files=2):
@@ -14,11 +13,12 @@ def get_cached_level_00(num_files=2):
         "https://metobs-test.ssec.wisc.edu/pub/cache/aoss/tower/level_00/version_00/2020/01/02/aoss_tower.2020-01-02.ascii",
     )
     if num_files > len(file_urls):
-        raise ValueError(f"Maximum of {len(file_urls)} can be loaded.")
+        msg = f"Maximum of {len(file_urls)} can be loaded."
+        raise ValueError(msg)
     file_urls = file_urls[:num_files]
 
     for u in file_urls:
-        fn = os.path.join(CACHE_DIR, os.path.basename(u))
-        if not os.path.isfile(fn):
+        fn = CACHE_DIR / u.rsplit("/", 1)[-1]
+        if not fn.is_file():
             urllib.request.urlretrieve(u, fn)
         yield fn
diff --git a/pyproject.toml b/pyproject.toml
index 9a9d4ff..20dc9ef 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,6 +38,10 @@ ignore = ["D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107", "D203"
 
 [tool.ruff.per-file-ignores]
 "aosstower/tests/*" = ["S", "PLR2004"]
+"aosstower/level_b1/quicklook.py" = ["PLR0913"]
+
+[tool.ruff.pydocstyle]
+convention = "google"
 
 [tool.mypy]
 python_version = "3.10"
-- 
GitLab