#!/usr/bin/env python3
# Based on https://github.com/deeplycloudy/glmtools/blob/master/examples/grid/make_GLM_grids.py

parse_desc = """Grid the past X minutes of GLM flash data, given a single input file.
"""

import numpy as np
from datetime import datetime, timedelta
import os
import sys
import tempfile
import shutil
import atexit
from glob import glob
#from multiprocessing import freeze_support # https://docs.python.org/2/library/multiprocessing.html#multiprocessing.freeze_support
from functools import partial
from lmatools.grid.make_grids import write_cf_netcdf_latlon, write_cf_netcdf_noproj, write_cf_netcdf_fixedgrid
from lmatools.grid.make_grids import dlonlat_at_grid_center, grid_h5flashfiles
from glmtools.grid.make_grids import grid_GLM_flashes
from glmtools.io.glm import parse_glm_filename
from lmatools.grid.fixed import get_GOESR_grid, get_GOESR_coordsys

import logging

log = logging.getLogger(__name__)

def create_parser():
    import argparse
    parser = argparse.ArgumentParser(description=parse_desc)
    parser.add_argument('-v', '--verbose', dest='verbosity', action="count", default=0,
                        help='each occurrence increases verbosity 1 level through ERROR-WARNING-INFO-DEBUG (default INFO)')
    parser.add_argument('-l', '--log', dest="log_fn", default=None,
                        help="specify the log filename")
    # from Requirements: "Output is Gridded GLM in the native glmtools NetCDF4 format, with a user option to produce AWIPS-compatible NetCDF tiles as described below"
    parser.add_argument('-o', '--output-dir', metavar='output directory',
                        default=os.getcwd())
    parser.add_argument('--goes-sector', default="full",
                        help="One of [full|conus|meso]. "
                             "Requires goes_position. If sector is "
                             "meso, ctr_lon and ctr_lat are interpreted as "
                             "the ctr_x and ctr_y of the fixed grid")
    parser.add_argument('--goes-position', default="auto",
                        help="One of [east|west|test|auto]. "
                             "Requires '--goes-sector'.")
    parser.add_argument("-t", "--create-tiles", default=False, action='store_true',
                        help="create AWIPS-compatible tiles") # FIXME: improve this help text
    parser.add_argument('--ctr-lat', metavar='latitude',
                        type=float, help='center latitude (required for meso)')
    parser.add_argument('--ctr-lon', metavar='longitude',
                        type=float, help='center longitude (required for meso)')
    # from Requirements: "Input is one or more GLM LCFA (L2+) files in mission standard format (nominally three 20-second input files)"
    parser.add_argument(dest='filenames', metavar='filename', nargs='+')
    return parser

"""
old arguments for reference

FIXME: remove this whole comment once everything is working

    parser.add_argument('--dx', metavar='km',
                        default=10.0, type=float,
                        help='approximate east-west grid spacing')
    parser.add_argument('--dy', metavar='km',
                        default=10.0, type=float,
                        help='approximate north-south grid spacing')
    parser.add_argument('--dt', metavar='seconds',
                        default=60.0, type=float,
                        help='frame duration')
    parser.add_argument('--width', metavar='distance in km',
                        default=400.0,
                        type=float, help='total width of the grid')
    parser.add_argument('--height', metavar='distance in km',
                        default=400.0,
                        type=float, help='total height of the grid')
    parser.add_argument('--nevents', metavar='minimum events per flash',
                        type=int, dest='min_events', default=1,
                        help='minimum number of events per flash')
    parser.add_argument('--ngroups', metavar='minimum groups per flash',
                        type=int, dest='min_groups', default=1,
                        help='minimum number of groups per flash')
    parser.add_argument('--subdivide-grid', metavar='sqrt(number of subgrids)',
                        type=int, default=1,
                        help="subdivide the grid this many times along "
                             "each dimension")
"""


def get_resolution(args):
    closest_resln = 2.0 # hardcoding resolution to 2.0 for now. see nearest_resolution in make_glm_grids for how we could expose this if we change our minds.
    resln = '{0:4.1f}km'.format(closest_resln).replace(' ', '')
    return resln


# if provided "auto" position, we determine the sensor from the filename
def get_goes_position(filenames):
    if all("_G16_" in f for f in filenames):
        return "east"
    if all("_G17_" in f for f in filenames):
        return "west"

    # we require that all files are from the same sensor and raise an exception if not
    raise ValueError("position 'auto' but could not determine position - did you provide a mix of satellites?")


