import os
import sys
import logging
import pandas as pd
from datetime import datetime as dt
from netCDF4 import Dataset
import numpy as np
import platform
from aosstower import station, schema
from aosstower.level_00 import parser
from datetime import timedelta as delta
from aosstower.level_b1 import calc

LOG = logging.getLogger(__name__)
STATION_NAME = "AOSS Tower"
SOFTWARE_VERSION = '00'
# Knots used to do ASOS wind gust calculations
KNOTS_10 = calc.knots_to_mps(10.)
KNOTS_9 = calc.knots_to_mps(9.)
KNOTS_5 = calc.knots_to_mps(5.)
KNOTS_3 = calc.knots_to_mps(3.)
KNOTS_2 = calc.knots_to_mps(2.)
DEFAULT_FLOAT_FILL = -9999.


def make_summary_dict(source_dict):
    """Create the '_mean','_min','_max' file structure."""
    dest_dict = {}
    for key in source_dict:
        if key == 'wind_dir':
            dest_dict['wind_speed_max_dir'] = source_dict[key]
            dest_dict['wind_speed_mean_dir'] = source_dict[key]
            dest_dict['wind_speed_min_dir'] = source_dict[key]
            dest_dict['peak_gust_dir'] = source_dict[key]
        elif key == 'gust':
            dest_dict['peak_gust'] = source_dict[key]
        else:
            dest_dict[key + '_max'] = source_dict[key]
            dest_dict[key + '_mean'] = source_dict[key]
            dest_dict[key + '_min'] = source_dict[key]
    return dest_dict


def filter_array(arr, valid_min, valid_max, valid_delta):
    """Create QC field array.

    Meanings
    --------

        Bit 0 (0b1): value is equal to missing_value
        Bit 1 (0b10): value is less than the valid min
        Bit 2 (0b100): value is greater than the valid max
        Bit 3 (0b1000): difference between current and previous values exceeds valid_delta

    Arguments:
        arr (Series): Data array this QC is for
        valid_min (float): Valid minimum of the input data
        valid_max (float): Valid maximum of the input data
        valid_delta (float): Valid delta of one input data point to the previous point

    Returns:
        Series: Bit-mask describing the input data (see above bit meanings)
   """
    qcControl = np.zeros(arr.shape, dtype=np.int8)
    qcControl[arr.isnull().values] |= 0b1
    if valid_min:
        qcControl[(arr < valid_min).values] |= 0b10
    if valid_max:
        qcControl[(arr > valid_max).values] |= 0b100
    if valid_delta:
        # prepend to fix off-by-one array shape
        d = np.append([0], np.diff(arr))
        qcControl[(d > valid_delta).values] |= 0b1000
    return qcControl


def write_dimensions(nc_file):
    nc_file.createDimension('time', None)
    nc_file.createDimension('max_len_station_name', 32)


