import os
import shelve
import tempfile
from glob import glob
import re
from collections import defaultdict, OrderedDict
import netCDF4
from aeri_tools.io.dmv.housekeeping import get_all_housekeeping
import pandas as pd
import numpy as np
from zipfile import ZipFile
from io import BytesIO
from tempfile import mkstemp
from shutil import move
from datetime import datetime
from aeri_qc.igm_checks import spike_check
from aeri_qc import bomem_file
from itertools import product

def save_quality(frame, qc_path, checklist):
    """
    Save the DataFrame (frame) to a netCDF at (qc_path)
    """
    # First select only rows corresponding to records from the sum file
    frame = frame.ix[pd.notnull(frame.sum_index)].set_index('sum_index').sort_index()

    # Define the netcdf
    handle, temp = mkstemp()
    os.close(handle)
    ncdf = netCDF4.Dataset(temp, 'w')
    ncdf.algorithm_version = '0.1.0'
    ncdf.createDimension('time', len(frame))
    time = ncdf.createVariable('time', 'f8', ('time',))
    time.units = 'nanoseconds since 1970-01-01 00:00:00'
    
    # Write the columns ending in _check (aggregate tests)
    for check in checklist.checks:
        variable = ncdf.createVariable(check.name, 'f4', ('time',))
        variable.depends = ','.join(check.depends)
        variable.affects_calibration = str(check.affects_calibration)
        variable.affected_by_calibration = str(check.affected_by_calibration)
        variable.description = check.description
        variable.hides = ','.join(check.hides)
        variable[:] = frame[check.name].values

    spike_check_variable = ncdf.createVariable('spike_check','f4',('time',))
    spike_check_variable.depends = ''
    spike_check_variable.affects_calibration = "True"
    spike_check_variable.affected_by_calibration = "False"
    spike_check_variable.description = 'Check for spikes in interferometer caused by faulty electronics'
    spike_check_variable.hides = ''
    # Do the spike check separately for now
    if 'spike_check' in frame.columns:
        spike_check_variable[:] = frame.spike_check.values
    else:
        spike_check_variable[:] = 0

    # Write time information
    time[:] = frame.datetime.values

    ncdf.close()
    move(temp,qc_path)
    

def compute_calibration_graph(sceneMirrorPosition):
    affected_records = defaultdict(set)
    scene_regex = '(?<=(HA|AH))([^HA]+)(?=(HA|AH))'
    for scene in re.finditer(scene_regex, ''.join(sceneMirrorPosition.fillna(ord('?')).apply(int).apply(chr))):
        blackbody_indices = set(range(*scene.span(1)))
        blackbody_indices.update(range(*scene.span(3)))
        sky_view_indices = range(*scene.span(2))
        for blackbody_index in blackbody_indices:
            affected_records[blackbody_index].update(sky_view_indices)
    return affected_records

def fill_na_sceneMirrorPosition(sceneMirrorPosition):
    if not sceneMirrorPosition.isnull().any():
        return sceneMirrorPosition
    datetime_index = sceneMirrorPosition.index
    # Reset the index so NaTs don't mess things up
    sceneMirrorPosition = sceneMirrorPosition.reset_index(drop=True)
    # Keep shifting until everything matches up except missing scenes
    for i in range(1, len(sceneMirrorPosition)):
        if abs(sceneMirrorPosition - sceneMirrorPosition.shift(i)).dropna().sum() == 0:
            period = i
            break
    else:
        return sceneMirrorPosition
    # Now try to fill with time-shifted versions of the same variable
    for i in np.arange(1, len(sceneMirrorPosition)/period+1):
        sceneMirrorPosition.fillna(sceneMirrorPosition.shift(i*period), inplace=True)
        sceneMirrorPosition.fillna(sceneMirrorPosition.shift(-i*period), inplace=True)
        if not sceneMirrorPosition.isnull().any():
            break

    sceneMirrorPosition.index = datetime_index

    return sceneMirrorPosition
    

