"""Generate quicklook images for the ceilometer instrument."""

import datetime
import logging
import os
import pathlib
import sys
from calendar import timegm
from functools import reduce
from typing import Iterable, Iterator

# Must set backend for matplotlib before importing pylab
import matplotlib as mpl
import numpy as np
import numpy.typing as npt
from netCDF4 import Dataset

mpl.use("agg")
import matplotlib.pyplot as plt

LOG = logging.getLogger(__name__)


def get_img_filename(begin, end, site="", description="", tag=""):
    """Generate filename for quicklook image."""
    inst = "ceilo"
    ext = "png"
    sfx = (tag and "_" + tag or "") + "." + ext
    descr = description and "_" + description or ""
    s = f"{site.lower()}_{inst}{descr.lower()}"
    # add plotname to beginning if we have one
    if not end:
        return begin.strftime(s + ".%Y-%m-%d" + sfx)

    if end < begin:
        raise ValueError("Ending date cannot precede beginning date")

    fmt = ".%Y-%m-%d_%H%M%S"

    # stamps are in the same day
    if end.year == begin.year and end.month == begin.month and end.day == begin.day:
        return end.strftime(s + begin.strftime(fmt) + "_%H%M%S" + sfx)

    # stamps are in different days
    fmt = ".%Y-%m-%d_%H%M%S"
    return begin.strftime(s + fmt + end.strftime(fmt) + sfx)


def make_plot(
    ncfiles: list[pathlib.Path],
    quality: int = 1,
    site: str = "",
    description: str = "",
    start_date: datetime.datetime | None = None,
    end_date: datetime.datetime | None = None,
    daily: bool = False,
    out_dir: pathlib.Path | None = None,
) -> list[pathlib.Path]:
    """Create quicklook plots for provided NetCDF data files."""
    if not out_dir:
        out_dir = pathlib.Path.cwd()

    datasets = tuple(_get_datasets(ncfiles))
    if not datasets:
        LOG.error("No valid nc files found")
        exit(1)

    all_x = _combine(_get_data(datasets, ("time_offset", "base_time")))
    # get order to sort all other data
    order = all_x.argsort()
    all_x = all_x[order]

    all_z = _combine(_get_data(datasets, ("backscatter",)))[order]
    all_fcbh = _combine(_get_data(datasets, ("first_cbh",)))[order]
    all_scbh = _combine(_get_data(datasets, ("second_cbh",)))[order]
    all_tcbh = _combine(_get_data(datasets, ("third_cbh",)))[order]

    # get first day base_time
    if not start_date:
        base_time = all_x[0]
        dt = datetime.datetime.fromtimestamp(base_time.item(), tz=datetime.timezone.utc)
    else:
        base_time = timegm(start_date.timetuple())
        dt = start_date

    # start first record 0 seconds from "base_time"
    all_x -= base_time

    for d in datasets:
        d.close()

    # crop data with mask after start
    mask = all_x >= 0

    # crop again to 24 after end_date kwarg
    if end_date:
        mask &= all_x <= (
            timegm((end_date + datetime.timedelta(days=1)).timetuple()) - base_time
        )

    x, y, z, fcbh, scbh, tcbh = _prepare_data(
        all_x[mask], all_z[mask], all_fcbh[mask], all_scbh[mask], all_tcbh[mask]
    )

    try:
        created_files = _plot_quicklook(
            x,
            y,
            z,
            fcbh,
            scbh,
            tcbh,
            dt,
            quality,
            out_dir,
            site=site,
            description=description,
            daily=daily,
        )
    except Exception as e:
        LOG.exception("Error in matplotlib or pylab during plotting")
        raise e
    finally:
        plt.close("all")
    return created_files


def _get_datasets(files):
    """Get all datasets that exist."""
    for fn in files:
        try:
            yield Dataset(fn)
        except FileNotFoundError:
            LOG.warning(f"{os.path.basename(fn)} not found, skipping quicklook")


