import os
import sys
import logging
import pandas as pd
from datetime import datetime as dt
from aosstower.l00 import parser
import avg_database
from netCDF4 import Dataset
import numpy as np 
import platform
from aosstower import station
from datetime import timedelta as delta
import calc

LOG = logging.getLogger(__name__)

def filterArray(array, valid_min, valid_max):

    qcControl = []

    for value in array:
        if value == float(-99999):
            qcControl.append(np.byte(0b1))

        elif valid_min != '' and value < float(valid_min):
            qcControl.append(np.byte(0b10))
 
        elif valid_max != '' and value > float(valid_max):
            qcControl.append(np.byte(0b100))
 
        else:
            qcControl.append(np.byte(0b0))

    return np.array(qcControl)

# The purpose of this function is to write the dimensions
# for the nc file
# no parameters
# no returns

def writeDimensions(ncFile):
    ncFile.createDimension('time', None)
    ncFile.createDimension('max_len_station_name', 32)

    return ncFile

def createVariables(ncFile, firstStamp, chunksizes, zlib, database=parser.database):
    #base_time long name
    btln = 'base time as unix timestamp'

    #base time units
    btu = 'seconds since 1970-01-01 00:00:00'

    #base time string
    bts = firstStamp.strftime('%Y-%m-%d 00:00:00Z')

    #time long name
    tln = 'time offset from base_time'

    #time units
    tu = 'seconds since ' + firstStamp.strftime('%Y-%m-%d 00:00:00Z')

    

    coordinates = {
                      #fields: type, dimension, fill, valid_min, std_name, longname, units, valid_max, cf_role, axis
                      'lon': [np.float32, None, float(-999), '-180L', 'longitude', None, 'degrees_east', '180L', None],
                      'lat': [np.float32, None, float(-999), '-90L', 'latitude', None, 'degrees_north', '90L', None],
                      'alt': [np.float32, None, float(-999), None, 'height', 'vertical distance', 'm', None, None],
                      'base_time': [np.float32, None, float(-999), None, 'time', btln, btu, None, None],
                      'time_offset': [np.float32, 'time', float(-999), None, 'time', tln, tu, None, None],
                      'station_name': ['c', 'max_len_station_name', '-', None, None, 'station name', None, None, 'timeseries_id'],   
                      'time': [np.float32, 'time', float(-999), None, None, "Time offset from epoch", "seconds since 1970-01-01 00:00:00Z", None, None, None]
                  }

    for key in coordinates:
        attr = coordinates[key]

        if(attr[1]):
            if attr[1] == 'max_len_station_name':
                if (chunksizes) and chunksizes[0] > 32:
                    variable = ncFile.createVariable(key, attr[0], dimensions=(attr[1]), fill_value=attr[2], zlib=zlib, chunksizes=[32])

                else:
                    variable = ncFile.createVariable(key, attr[0], dimensions=(attr[1]), fill_value=attr[2], zlib=zlib, chunksizes=chunksizes)

            else:
                variable = ncFile.createVariable(key, attr[0], dimensions=(attr[1]), fill_value=attr[2], zlib=zlib, chunksizes=chunksizes)
        else:
            variable = ncFile.createVariable(key, attr[0], fill_value=attr[1], zlib=zlib, chunksizes=chunksizes)

        #create var attributes
        if key == 'alt':
            variable.positive = 'up'
            variable.axis = 'Z'

        if(attr[3]):
            variable.valid_min = attr[3]
            variable.valid_max = attr[7]

        if(attr[4]):
            variable.standard_name = attr[4]

        if(attr[5]):
            variable.long_name = attr[5]

        if(attr[6]):
            variable.units = attr[6]

        if(attr[8]):
             variable.cf_role = attr[8]

        if key == 'base_time':
            variable.string = bts

        if 'time' in key:
            variable.calendar = 'gregorian' 

    for entry in database:
        if(entry == 'stamp'):
            continue

        varTup = database[entry]
        
        variable = ncFile.createVariable(entry, np.float32,
        dimensions=('time'), fill_value=float(-99999), zlib=zlib, chunksizes=chunksizes)

        variable.standard_name = varTup[1]
        variable.description = varTup[3]
        variable.units = varTup[4]

        if(varTup[5] != ''):
            variable.valid_min = float(varTup[5])
            variable.valid_max = float(varTup[6])

        qcVariable = ncFile.createVariable('qc_' + entry, 'b',
        dimensions=('time'), fill_value=0b0,  zlib=zlib, chunksizes=chunksizes)

        qcVariable.long_name = 'data quality'
        qcVariable.valid_range = np.byte(0b1), np.byte(0b1111)
        qcVariable.flag_masks = np.byte(0b1), np.byte(0b10), np.byte(0b100), np.byte(0b1000)

        flagMeaning = ['value is equal to missing_value', 
                       'value is less than the valid min', 
                       'value is greater than the valid max',
                       'difference between current and previous values exceeds valid_delta.']

        qcVariable.flag_meaning = ', '.join(flagMeaning)

    #create global attributes
    ncFile.source = 'surface observation'
    ncFile.conventions = 'ARM-1.2 CF-1.6'
    ncFile.institution = 'UW SSEC'
    ncFile.featureType = 'timeSeries'
    ncFile.data_level = 'b1'

    #monthly files end with .month.nc
    #these end with .day.nc

    ncFile.datastream = 'aoss.tower.nc-1d-1m.b1.v00'
    ncFile.software_version = '00'

    #generate history
    ncFile.history = ' '.join(platform.uname()) + " " + os.path.basename(__file__)
    
    return ncFile