def create_variables(nc_file, first_stamp, database, chunk_sizes=None, zlib=False):
    # base_time long name
    btln = 'Base time in Epoch'

    # base time units
    btu = 'seconds since 1970-01-01 00:00:00'

    # base time string
    bts = first_stamp.strftime('%Y-%m-%d 00:00:00Z')

    # time offset long name
    to_ln = 'Time offset from base_time'

    # time offset units
    to_u = 'seconds since ' + first_stamp.strftime('%Y-%m-%d 00:00:00Z')
    # 'time' units
    t_u = 'hours since ' + first_stamp.strftime('%Y-%m-%d 00:00:00Z')

    coordinates = {
        # fields: type, dimension, fill, valid_min, std_name, longname, units, valid_max, cf_role, axis
        'time': [np.float64, ('time',), DEFAULT_FLOAT_FILL, None, None, "Hour offset from midnight",
                 t_u, None, None, None],
        'lon': [np.float32, tuple(), DEFAULT_FLOAT_FILL, -180., 'longitude', None, 'degrees_east', 180., None],
        'lat': [np.float32, tuple(), DEFAULT_FLOAT_FILL, -90., 'latitude', None, 'degrees_north', 90., None],
        'alt': [np.float32, tuple(), DEFAULT_FLOAT_FILL, None, 'height', 'vertical distance', 'm', None, None],
        # int64 for base_time would be best, but NetCDF4 Classic does not support it
        # NetCDF4 Classic mode was chosen so users can use MFDatasets (multi-file datasets)
        'base_time': [np.int32, tuple(), DEFAULT_FLOAT_FILL, None, 'time', btln, btu, None, None],
        'time_offset': [np.float64, ('time',), DEFAULT_FLOAT_FILL, None, 'time', to_ln, to_u, None, None],
        'station_name': ['c', ('max_len_station_name',), '\0', None, None, 'station name', None, None, 'timeseries_id'],
    }

    for key, attr in coordinates.items():
        if attr[1] == ('max_len_station_name',) and chunk_sizes and chunk_sizes[0] > 32:
            cs = [32]
        else:
            cs = chunk_sizes

        variable = nc_file.createVariable(key, attr[0],
                                          dimensions=attr[1],
                                          fill_value=attr[2],
                                          zlib=zlib,
                                          chunksizes=cs)

        # create var attributes
        if key == 'alt':
            variable.positive = 'up'
            variable.axis = 'Z'

        if attr[3]:
            variable.valid_min = attr[3]
            variable.valid_max = attr[7]

        if attr[4]:
            variable.standard_name = attr[4]

        if attr[5]:
            variable.long_name = attr[5]

        if attr[6]:
            variable.units = attr[6]

        if attr[8]:
            variable.cf_role = attr[8]

        if key == 'base_time':
            variable.string = bts
            variable.ancillary_variables = 'time_offset'
        elif key == 'time_offset':
            variable.ancillary_variables = 'base_time'

        # CF default
        # if 'time' in key:
        #     variable.calendar = 'gregorian'

    for entry in sorted(database.keys()):
        if entry == 'stamp':
            continue

        varTup = database[entry]
        variable = nc_file.createVariable(entry, np.float32,
                                          dimensions=('time',), fill_value=DEFAULT_FLOAT_FILL, zlib=zlib,
                                          chunksizes=chunk_sizes)

        variable.standard_name = varTup[1]
        variable.description = varTup[3]
        variable.units = varTup[4]

        if varTup[5] != '':
            variable.valid_min = float(varTup[5])
            variable.valid_max = float(varTup[6])

        qcVariable = nc_file.createVariable('qc_' + entry, 'b',
                                            dimensions=('time',), zlib=zlib, chunksizes=chunk_sizes)

        qcVariable.long_name = 'data quality'
        qcVariable.valid_range = np.byte(0b1), np.byte(0b1111)
        qcVariable.flag_masks = np.byte(0b1), np.byte(0b10), np.byte(0b100), np.byte(0b1000)

        flagMeaning = ['value is equal to missing_value',
                       'value is less than the valid min',
                       'value is greater than the valid max',
                       'difference between current and previous values exceeds valid_delta.']

        qcVariable.flag_meaning = ', '.join(flagMeaning)


