import os
import sys
import logging
import pandas as pd
from datetime import datetime as dt
from aosstower.level_00 import parser
from netCDF4 import Dataset
import numpy as np
import platform
from aosstower import station
from datetime import timedelta as delta
from aosstower.level_b1 import calc

LOG = logging.getLogger(__name__)


def make_mean_dict(source_dict):
    """Create the '_mean','_low','_high' file structure."""
    dest_dict = {}
    for key in source_dict:
        dest_dict[key + '_high'] = source_dict[key]
        dest_dict[key + '_mean'] = source_dict[key]
        dest_dict[key + '_low'] = source_dict[key]
    return dest_dict


mean_database = make_mean_dict(parser.database)


def filter_array(array, valid_min, valid_max):
    qcControl = []

    for value in array:
        if value == float(-99999):
            qcControl.append(np.byte(0b1))

        elif valid_min != '' and value < float(valid_min):
            qcControl.append(np.byte(0b10))

        elif valid_max != '' and value > float(valid_max):
            qcControl.append(np.byte(0b100))

        else:
            qcControl.append(np.byte(0b0))

    return np.array(qcControl)


def write_dimensions(ncFile):
    ncFile.createDimension('time', None)
    ncFile.createDimension('max_len_station_name', 32)

    return ncFile


def create_variables(ncFile, firstStamp, chunksizes, zlib, database=parser.database):
    # base_time long name
    btln = 'base time as unix timestamp'

    # base time units
    btu = 'seconds since 1970-01-01 00:00:00'

    # base time string
    bts = firstStamp.strftime('%Y-%m-%d 00:00:00Z')

    # time long name
    tln = 'time offset from base_time'

    # time units
    tu = 'seconds since ' + firstStamp.strftime('%Y-%m-%d 00:00:00Z')

    coordinates = {
        # fields: type, dimension, fill, valid_min, std_name, longname, units, valid_max, cf_role, axis
        'lon': [np.float32, None, float(-999), '-180L', 'longitude', None, 'degrees_east', '180L', None],
        'lat': [np.float32, None, float(-999), '-90L', 'latitude', None, 'degrees_north', '90L', None],
        'alt': [np.float32, None, float(-999), None, 'height', 'vertical distance', 'm', None, None],
        'base_time': [np.int32, None, float(-999), None, 'time', btln, btu, None, None],
        'time_offset': [np.float64, 'time', float(-999), None, 'time', tln, tu, None, None],
        'station_name': ['c', 'max_len_station_name', '-', None, None, 'station name', None, None, 'timeseries_id'],
        'time': [np.float32, 'time', float(-999), None, None, "Time offset from epoch",
                 "seconds since 1970-01-01 00:00:00Z", None, None, None]
    }

    for key, attr in coordinates.items():
        if attr[1]:
            if attr[1] == 'max_len_station_name':
                if chunksizes and chunksizes[0] > 32:
                    variable = ncFile.createVariable(key, attr[0], dimensions=(attr[1]), fill_value=attr[2], zlib=zlib,
                                                     chunksizes=[32])
                else:
                    variable = ncFile.createVariable(key, attr[0], dimensions=(attr[1]), fill_value=attr[2], zlib=zlib,
                                                     chunksizes=chunksizes)
            else:
                variable = ncFile.createVariable(key, attr[0], dimensions=(attr[1]), fill_value=attr[2], zlib=zlib,
                                                 chunksizes=chunksizes)
        else:
            variable = ncFile.createVariable(key, attr[0], fill_value=attr[1], zlib=zlib, chunksizes=chunksizes)

        # create var attributes
        if key == 'alt':
            variable.positive = 'up'
            variable.axis = 'Z'

        if attr[3]:
            variable.valid_min = attr[3]
            variable.valid_max = attr[7]

        if attr[4]:
            variable.standard_name = attr[4]

        if attr[5]:
            variable.long_name = attr[5]

        if attr[6]:
            variable.units = attr[6]

        if attr[8]:
            variable.cf_role = attr[8]

        if key == 'base_time':
            variable.string = bts

        if 'time' in key:
            variable.calendar = 'gregorian'

    for entry in database:
        if entry == 'stamp':
            continue

        varTup = database[entry]
        variable = ncFile.createVariable(entry, np.float32,
                                         dimensions=('time'), fill_value=float(-99999), zlib=zlib,
                                         chunksizes=chunksizes)

        variable.standard_name = varTup[1]
        variable.description = varTup[3]
        variable.units = varTup[4]

        if varTup[5] != '':
            variable.valid_min = float(varTup[5])
            variable.valid_max = float(varTup[6])

        qcVariable = ncFile.createVariable('qc_' + entry, 'b',
                                           dimensions=('time'), fill_value=0b0, zlib=zlib, chunksizes=chunksizes)

        qcVariable.long_name = 'data quality'
        qcVariable.valid_range = np.byte(0b1), np.byte(0b1111)
        qcVariable.flag_masks = np.byte(0b1), np.byte(0b10), np.byte(0b100), np.byte(0b1000)

        flagMeaning = ['value is equal to missing_value',
                       'value is less than the valid min',
                       'value is greater than the valid max',
                       'difference between current and previous values exceeds valid_delta.']

        qcVariable.flag_meaning = ', '.join(flagMeaning)

    # create global attributes
    ncFile.source = 'surface observation'
    ncFile.conventions = 'ARM-1.2 CF-1.6'
    ncFile.institution = 'UW SSEC'
    ncFile.featureType = 'timeSeries'
    ncFile.data_level = 'b1'

    # monthly files end with .month.nc
    # these end with .day.nc

    ncFile.datastream = 'aoss.tower.nc-1d-1m.b1.v00'
    ncFile.software_version = '00'

    # generate history
    ncFile.history = ' '.join(platform.uname()) + " " + os.path.basename(__file__)

    return ncFile