def get_start_end(filenames, start_time=None, end_time=None):
    """Compute start and end time of data based on filenames."""
    base_filenames = [os.path.basename(p) for p in filenames]

    filename_infos = [parse_glm_filename(f) for f in base_filenames]
    # opsenv, algorithm, platform, start, end, created = parse_glm_filename(f)
    filename_starts = [info[3] for info in filename_infos]
    filename_ends = [info[4] for info in filename_infos]
    start_time = min(filename_starts)

    # Used to use max(filename_ends), but on 27 Oct 2020, the filename
    # ends started to report the time of the last event in the file,
    # causing a slight leakage (usually less than a second) into the
    # next minute. This caused two minutes of grids to be produced for every
    # three twenty second files passed to this script.
    # Instead, we now assume every LCFA file is 20 s long, beginning with
    # the start time. No doubt in the future we will see filenames that no
    # longer start on an even minute boundary.
    end_time = max(filename_starts) + timedelta(0, 20)

    if start_time is None or end_time is None:
        raise ValueError("Could not determine start/end time")

    return start_time, end_time


def grid_setup(args, work_dir=os.getcwd()):
    # When passed None for the minimum event or group counts, the gridder will skip
    # the check, saving a bit of time.
    min_events = None
    min_groups = None

    try:
        start_time, end_time = get_start_end(args.filenames)
    except ValueError:
        log.error("Non-standard filenames provided, use --start and --end to specify data times.")
        raise

    base_date = datetime(start_time.year, start_time.month, start_time.day)
    proj_name = 'geos'

    outputpath = os.path.join(work_dir, "{dataset_name}") # GLMTools expects a template in addition to the path

    if args.goes_position == "auto":
        goes_position = get_goes_position(args.filenames)
    else:
        goes_position = args.goes_position

    resln = get_resolution(args)
    view = get_GOESR_grid(position=goes_position,
                          view=args.goes_sector,
                          resolution=resln)
    nadir_lon = view['nadir_lon']
    dx = dy = view['resolution']
    nx, ny = view['pixelsEW'], view['pixelsNS']
    geofixcs, grs80lla = get_GOESR_coordsys(sat_lon_nadir=nadir_lon)

    if 'centerEW' in view:
        x_ctr, y_ctr = view['centerEW'], view['centerNS']
    elif args.goes_sector == 'meso':
        # use ctr_lon, ctr_lat to get the center of the mesoscale FOV
        x_ctr, y_ctr, z_ctr = geofixcs.fromECEF(
            *grs80lla.toECEF(args.ctr_lon, args.ctr_lat, 0.0))
    else:
        # FIXME: is it possible to get here? if so, what should happen?
        raise RuntimeError

    # Need to use +1 here to convert to xedge, yedge expected by gridder
    # instead of the pixel centroids that will result in the final image
    nx += 1
    ny += 1
    x_bnd = (np.arange(nx, dtype='float') - (nx) / 2.0) * dx + x_ctr + 0.5 * dx
    y_bnd = (np.arange(ny, dtype='float') - (ny) / 2.0) * dy + y_ctr + 0.5 * dy
    log.debug(("initial x,y_ctr", x_ctr, y_ctr))
    log.debug(("initial x,y_bnd", x_bnd.shape, y_bnd.shape))
    x_bnd = np.asarray([x_bnd.min(), x_bnd.max()])
    y_bnd = np.asarray([y_bnd.min(), y_bnd.max()])

    geofixcs, grs80lla = get_GOESR_coordsys(sat_lon_nadir=nadir_lon)
    ctr_lon, ctr_lat, ctr_alt = grs80lla.fromECEF(
        *geofixcs.toECEF(x_ctr, y_ctr, 0.0))
    fixed_grid = geofixcs
    log.debug((x_bnd, y_bnd, dx, dy, nx, ny))

    output_writer = partial(write_cf_netcdf_fixedgrid, nadir_lon=nadir_lon)

    gridder = grid_GLM_flashes
    output_filename_prefix = 'GLM'
    grid_kwargs = dict(proj_name=proj_name,
                       base_date=base_date, do_3d=False,
                       dx=dx, dy=dy, frame_interval=60.0,
                       x_bnd=x_bnd, y_bnd=y_bnd,
                       ctr_lat=ctr_lat, ctr_lon=ctr_lon, outpath=outputpath,
                       min_points_per_flash=min_events,
                       output_writer=output_writer, subdivide=1, # subdivide the grid this many times along each dimension
                       output_filename_prefix=output_filename_prefix,
                       output_kwargs={'scale_and_offset': False},
                       spatial_scale_factor=1.0)

    #if args.fixed_grid:
    #    grid_kwargs['fixed_grid'] = True
    #    grid_kwargs['nadir_lon'] = nadir_lon
    # if args.split_events:
    grid_kwargs['clip_events'] = True
    if min_groups is not None:
        grid_kwargs['min_groups_per_flash'] = min_groups
    grid_kwargs['energy_grids'] = ('total_energy',)
    if (proj_name == 'pixel_grid') or (proj_name == 'geos'):
        grid_kwargs['pixel_coords'] = fixed_grid
    grid_kwargs['ellipse_rev'] = -1  # -1 (default) = infer from date in each GLM file
    return gridder, args.filenames, start_time, end_time, grid_kwargs