def read_mirror(mirror_beg_path):
    mirror_beg = []
    with open(mirror_beg_path) as fp:
        mir = fp.read().splitlines()
    for scene_line in mir[2::2]:
        name,idk,angle = scene_line.split()
        mirror_beg.append( (name, float(idk), float(angle)) )
    assert len(mirror_beg) == int(mir[0].split()[0]), (mir[0], len(mirror_beg))
    return mirror_beg

def read_frame(cxs_file, sum_file):
    """
    Read housekeeping from CXS file and SUM file together

    Returns DataFrame with range index, datetime column, and sum_index column, and housekeeping data
    """
    # Get CXS housekeeping as dataframe, remove datetime index into a column and replace with integers
    cxs = get_all_housekeeping(cxs_file)
    cxs.index.name = 'datetime'
    cxs.reset_index(inplace=True)
    cxs.index.name = 'cxs_index'
    # Fill in any missing scenesMirrorPositions
    cxs['sceneMirrorPosition'] = fill_na_sceneMirrorPosition(cxs.sceneMirrorPosition)
    # Find all non-calibration views
    non_cal_records = cxs.ix[~cxs.sceneMirrorPosition.isin([ord('H'), ord('A')])].copy()
    non_cal_records.reset_index(inplace=True)
    non_cal_records.index.name = 'sum_index'

    # Read SUM as well
    sum_ = get_all_housekeeping(sum_file)
    sum_.index.name = 'datetime'
    sum_.reset_index(inplace=True)
    sum_.index.name = 'sum_index'
    
    # Combine extra data from SUM into CXS
    # also non_cal_records may have a different number of records than the sum, so limit by the sum
    non_cal_records = non_cal_records.combine_first(sum_).ix[sum_.index]
    hk = cxs.combine_first(non_cal_records.reset_index().set_index('cxs_index')).reset_index()
    hk.calibration_graph = compute_calibration_graph(hk.sceneMirrorPosition)
    return hk

def read_igms(spc_zip_path, cache=()):
    """
    Read a zip file that archives Igm files, yield dictionaries containing interferograms and index info
    """
    if spc_zip_path is not None:
        # Open zip file
        with ZipFile(spc_zip_path) as spc_zip:
            # Find all members with .Igm suffix
            for name in spc_zip.namelist():
                if name.endswith('.Igm') and name not in cache:
                    for index, subfile in enumerate(bomem_file.read_zip(spc_zip, name)):
                        # yield row
                        yield {
                            'datetime':datetime.utcfromtimestamp(subfile['UTCTime']),
                            'DataA':subfile['DataA'].squeeze(),
                            'DataB':subfile['DataB'].squeeze(),
                            'sceneMirrorPosition':ord(name[0]),
                            'subfile':index,
                            'scene_index':int(re.search('[A-Z]([0-9]+)M_[0-9]{8}_[0-9]{6}_.*[.]Igm', name).group(1)),
                            'name':name
                        }

def check_frame(frame, parameters, checklist):
    """
    Start with housekeeping DataFrame and iteratively run checks to compute quality
    """
    frame['qc_percent'] = 0
    frame['qc_notes'] = None
    frame = checklist.check_everything(frame, parameters)
    return frame

def get_cache():
    tempdir = tempfile.gettempdir()
    return os.path.join(tempdir,'aeri_quality_control_cache')

def get_cached_spike_check(zip_file):
    cache = get_cache()
    cached_data = {}
    with shelve.open(cache) as shlv:
        with ZipFile(zip_file) as spc_zip:
            for name in set(shlv.keys()) & set(spc_zip.namelist()):
                cached_data[name] = shlv[name]
    data = pd.DataFrame(cached_data).T
    data.index.name = 'name'
    return data.reset_index()

def save_cached_spike_check(dataframe):
    cache = get_cache()
    with shelve.open(cache) as shlv:
        shlv.update(dataframe.set_index('name').to_dict(orient='index'))