def get_gust(rollingAvg, speeds):
    averages = rollingAvg.tolist()
    maxSpeed = speeds['wind_speed'].tolist()

    gust = []
    for idx, average in enumerate(averages):
        if not average:
            gust.append(np.nan)
            continue
        elif average >= 4.63 and maxSpeed[idx] > average + 2.573:
            gust.append(maxSpeed[idx])
        else:
            gust.append(np.nan)
            continue

    return gust


def get_rolling(series, minutes):
    """Get the rolling mean closest to the nearest minute"""
    returnSeries = series.rolling(25, win_type='boxcar').mean()
    data = {}
    for minute in minutes:
        # doesn't go past the minute
        data[minute] = returnSeries[returnSeries.index.asof(minute)]

    return pd.Series(data)


def get_new_wind_direction(wind_dir, wind_speed, stamps):
    newWindDir = {}

    for stamp in stamps:
        before = stamp - delta(minutes=1)

        if before not in wind_speed.index:
            newWindDir[stamp] = None
        else:
            speed = wind_speed[before: stamp].tolist()
            dire = wind_dir[before: stamp].tolist()

            wind_dire = calc.mean_wind_vector(speed, dire)[0]

            newWindDir[stamp] = wind_dire

    return pd.Series(newWindDir)


def minute_averages(frame):
    frame['minute'] = [(ts + delta(minutes=1)).replace(second=0) for ts in frame.index]
    newFrame = frame.groupby('minute').mean()
    newFrame.index.names = ['']

    columns = list(newFrame.columns.values)
    if 'wind_speed' in columns:
        del newFrame['wind_speed']

        windSeries = frame['wind_speed']
        windSeries = get_rolling(windSeries, list(newFrame.index))
        newFrame['wind_speed'] = windSeries
        rollingAvg = newFrame['wind_speed']
        maxSpeed = pd.DataFrame()
        maxSpeed['minute'] = frame['minute']
        maxSpeed['speed'] = frame['wind_speed']
        maxSpeed = frame.groupby('minute').max()
        gust = get_gust(rollingAvg, maxSpeed)
        newFrame['gust'] = gust

    if 'wind_dir' in columns:
        del newFrame['wind_dir']

        dupFrame = frame.set_index('minute')
        stamps = newFrame.index
        windDirSeries = dupFrame['wind_dir']
        windSeries = dupFrame['wind_speed']
        windDirSeries = get_new_wind_direction(windDirSeries, windSeries, stamps)
        newFrame['wind_dir'] = windDirSeries

    del frame['minute']
    return newFrame.fillna(-99999)