def _get_data(
    datasets: tuple[Dataset, ...], vs: tuple[str, ...]
) -> Iterator[npt.ArrayLike]:
    """Generate vars of datasets added together, only added together for time_offset + base_time."""
    for d in datasets:
        yield reduce(lambda a, b: a + b, (d.variables[v][:] for v in vs))


def _combine(arrs: Iterable[np.typing.ArrayLike]) -> np.ndarray:
    return np.concatenate(list(arrs), axis=0)


def _prepare_data(x, z, fcbh, scbh, tcbh):
    y = np.arange(0, z[0, :].size * 30, 30) / 1000.0
    z = z.transpose()
    x = x / 3600.0

    zshape = np.shape(z)
    z = np.concatenate(z)
    indp = np.nonzero(np.greater(z, 0))[0]
    z[indp] = abs(np.log10(z[indp]))
    z = np.reshape(z, zshape)
    z = np.ma.masked_where(z <= 0, z)
    fcbh = np.ma.masked_where(fcbh == -9999, fcbh)
    fcbh = fcbh / 1000.0
    scbh = np.ma.masked_where(scbh == -9999, scbh)
    scbh = scbh / 1000.0
    tcbh = np.ma.masked_where(tcbh == -9999, tcbh)
    tcbh = tcbh / 1000.0

    return x, y, z, fcbh, scbh, tcbh


def _plot_quicklook(
    x: np.ndarray,
    y: np.ndarray,
    z: np.ndarray,
    fcbh: np.ndarray,
    scbh: np.ndarray,
    tcbh: np.ndarray,
    dt: datetime.datetime,
    quality: int,
    image_dir: pathlib.Path,
    site: str = "",
    description: str = "",
    daily: bool = False,
    end: datetime.datetime | None = None,
) -> list[pathlib.Path]:
    if daily:
        daily_plots = tuple(
            _plot_daily(
                dt,
                x,
                y,
                z,
                fcbh,
                scbh,
                tcbh,
                quality,
                image_dir,
                site,
                description,
            )
        )
        # flatten list of filenames
        return [fn for created_filenames in daily_plots for fn in created_filenames]

    # generate thumbnail
    fig = plt.figure(figsize=(1, 1))
    spt = plt.subplot(1, 1, 1, facecolor="black")

    if not end:
        end = dt + datetime.timedelta(hours=x[-1])

    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)

    plt.pcolor(
        x[0::quality],
        y,
        z[:, 0::quality],
        cmap=plt.get_cmap("plasma"),
        shading="auto",
        vmin=0,
        vmax=10,
    )
    plt.plot(
        x[0::quality],
        fcbh[0::quality],
        marker="o",
        color="red",
        markeredgecolor="white",
        linewidth=0,
        markersize=3,
    )
    plt.plot(
        x[0::quality],
        scbh[0::quality],
        marker="o",
        color="#003300",
        markeredgecolor="white",
        linewidth=0,
        markersize=3,
    )
    plt.plot(
        x[0::quality],
        tcbh[0::quality],
        marker="o",
        color="purple",
        markeredgecolor="white",
        linewidth=0,
        markersize=3,
    )
    plt.axis((0, max(x[-1], 24), 0, 7))

    tl = spt.yaxis.get_ticklines()
    for t in tl:
        t.set_visible(False)

    image_dir.mkdir(parents=True, exist_ok=True)
    filename = get_img_filename(dt, end, tag="tn", site=site, description=description)
    tn_file = image_dir / filename
    plt.savefig(tn_file)
    LOG.info("Ceilometer Backscatter Thumbnail saved as %s" % tn_file)

    # generate quicklook
    # reset to default margins
    plt.subplots_adjust(
        left=mpl.rcParams["figure.subplot.left"],
        right=mpl.rcParams["figure.subplot.right"],
        bottom=mpl.rcParams["figure.subplot.bottom"],
        top=mpl.rcParams["figure.subplot.top"],
    )
    for t in tl:
        t.set_visible(True)
    # reset to default size
    fig.set_size_inches(mpl.rcParams["figure.figsize"])
    cbar = plt.colorbar()
    cbar.set_label("1.0E⁻⁷ m⁻¹·sr⁻¹", rotation=-90)
    plt.ylabel("Altitude (km)")
    plt.xlabel("Time (UTC)")
    title = f"AO&SS Ceilometer Backscatter {dt.strftime('%m/%d/%Y')}"
    plt.title(title)

    filename = get_img_filename(dt, end, site=site, description=description)
    ql_file = image_dir / filename
    plt.savefig(ql_file)
    LOG.info("Ceilometer Backscatter Quicklook saved as %s" % ql_file)
    plt.close(fig=fig)

    return [tn_file, ql_file]