def getGust(rollingAvg, speeds):
    averages = rollingAvg.tolist()
    maxSpeed = speeds['wind_speed'].tolist()

    gust = []

    for idx, average in enumerate(averages):
        if not average:
             gust.append(np.nan)
             continue

        elif average >= 4.63 and maxSpeed[idx] > average + 2.573:
              gust.append(maxSpeed[idx])

        else:
            gust.append(np.nan)
            continue

    return gust

#gets the rolling mean closest to the nearest minute
def getRolling(series, minutes):
    returnSeries = series.rolling(25, win_type='boxcar').mean()

    data = {}

    for minute in minutes:

        #doesn't go past the minute
        closestStamp = returnSeries.index.asof(minute)
        data[minute] = returnSeries[returnSeries.index.asof(minute)]

    returnSeries = pd.Series(data)
    
    return returnSeries

def getNewWindDirection(wind_dir, wind_speed, stamps):
    newWindDir = {}

    for stamp in stamps:
        before = stamp - delta(minutes=1)

        if before not in wind_speed.index:
            newWindDir[stamp] = None

        else:
            speed = wind_speed[before: stamp].tolist()
            dire = wind_dir[before: stamp].tolist()

            wind_dire = calc.mean_wind_vector(speed, dire)[0]
            
            newWindDir[stamp] = wind_dire
    
    return pd.Series(newWindDir)

def minuteAverages(frame):
    frame['minute'] = [(ts + delta(minutes=1)).replace(second=0) for ts in frame.index]
    newFrame = frame.groupby('minute').mean()
    newFrame.index.names = ['']

    columns = list(newFrame.columns.values)
    if 'wind_speed' in columns:
        del newFrame['wind_speed']

        windSeries = frame['wind_speed']

        windSeries = getRolling(windSeries, list(newFrame.index))

        newFrame['wind_speed'] = windSeries
  
        rollingAvg = newFrame['wind_speed']

        maxSpeed = pd.DataFrame()
        maxSpeed['minute'] = frame['minute']
        maxSpeed['speed'] = frame['wind_speed']

        maxSpeed = frame.groupby('minute').max()

        gust = getGust(rollingAvg, maxSpeed)

        newFrame['gust'] = gust
    
    if 'wind_dir' in columns:
        del newFrame['wind_dir']

        dupFrame = frame.set_index('minute')

        stamps = newFrame.index

        windDirSeries = dupFrame['wind_dir']

        windSeries = dupFrame['wind_speed']

        windDirSeries = getNewWindDirection(windDirSeries, windSeries, stamps)

        newFrame['wind_dir'] = windDirSeries

    del frame['minute']

    return newFrame.fillna(-99999)