def average_over_interval(frame, interval_width):
    """takes a frame and an interval to average it over, and returns a minimum,
    maximum, and average dataframe for that interval"""
    ts = frame.index
    # round each timestamp to the nearest n minutes
    frame['interval'] = (ts.astype(int) - ts.astype(int) % (interval_width * 60e9)).astype('datetime64[ns]')
    out_frames = {
        'low': frame.groupby('interval').min(),
        'high': frame.groupby('interval').max(),
        'mean': frame.groupby('interval').mean(),
    }
    del frame['interval']
    for key in out_frames:
        # append the appropriate suffix to each column
        columns = out_frames[key].columns
        out_frames[key].columns = ['_'.join([col, key]) for col in columns]
    out_frames = pd.concat(out_frames.values(), axis=1)
    return out_frames


def get_data(inputFiles):
    dictData = {}
    for filename in inputFiles:
        getFrames = list(parser.read_frames(filename))

        for frame in getFrames:
            if 'stamp' not in frame:
                continue

            stamp = frame['stamp']
            del frame['stamp']

            dictData[stamp] = frame

    return pd.DataFrame(dictData).transpose().replace(-99999, np.nan)


def write_vars(ncFile, frame, database=parser.database):
    stamps = list(frame.index)
    baseDTObj = dt.strptime(str(stamps[0]).split(' ')[0], '%Y-%m-%d')

    # find out how much time elapsed
    # since the origin to the start of the day
    # in seconds
    baseTimeValue = baseDTObj - dt(1970, 1, 1)
    baseTimeValue = baseTimeValue.total_seconds()

    # create time numpy
    timeNumpy = np.empty(len(stamps), dtype='float64')

    counter = 0

    # write stamps in, yo

    for stamp in stamps:
        stampObj = dt.strptime(str(stamp), '%Y-%m-%d %H:%M:%S')
        timeValue = (stampObj - baseDTObj).total_seconds()

        timeNumpy[counter] = timeValue
        counter += 1

    fileVar = ncFile.variables
    fileVar['base_time'].assignValue(baseTimeValue)
    fileVar['time_offset'][:] = timeNumpy
    fileVar['time'][:] = timeNumpy

    # write coordinate var values to file
    # alt might not be right, need to verify
    fileVar['lon'].assignValue(station.LONGITUDE)
    fileVar['lat'].assignValue(station.LATITUDE)
    fileVar['alt'].assignValue(328)

    # might change
    stationName = ("AOSS Tower")

    # transfer station name into array of chars
    statChars = list(stationName)
    statNumpy = np.asarray(statChars)

    # write station name to file
    fileVar['station_name'][0:len(statNumpy)] = statNumpy

    # writes data into file
    for varName in frame:
        if varName not in fileVar:
            logging.warn('Extraneous key: %s in frame' % varName)
            continue
        dataList = frame[varName].tolist()

        dataArray = np.asarray(dataList)
        fileVar[varName][:] = dataArray

        valid_min = database[varName][5]
        valid_max = database[varName][6]

        fileVar['qc_' + varName][:] = filter_array(dataArray, valid_min, valid_max)

    coordinates = ['lon', 'lat', 'alt', 'base_time', 'time_offset', 'station_name', 'time']

    for varName in fileVar:
        if varName.startswith('qc_'):
            continue

        elif varName in frame:
            continue

        elif varName in coordinates:
            continue

        else:
            fileVar['qc_' + varName][:] = np.full(len(list(frame.index)), np.byte(0b1), dtype='b')

    return ncFile


# The purpose of this method is to take a begin date, and end date
# input filenames and output filename and create a netCDF file 
# based upon that
# @param start time - a start datetime object
# @param end time - an end datetime object
# @param input filenames - list of filenames
# @param output filename - filename of the netcdf file