def calculate_wind_gust(wind_speed_5s, wind_speed_2m):
    """Calculate reportable wind gust from wind speed averages.

    Determination of wind gusts follows the Autmated Surface Observing System
    (ASOS) User's Guide created by NOAA. The guide can be found here:

        http://www.nws.noaa.gov/asos/pdfs/aum-toc.pdf

    Note: This operates on wind data in meters per second even though the
    ASOS User's Guide operates on knots.

    Arguments:
        wind_speed_5s (Series): 5-second average of instantaneous wind speed magnitudes in m/s
        wind_speed_2m (Series): 2-minute rolling average of the 5 second wind speed averages in m/s

    Returns:
        Series: 5-second rolling wind speed gusts meeting the ASOS criteria
                for being a reportable wind gust.

    """
    # 1 minute rolling peaks
    wind_peak_1m = wind_speed_5s.rolling(window='1T', center=False).max()
    # criteria for a fast wind to be considered a wind gust
    gust_mask = (wind_speed_2m >= KNOTS_9) & \
                (wind_peak_1m >= wind_speed_2m + KNOTS_5)
    gusts = wind_peak_1m.mask(~gust_mask)

    # determine highest gust in the last 10 minutes
    # 5 seconds * 120 = 10 minutes
    max_10m_gusts = gusts.rolling(window='10T', center=False).max()
    # Minimum 5-second average in the past 10 minutes
    min_10m_5avg = wind_speed_5s.rolling(window='10T', center=False).min()
    # criteria for a wind gust to be reportable
    reportable_mask = (max_10m_gusts >= wind_speed_2m + KNOTS_3) & \
                      (wind_speed_2m > KNOTS_2) & \
                      (max_10m_gusts >= min_10m_5avg + KNOTS_10)
    # Reportable gusts at 5 second resolution
    reportable_gusts = max_10m_gusts.mask(~reportable_mask)
    return reportable_gusts


def minute_averages(frame):
    """Average the data frame at a 1 minute interval.

    Wind direction and wind speed must be handled in a special manner
    """
    # Add wind direction components so we can average wind direction properly
    frame['wind_east'], frame['wind_north'], _ = calc.wind_vector_components(frame['wind_speed'], frame['wind_dir'])
    # round up each 1 minute group so data at time T is the average of data
    # from T - 1 (exclusive) to T (inclusive).
    new_frame = frame.resample('1T', closed='right', loffset='1T').mean()

    # 2 minute rolling average of 5 second data (5 seconds * 24 = 120 seconds = 2 minutes)
    winds_frame_5s = frame[['wind_speed', 'wind_east', 'wind_north']]
    winds_frame_5s = winds_frame_5s.resample('5S', closed='right', loffset='5S').mean()
    winds_frame_2m = winds_frame_5s.rolling(24, win_type='boxcar').mean()

    # rolling average is used for 1 minute output
    new_frame.update(winds_frame_2m.ix[new_frame.index])  # adds wind_speed, wind_east/north
    new_frame['wind_dir'] = calc.wind_vector_degrees(new_frame['wind_east'], new_frame['wind_north'])
    del new_frame['wind_east']
    del new_frame['wind_north']

    gust = calculate_wind_gust(winds_frame_5s['wind_speed'], winds_frame_2m['wind_speed'])
    # "average" the gusts to minute resolution to match the rest of the data
    new_frame['gust'] = gust.resample('1T', closed='right', loffset='1T').max()

    return new_frame.fillna(np.nan)


def summary_over_interval(frame, interval_width):
    """takes a frame and an interval to average it over, and returns a minimum,
    maximum, and average dataframe for that interval
    """
    # round each timestamp to the nearest minute
    # the value at time X is for the data X - interval_width minutes
    exclude = ['gust', 'wind_east', 'wind_north']
    include = [c for c in frame.columns if c not in exclude]
    gb = frame[include].resample(interval_width, closed='left')

    low = gb.min()
    low.rename(columns=lambda x: x + "_min", inplace=True)
    high = gb.max()
    high.rename(columns=lambda x: x + "_max", inplace=True)
    mean = gb.mean()
    mean.rename(columns=lambda x: x + "_mean", inplace=True)

    out_frames = pd.concat((low, high, mean), axis=1)

    # wind fields need to be handled specially
    ws_min_idx = frame['wind_speed'].resample(interval_width, closed='left').apply(lambda arr_like: arr_like.argmin())
    ws_max_idx = frame['wind_speed'].resample(interval_width, closed='left').apply(lambda arr_like: arr_like.argmax())
    # probably redundant but need to make sure the direction indexes are
    # the same as those used in the wind speed values
    # must use .values so we don't take data at out_frames index, but rather
    # fill in the out_frames index values with the min/max values
    out_frames['wind_speed_min'] = frame['wind_speed'][ws_min_idx].values
    out_frames['wind_speed_max'] = frame['wind_speed'][ws_max_idx].values
    out_frames['wind_speed_min_dir'] = calc.wind_vector_degrees(frame['wind_east'][ws_min_idx], frame['wind_north'][ws_min_idx]).values
    out_frames['wind_speed_max_dir'] = calc.wind_vector_degrees(frame['wind_east'][ws_max_idx], frame['wind_north'][ws_max_idx]).values
    we = frame['wind_east'].resample(interval_width, closed='left').mean()
    wn = frame['wind_north'].resample(interval_width, closed='left').mean()
    out_frames['wind_speed_mean_dir'] = calc.wind_vector_degrees(we, wn).values

    gust_idx = frame['gust'].resample(interval_width, closed='left').apply(lambda arr_like: arr_like.argmax())
    # gusts may be NaN so this argmax will be NaN indexes which don't work great
    gust_idx = gust_idx.astype('datetime64[ns]', copy=False)
    peak_gust = frame['gust'][gust_idx]
    out_frames['peak_gust'] = peak_gust.values
    we = frame['wind_east'][gust_idx]
    wn = frame['wind_north'][gust_idx]
    out_frames['peak_gust_dir'] = calc.wind_vector_degrees(we, wn).values

    return out_frames