if __name__ == '__main__':
#    freeze_support() # nb. I don't think this is needed as we're not making windows execs at this time
    parser = create_parser()
    args = parser.parse_args()

    # Configure logging
    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    logging.basicConfig(level=levels[min(3, args.verbosity)], filename=args.log_fn)
    if levels[min(3, args.verbosity)] > logging.DEBUG:
        import warnings
        warnings.filterwarnings("ignore")
    log.info("Starting GLM Gridding")
    log.debug("Starting script with: %s", sys.argv)

    # set up output dir
    os.makedirs(args.output_dir, exist_ok=True)

    # set up temporary dir
    tempdir_path = tempfile.mkdtemp(suffix=None, prefix="tmp-glm-grids-", dir=os.getcwd())
    log.info("working in: {}".format(tempdir_path))
    # clean our temporary dir on exit
    atexit.register(shutil.rmtree, tempdir_path)

    # do the gridding
    gridder, glm_filenames, start_time, end_time, grid_kwargs = grid_setup(args, work_dir=tempdir_path)
    gridder(glm_filenames, start_time, end_time, **grid_kwargs)

    # pick up gridded files from the tempdir
    # output looks like: OR_GLM-L2-GLMC-M3_G17_s20202691559400_e20202691600400_c20210120141010.nc
    log.debug("gridded files in {}".format(tempdir_path))
    gridded_path = os.path.join(tempdir_path, 'OR_GLM-L2-GLM?-M?_G??_s*_e*_c*.nc')
    log.debug(gridded_path)
    gridded_files = glob(gridded_path)
    log.debug(gridded_files)

    # (optionally) do tiling
    if args.create_tiles:
        from satpy import Scene
        for gridded_file in gridded_files:
            log.info("TILING: {}".format(gridded_files))
            scn = Scene(reader='glm_l2', filenames=[gridded_file]) # n.b. satpy requires a list of filenames
            scn.load([
                'DQF',
                'flash_extent_density',
                'minimum_flash_area',
                'total_energy',
            ])

            scn.save_datasets(writer='awips_tiled',
                              template='glm_l2_radf', 
                              sector_id="GOES_EAST", # sector_id becomes an attribute in the output files and may be another legacy kind of thing. I'm not sure how much is is actually used here.
                              source_name="", # You could probably make source_name an empty string. I think it is required by the writer for legacy reasons but isn't actually used for the glm output
                              base_dir=tempdir_path, # base_dir is the output directory. I think blank is the same as current directory.
                              tile_size=(506, 904), # tile_size is set to the size of the GLMF sample tiles we were given and should match the full disk ABI tiles which is what they wanted
                              check_categories=False) # check_categories is there because of that issue I mentioned where DQF is all valid all the time so there is no way to detect empty tiles unless we ignore the "category" products


    # pick up output files from the tempdir
    # output looks like: OR_GLM-L2-GLMC-M3_G17_T03_20200925160040.nc
    log.debug("files in {}".format(tempdir_path))
    log.debug(os.listdir(tempdir_path))
    log.debug("moving output to {}".format(args.output_dir))
    tiled_path = os.path.join(tempdir_path, 'OR_GLM-L2-GLM?-M?_G??_T??_*.nc')
    tiled_files = glob(tiled_path)
    for f in tiled_files:
        shutil.move(f, os.path.join(args.output_dir, os.path.basename(f)))
    for f in gridded_files:
        shutil.move(f, os.path.join(args.output_dir, os.path.basename(f)))

    # tempdir cleans itself up via atexit, above