def create_giant_netCDF(start, end, inputFiles, outputName, zlib, chunkSize,
                        interval_width=None, database=parser.database):
    default = False

    if (chunkSize):
        chunksizes = [chunkSize]

    else:
        default = True

    frame = get_data(inputFiles)

    if (frame.empty):
        return False

    else:

        frame = minute_averages(frame)
        if interval_width:
            frame = average_over_interval(frame, interval_width)

        if (start and end):
            frame = frame[start.strftime('%Y-%m-%d %H:%M:%S'): end.strftime('%Y-%m-%d %H:%M:%S')]

        if (default):
            chunksizes = [len(list(frame.index))]

        firstStamp = dt.strptime(str(list(frame.index)[0]), '%Y-%m-%d %H:%M:%S')

        ncFile = Dataset(outputName, 'w', format='NETCDF4_CLASSIC')

        ncFile = write_dimensions(ncFile)

        ncFile = create_variables(ncFile, firstStamp, chunksizes, zlib, database)

        ncFile.inputFiles = ', '.join(inputFiles)

        ncFile = write_vars(ncFile, frame, database)

        ncFile.close()

        return True


def create_multiple(filenames, outputFilenames, zlib, chunkSize):
    if outputFilenames and len(filenames) != len(outputFilenames):
        raise ValueError(
            'Number of output filenames must equal number of input filenames when start and end times are not specified')

    results = []
    for idx, filename in enumerate(filenames):
        results.append(create_giant_netCDF(None, None, [filename], outputFilenames[idx], zlib, chunkSize))

    if not any(results):
        raise IOError('All ASCII files were empty')


# The purpose of this method is to take a string in the format
# YYYY-mm-ddTHH:MM:SS and convert that to a datetime object
# used in coordination with argparse -s and -e params
# @param datetime string
# @return datetime object

def _dt_convert(datetime_str):
    # parse datetime string, return datetime object
    try:
        return dt.strptime(datetime_str, '%Y-%m-%dT%H:%M:%S')
    except:
        return dt.strptime(datetime_str, '%Y-%m-%d')


def main():
    import argparse
    argparser = argparse.ArgumentParser(description="Convert level_00 aoss tower data to level_a0",
                                        fromfile_prefix_chars='@')

    argparser.add_argument('-v', '--verbose', action="count", default=int(os.environ.get("VERBOSITY", 2)),
                           dest='verbosity',
                           help='each occurrence increases verbosity 1 level through ERROR-WARNING-INFO-DEBUG (default INFO)')

    argparser.add_argument('-s', '--start-time', type=_dt_convert,
                           help="Start time of massive netcdf file, if only -s is given, a netcdf file for only that day is given" +
                                ". Formats allowed: \'YYYY-MM-DDTHH:MM:SS\', \'YYYY-MM-DD\'")
    argparser.add_argument('-e', '--end-time', type=_dt_convert,
                           help='End time of massive netcdf file. Formats allowed:' +
                                "\'YYYY-MM-DDTHH:MM:SS\', \'YYYY-MM-DD\'")
    argparser.add_argument('-i', '--interval', type=float,
                           help='Width of the interval to average input data over in minutes.' +
                                " If not specified, 1 is assumed. (Use 60 for one hour and 1440 for 1 day)")
    argparser.add_argument('--chunk-size', type=int, help='chunk size for the netCDF file')
    argparser.add_argument('-z', '--zlib', action='store_true', help='compress netCDF file with zlib')

    argparser.add_argument("input_files", nargs="+",
                           help="aoss_tower level_00 paths. Use @filename to red a list of paths from that file.")

    argparser.add_argument('-o', '--output', required=True, nargs="+", help="filename pattern or filename. " +
                                                                            "Should be along the lines of <filepath>/aoss_tower.YYYY-MM-DD.nc")
    args = argparser.parse_args()

    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    level = levels[min(3, args.verbosity)]
    logging.basicConfig(level=level)

    database = mean_database if args.interval else parser.database
    if args.start_time and args.end_time:
        result = create_giant_netCDF(args.start_time, args.end_time, args.input_files, args.output[0], args.zlib,
                                     args.chunk_size,
                                     args.interval, database)
        if not result:
            raise IOError('An empty ASCII file was found')
    elif args.start_time:
        end_time = args.start_time.replace(hour=23, minute=59, second=59)
        result = create_giant_netCDF(args.start_time, end_time, args.input_files, args.output[0], args.zlib,
                                     args.chunk_size,
                                     args.interval, database)
        if not result:
            raise IOError('An empty ASCII file was found')
    elif args.end_time:
        raise ValueError('start time must be specified when end time is specified')
    else:
        create_multiple(args.input_files, args.output, args.zlib, args.chunk_size)


if __name__ == "__main__":
    sys.exit(main())