def _get_data(input_files):
    for filename in input_files:
        yield from parser.read_frames(filename)


def get_data(input_files):
    frame = pd.DataFrame(_get_data(input_files))
    frame.set_index('stamp', inplace=True)
    frame.mask(frame == -99999., inplace=True)
    frame.fillna(value=np.nan, inplace=True)
    return frame


def write_vars(nc_file, frame, database):
    # all time points from epoch
    time_epoch = (frame.index.astype(np.int64) // 10**9).values
    # midnight of data's first day
    base_epoch = frame.index[0].replace(hour=0, minute=0, second=0, microsecond=0).value // 10**9

    fileVar = nc_file.variables
    # base_time is midnight of the current day
    fileVar['base_time'].assignValue(base_epoch)
    fileVar['time_offset'][:] = time_epoch - base_epoch
    # hours since midnight
    fileVar['time'][:] = (time_epoch - base_epoch) / (60 * 60)

    # write coordinate var values to file
    # alt might not be right, need to verify
    fileVar['lon'].assignValue(station.LONGITUDE)
    fileVar['lat'].assignValue(station.LATITUDE)
    fileVar['alt'].assignValue(station.ELEVATION)

    # write station name to file
    fileVar['station_name'][:len(STATION_NAME)] = STATION_NAME

    # writes data into file
    for varName in frame:
        if varName not in fileVar:
            LOG.debug('Unused input variable: %s', varName)
            continue
        fileVar[varName][:] = frame[varName].fillna(DEFAULT_FLOAT_FILL).values

        valid_min = database[varName][5]
        valid_max = database[varName][6]
        valid_delta = database[varName][7]

        fileVar['qc_' + varName][:] = filter_array(frame[varName],
                                                   float(valid_min) if valid_min else valid_min,
                                                   float(valid_max) if valid_max else valid_max,
                                                   float(valid_delta) if valid_delta else valid_delta)

    coordinates = ['lon', 'lat', 'alt', 'base_time', 'time_offset', 'station_name', 'time']

    for varName in fileVar:
        if not varName.startswith('qc_') and \
                        varName not in frame and \
                        varName not in coordinates:
            # if a variable should be in file (database), but isn't in the
            # input data then make sure the QC field lists it as all fills
            fileVar['qc_' + varName][:] |= 0b1


def write_global_attributes(nc_file, input_sources, interval=None, datastream=None):
    # create global attributes
    nc_file.source = 'surface observation'
    nc_file.conventions = 'ARM-1.2 CF-1.6'
    nc_file.institution = 'University of Wisconsin - Madison (UW) Space Science and Engineering Center (SSEC)'
    nc_file.featureType = 'timeSeries'
    nc_file.data_level = 'b1'

    # monthly files end with .month.nc
    # these end with .day.nc

    if datastream:
        nc_file.datastream = datastream
    elif interval in ['1D']:
        # assume this is a monthly file, averaged daily
        nc_file.datastream = 'aoss.tower.nc-1mo-1d.b1.v{software_version}'.format(software_version=SOFTWARE_VERSION)
    elif interval in ['1T', '1min']:
        # assume this is a daily file, averaged
        nc_file.datastream = 'aoss.tower.nc-1d-1m.b1.v{software_version}'.format(software_version=SOFTWARE_VERSION)
    nc_file.software_version = SOFTWARE_VERSION
    nc_file.command_line = " ".join(sys.argv)

    # generate history
    nc_file.history = ' '.join(platform.uname()) + " " + os.path.basename(__file__)
    nc_file.input_source = input_sources[0]
    nc_file.input_sources = ', '.join(input_sources)


def create_giant_netcdf(input_files, output_fn, zlib, chunk_size,
                        start=None, end=None, interval_width=None,
                        summary=False,
                        database=schema.database, datastream=None):
    frame = get_data(input_files)
    if frame.empty:
        raise ValueError("No data found from input files: {}".format(", ".join(input_files)))

    # Add wind direction components so we can average wind direction properly
    frame['wind_east'], frame['wind_north'], _ = calc.wind_vector_components(frame['wind_speed'], frame['wind_dir'])

    if 'air_temp' in frame and 'rh' in frame and 'dewpoint' in database:
        LOG.info("'dewpoint' is missing from the input file, will calculate it from air temp and relative humidity")
        frame['dewpoint'] = calc.dewpoint(frame['air_temp'], frame['rh'])

    # round up each 1 minute group so data at time T is the average of data
    # from T - 1 (exclusive) to T (inclusive).
    new_frame = frame.resample('5S', closed='right', loffset='5S').mean()

    # 2 minute rolling average of 5 second data (5 seconds * 24 = 120 seconds = 2 minutes)
    winds_frame_5s = new_frame[['wind_speed', 'wind_east', 'wind_north']]
    winds_frame_2m = winds_frame_5s.rolling('2T').mean()
    winds_frame_2m['gust'] = calculate_wind_gust(winds_frame_5s['wind_speed'], winds_frame_2m['wind_speed'])

    # rolling average is used for mean output
    new_frame.update(winds_frame_2m)  # adds wind_speed, wind_east/north
    new_frame['gust'] = winds_frame_2m['gust']

    # average the values
    if summary:
        frame = summary_over_interval(new_frame, interval_width)
    else:
        frame = new_frame.resample(interval_width, closed='right', loffset=interval_width).mean()
        frame['wind_dir'] = calc.wind_vector_degrees(frame['wind_east'], frame['wind_north'])
        frame['gust'] = new_frame['gust'].resample(interval_width, closed='right', loffset=interval_width).max()
    frame.fillna(np.nan, inplace=True)

    if start and end:
        frame = frame[start.strftime('%Y-%m-%d %H:%M:%S'): end.strftime('%Y-%m-%d %H:%M:%S')]

    if chunk_size and not isinstance(chunk_size, (list, tuple)):
        chunk_sizes = [chunk_size]
    else:
        chunk_sizes = [frame.shape[0]]

    first_stamp = dt.strptime(str(frame.index[0]), '%Y-%m-%d %H:%M:%S')
    # NETCDF4_CLASSIC was chosen so that MFDataset reading would work. See:
    # http://unidata.github.io/netcdf4-python/#netCDF4.MFDataset
    nc_file = Dataset(output_fn, 'w', format='NETCDF4_CLASSIC')
    write_dimensions(nc_file)
    create_variables(nc_file, first_stamp, database, chunk_sizes, zlib)
    write_vars(nc_file, frame, database)
    write_global_attributes(nc_file,
                            [os.path.basename(x) for x in input_files],
                            interval=interval_width,
                            datastream=datastream)
    nc_file.close()
    return nc_file


def _dt_convert(datetime_str):
    """Parse datetime string, return datetime object"""
    try:
        return dt.strptime(datetime_str, '%Y-%m-%dT%H:%M:%S')
    except ValueError:
        return dt.strptime(datetime_str, '%Y-%m-%d')


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Convert level_00 aoss tower data to level_b1",
                                     fromfile_prefix_chars='@')

    parser.add_argument('-v', '--verbose', action="count", default=int(os.environ.get("VERBOSITY", 2)),
                        dest='verbosity',
                        help='each occurrence increases verbosity 1 level through ERROR-WARNING-INFO-DEBUG (default INFO)')

    parser.add_argument('-s', '--start-time', type=_dt_convert,
                        help="Start time of massive netcdf file, if only -s is given, a netcdf file for only that day is given" +
                             ". Formats allowed: \'YYYY-MM-DDTHH:MM:SS\', \'YYYY-MM-DD\'")
    parser.add_argument('-e', '--end-time', type=_dt_convert,
                        help='End time of massive netcdf file. Formats allowed:' +
                                "\'YYYY-MM-DDTHH:MM:SS\', \'YYYY-MM-DD\'")
    parser.add_argument('-n', '--interval', default='1T',
                        help="""Width of the interval to average input data
over in Pandas offset format. If not specified, 1 minute averages are used. If
specified then '_high', '_mean', and '_low' versions of the data fields are
written to the output NetCDF.
Use '1D' for daily or '5T' for 5 minute averages.
See this page for more details:
http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases""")
    parser.add_argument('--summary', action='store_true',
                        help="Create a file with _low, _mean, _high versions of every variable name")
    parser.add_argument('-f', '--fields', nargs='+', default=schema.met_vars,
                        help="Variable names to include in the NetCDF file (base name, no suffixes)")
    parser.add_argument('--chunk-size', type=int, help='chunk size for the netCDF file')
    parser.add_argument('-z', '--zlib', action='store_true', help='compress netCDF file with zlib')
    parser.add_argument('--data-stream', help="'datastream' global attribute to put in output file")

    parser.add_argument('-i', '--input', dest='input_files', required=True, nargs="+",
                        help="aoss_tower level_00 paths. Use @filename to red a list of paths from that file.")

    parser.add_argument('-o', '--output', dest='output_files', required=True, nargs="+",
                        help="""NetCDF filename(s) to create from input. If one
filename is specified then all input files are combined in to it. Otherwise
each input file is mapped to the corresponding output file.
""")
    args = parser.parse_args()

    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    level = levels[min(3, args.verbosity)]
    logging.basicConfig(level=level)

    if args.start_time and not args.end_time:
        args.end_time = args.start_time.replace(hour=23, minute=59, second=59)
    elif not args.start_time and args.end_time:
        raise ValueError('start time must be specified when end time is specified')

    mini_database = {k: schema.database[k] for k in args.fields}
    if args.summary:
        mini_database = make_summary_dict(mini_database)

    # Case 1: All inputs to 1 output file
    # Case 2: Each input in to a separate output file
    if args.output_files and len(args.output_files) not in [1, len(args.input_files)]:
        raise ValueError('Output filenames must be 1 or the same length as input files')
    elif args.output_files and len(args.output_files) == len(args.input_files):
        args.input_files = [[i] for i in args.input_files]
    else:
        args.input_files = [args.input_files]

    success = False
    for in_files, out_fn in zip(args.input_files, args.output_files):
        try:
            create_giant_netcdf(in_files, out_fn, args.zlib,
                                args.chunk_size, args.start_time,
                                args.end_time, args.interval, args.summary,
                                mini_database, args.data_stream)
            success = True
        except (ValueError, TypeError):
            LOG.error("Could not generate NetCDF file for {}".format(in_files), exc_info=True)
    if not success:
        raise IOError('All ASCII files were empty or could not be read')


if __name__ == "__main__":
    sys.exit(main())