def averageOverInterval(frame,interval_width):
    """takes a frame and an interval to average it over, and returns a minimum,
    maximum, and average dataframe for that interval"""
    ts = frame.index
    #round each timestamp to the nearest n minutes
    frame['interval'] = (ts.astype(int)-ts.astype(int)%(interval_width*60e9)).astype('datetime64[ns]')
    outFrames = {}
    outFrames['low'] = frame.groupby('interval').min()
    outFrames['high'] = frame.groupby('interval').max()
    outFrames['mean'] = frame.groupby('interval').mean()
    del frame['interval']
    for key in outFrames:
        #append the appropriate suffix to each column 
        columns = outFrames[key].columns 
        outFrames[key].columns = ['_'.join([col,key]) for col in columns]
    outFrames = pd.concat(outFrames.values(),axis=1)
    return outFrames 

def getData(inputFiles):
    dictData = {}

    for filename in inputFiles:
        getFrames = list(parser.read_frames(filename))

        for frame in getFrames:
            if 'stamp' not in frame:
                continue

            stamp = frame['stamp']
            del frame['stamp']

            dictData[stamp] = frame

    return pd.DataFrame(dictData).transpose().replace(-99999, np.nan)

def writeVars(ncFile, frame, database=parser.database):
    stamps = list(frame.index)
    baseDTObj = dt.strptime(str(stamps[0]).split(' ')[0], '%Y-%m-%d')

    #find out how much time elapsed
    #since the origin to the start of the day
    #in seconds
    baseTimeValue = baseDTObj - dt(1970,1,1)
    baseTimeValue = baseTimeValue.total_seconds()

    #create time numpy
    timeNumpy = np.empty(len(stamps), dtype='float32')

    counter = 0

    #write stamps in, yo

    for stamp in stamps:
        stampObj = dt.strptime(str(stamp), '%Y-%m-%d %H:%M:%S')
        timeValue = (stampObj - baseDTObj).total_seconds()

        timeNumpy[counter] = timeValue
        counter += 1

    fileVar = ncFile.variables
    fileVar['base_time'].assignValue(baseTimeValue)
    fileVar['time_offset'][:] = timeNumpy
    fileVar['time'][:] = timeNumpy + baseTimeValue

    #write coordinate var values to file
    #alt might not be right, need to verify
    fileVar['lon'].assignValue(station.LONGITUDE)
    fileVar['lat'].assignValue(station.LATITUDE)
    fileVar['alt'].assignValue(328)

    #might change
    stationName = ("AOSS Tower")
    
    #transfer station name into array of chars
    statChars = list(stationName)
    statNumpy = np.asarray(statChars)

    #write station name to file
    fileVar['station_name'][0:len(statNumpy)] = statNumpy

    #writes data into file
    for varName in frame:
        if varName not in fileVar:
            logging.warn('Extraneous key: %s in frame'%varName)
            continue
        dataList = frame[varName].tolist()

        dataArray = np.asarray(dataList)
        fileVar[varName][:] = dataArray

        valid_min = database[varName][5]
        valid_max = database[varName][6]

        fileVar['qc_' + varName][:] = filterArray(dataArray, valid_min, valid_max)

    coordinates = ['lon', 'lat', 'alt', 'base_time', 'time_offset', 'station_name', 'time']

    for varName in fileVar:
        if varName.startswith('qc_'):
            continue

        elif varName in frame:
            continue

        elif varName in coordinates:
            continue

        else:
            fileVar['qc_' + varName][:] = np.full(len(list(frame.index)), np.byte(0b1), dtype='b')
    
    return ncFile

#The purpose of this method is to take a begin date, and end date
# input filenames and output filename and create a netCDF file 
# based upon that
# @param start time - a start datetime object
# @param end time - an end datetime object
# @param input filenames - list of filenames
# @param output filename - filename of the netcdf file

def createGiantNetCDF(start, end, inputFiles, outputName, zlib, chunkSize,
                      interval_width = None,  database=parser.database):
    default = False

    if(chunkSize):
        chunksizes = [chunkSize]

    else:
        default = True

    frame = getData(inputFiles)

    if(frame.empty):
        return False

    else:

        frame = minuteAverages(frame)
        if interval_width:
            frame = averageOverInterval(frame,interval_width) 

        if(start and end):
            frame = frame[start.strftime('%Y-%m-%d %H:%M:%S'): end.strftime('%Y-%m-%d %H:%M:%S')]

        if(default):
            chunksizes = [len(list(frame.index))]

        firstStamp = dt.strptime(str(list(frame.index)[0]), '%Y-%m-%d %H:%M:%S')

        ncFile = Dataset(outputName, 'w', format='NETCDF4_CLASSIC')

        ncFile = writeDimensions(ncFile)

        ncFile = createVariables(ncFile, firstStamp, chunksizes, zlib,database)
 
        ncFile.inputFiles = ', '.join(inputFiles)

        ncFile = writeVars(ncFile, frame,database)

        ncFile.close()
        
        return True