def _day_generator(s):
    now = datetime.datetime(
        year=s.year, month=s.month, day=s.day, tzinfo=datetime.timezone.utc
    )
    while True:
        end = now + datetime.timedelta(days=1)
        yield now, end - datetime.timedelta(seconds=1)
        now = end


def _plot_daily(dt, x, y, z, fcbh, scbh, tcbh, quality, image_dir, site, description):
    for start, end in _day_generator(dt):
        first_hour = (start - dt) / datetime.timedelta(hours=1)
        last_hour = (end - dt) / datetime.timedelta(hours=1)
        mask = (x >= first_hour) & (x < last_hour)
        if not mask.any():
            break
        yield _plot_quicklook(
            x[mask] - first_hour,
            y,
            z[:, mask],
            fcbh[mask],
            scbh[mask],
            tcbh[mask],
            start,
            quality,
            image_dir,
            site,
            description,
            end=end,
            daily=False,
        )


def main():
    """Handle command line arguments if run as a script.

    Run first from command-line.

    No Arguments or Run as Module
    =============================
    Creates, by default and by running the make_plot() function, full size quicklook plot of Ceilometer data
    and a thumbnail quicklook image without a title of the same plot.


    Run Script
    ==========
    Usage::
        quicklook_ceilo.py [-t | -f filename | --quality=1 | --date="MM/DD/YYYY" | --blank]
        Try `python quicklook_ceilo.py -h' for more information.
        Example: "python quicklook_ceilo.py -tf $HOME/my_ceilo_data.nc"

    """
    import argparse

    def _get_date(date_str):
        return datetime.datetime.strptime(date_str, "%Y-%m-%d")

    parser = argparse.ArgumentParser(
        description="Create a full size and thumbnail quicklook plot of ceilometer data"
    )
    parser.add_argument(
        "filenames",
        type=pathlib.Path,
        nargs="+",
        help="specify NetCDF files to plot",
    )
    parser.add_argument(
        "--quality",
        dest="quality",
        default="1",
        type=int,
        help="specify the number of data points to skip by",
    )
    parser.add_argument(
        "-l",
        dest="lockfile",
        default="ceilo_quicklook.lock",
        help="specify the lockfile to use",
    )
    parser.add_argument(
        "--description",
        dest="description",
        default="",
        help="the description used in filenames",
    )
    parser.add_argument("--site", dest="site", default="aoss", help="the site name")
    parser.add_argument(
        "--start-date",
        dest="start_date",
        type=_get_date,
        help="the first day of data to plot",
    )
    parser.add_argument(
        "--end-date",
        dest="end_date",
        type=_get_date,
        help="the last day of data to plot",
    )
    parser.add_argument(
        "--daily",
        dest="daily",
        action="store_true",
        help="split plots into daily files",
    )
    parser.add_argument(
        "-o",
        "--out-dir",
        dest="out_dir",
        type=pathlib.Path,
        default=None,
        help="output directory for quicklooks",
    )
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    try:
        make_plot(
            ncfiles=args.filenames,
            quality=args.quality,
            site=args.site,
            description=args.description,
            start_date=args.start_date,
            end_date=args.end_date,
            daily=args.daily,
            out_dir=args.out_dir,
        )
    finally:
        LOG.debug("terminating", exc_info=True)


if __name__ == "__main__":
    sys.exit(main())