def prepare_frame(cxs_file, sum_file, sci_dir):
    # First read the housekeeping dataframe
    frame = read_frame(cxs_file, sum_file)

    # Try to find the Igms for these times
    YYMMDD = re.search('([0-9]{6})', os.path.basename(cxs_file)).group(1)
    if sci_dir is not None:
        possible_zip_file = os.path.join(sci_dir, 'AE'+YYMMDD, 'SPC_AERI_{}.zip'.format(YYMMDD))
        possible_mirror = os.path.join(sci_dir, 'AE'+YYMMDD, 'cal', 'mirror.beg')
        if os.path.isfile(possible_zip_file) and os.path.isfile(possible_mirror):
            # read the interferograms
            print('Found igm file')
            cached_frame = get_cached_spike_check(possible_zip_file)
            if 'name' in cached_frame.columns:
                cached_names = set(cached_frame['name'])
            else:
                cached_names = []
            igms = pd.DataFrame.from_records(spike_check(read_igms(possible_zip_file, cache=cached_names), read_mirror(possible_mirror)),
                                             columns=['cxs_index','name','spike_check'])
            igms = pd.concat([igms,cached_frame], axis=0)
            if igms.empty:
                return frame
            igms = igms.groupby('cxs_index').agg({'spike_check':np.any, 'name':lambda x: x.iloc[0]}).reset_index()
            save_cached_spike_check(igms)

            # Add columns from igms, notably DataA, DataB
            cal_graph = frame.calibration_graph
            frame = frame.merge(igms, on=['cxs_index'], how='left', suffixes=('','_igm'))
            frame.spike_check.fillna(False, inplace=True)
            frame.calibration_graph = cal_graph
            print('Processed igm')

    return frame

def update_all(ftp_dir, sci_dir, checklist, parameters=None, hint=None):
    """
    Given the root directories for ftp and sci, find all days lacking an up-to-date qc file and generate new qc
    """
    # Find all CXS files
    cxs_files = glob(os.path.join(os.path.abspath(ftp_dir),'AE*','*B1.CXS'))
    # For each CXS file find a matching SUM file and possible QC filename
    for qc_file, cxs_file, sum_file in files_to_update(cxs_files, hint=hint):
        print('Performing quality control for {}'.format(cxs_file))
        frame = prepare_frame(cxs_file, sum_file, sci_dir)

        if parameters is None:
            parameters = {}

        # Perform qc on housekeeping frame
        frame = check_frame(frame, parameters, checklist)

        save_quality(frame, qc_file, checklist)

def should_update(sum_file, qc_file):
    sum_length = len(get_all_housekeeping(sum_file))
    nc = netCDF4.Dataset(qc_file)
    qc_length = len(nc.variables['qc_percent'])
    nc.close()
    return sum_length != qc_length

def files_to_update(cxs_files, *, update_only=True, hint=None):
    """
    Find a matching SUM file and determine the QC filename for each CXS file in sequence
    """
    for cxs_file in cxs_files:
        # Determine names of possible files
        possible_sum = os.path.join(os.path.dirname(cxs_file), cxs_file.replace('B1.CXS','.SUM'))
        possible_qc = os.path.join(os.path.dirname(cxs_file), cxs_file.replace('B1.CXS','QC.nc'))

        # Check if those files exist
        if os.path.isfile(possible_sum):
            sum_file = possible_sum

            # If QC file already exists, also check that it is newer than the sum and cxs file
            if os.path.isfile(possible_qc):
                qc_file = possible_qc
                if max(os.path.getmtime(sum_file), os.path.getmtime(cxs_file)) > os.path.getmtime(qc_file):
                    # Regenerate QC if it is older
                    yield (qc_file, cxs_file, sum_file)
                elif hint is not None and os.path.abspath(hint) in [os.path.abspath(sum_file), os.path.abspath(cxs_file)]:
                    # If given a hint, look closer, read the files and compare lengths
                    if should_update(sum_file, qc_file):
                        yield (qc_file, cxs_file, sum_file)
                elif not update_only:
                    # update_only=False will always regenerate
                    yield (qc_file, cxs_file, sum_file)
            else:
                # if qc doesn't exist, generate
                yield (possible_qc, cxs_file, sum_file)