def createMultiple(filenames, outputFilenames, zlib, chunkSize):
    if(outputFilenames and len(filenames) != len(outputFilenames)):
        print('USAGE: number of output filenames must equal number of input filenames when start and end times are not specified')
        exit(0)
    
    results = []

    for idx, filename in enumerate(filenames):
        results.append(createGiantNetCDF(None, None, [filename], outputFilenames[idx], zlib, chunkSize))

    allFalse = True

    for result in results:
        if result == True:
            allFalse = False

    if allFalse == True:
        raise IOError('All ASCII files were empty')

#The purpose of this method is to take a string in the format
# YYYY-mm-ddTHH:MM:SS and convert that to a datetime object
# used in coordination with argparse -s and -e params
# @param datetime string
# @return datetime object

def _dt_convert(datetime_str):
    #parse datetime string, return datetime object
    try: 
        return dt.strptime(datetime_str, '%Y-%m-%dT%H:%M:%S')
    except:
        return dt.strptime(datetime_str, '%Y-%m-%d')

def main():
    import argparse

    #argparse description
    argparser = argparse.ArgumentParser(description="Convert level_00 aoss tower data to level_a0",
                                        fromfile_prefix_chars='@')

    #argparse verbosity info
    argparser.add_argument('-v', '--verbose', action="count", default=int(os.environ.get("VERBOSITY", 2)),
                         dest='verbosity',
                         help='each occurrence increases verbosity 1 level through ERROR-WARNING-INFO-DEBUG (default INFO)')

    #argparse start and end times
    argparser.add_argument('-s', '--start-time', type=_dt_convert, 
        help="Start time of massive netcdf file, if only -s is given, a netcdf file for only that day is given" + 
        ". Formats allowed: \'YYYY-MM-DDTHH:MM:SS\', \'YYYY-MM-DD\'")
    argparser.add_argument('-e', '--end-time', type=_dt_convert, help='End time of massive netcdf file. Formats allowed:' +
        "\'YYYY-MM-DDTHH:MM:SS\', \'YYYY-MM-DD\'")
    argparser.add_argument('-i', '--interval', type=float, 
            help='Width of the interval to average input data over in minutes.'+
        " If not specified, 1 is assumed. (Use 60 for one hour and 1440 for 1 day)")
    argparser.add_argument('-cs', '--chunk-size', type=int, help='chunk Size for the netCDF file')
    argparser.add_argument('-z', '--zlib', action='store_true', help='compress netCDF file with zlib')

    argparser.add_argument("input_files", nargs="+",
                         help="aoss_tower level_00 paths. Use @filename to red a list of paths from that file.")

    argparser.add_argument('-o', '--output', required=True, nargs="+", help="filename pattern or filename. " +
    "Should be along the lines of <filepath>/aoss_tower.YYYY-MM-DD.nc")
    args = argparser.parse_args()

    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    level=levels[min(3, args.verbosity)]
    logging.basicConfig(level=level)


    database = avg_database.AOSS_VARS if args.interval else parser.database
    if(args.start_time and args.end_time):
        result = createGiantNetCDF(args.start_time, args.end_time, args.input_files, args.output[0], args.zlib, args.chunk_size,
                                   args.interval, database)
        if(result == False):
            raise IOError('An empty ASCII file was found')

    elif(args.start_time):
        end_time = args.start_time.replace(hour=23, minute=59, second=59)
        result = createGiantNetCDF(args.start_time, end_time, args.input_files, args.output[0], args.zlib, args.chunk_size,
                                   args.interval, database)
        if(result == False):
            raise IOError('An empty ASCII file was found')

    elif(args.end_time):
        print('USAGE: start time must be specified when end time is specified')

    else:
        createMultiple(args.input_files, args.output, args.zlib, args.chunk_size)

if __name__ == "__main__":
    main()