import os
from glob import glob
import re
import netCDF4
from aeri_tools.io.dmv.housekeeping import get_all_housekeeping


import electronic_checks
import global_checks
import radiometric_checks
import scene_checks
import state_checks
import thermal_checks

levels = [
    global_checks.CheckList(),
    scene_checks.CheckList(),
    state_checks.CheckList(), 
    electronic_checks.CheckList(),
    radiometric_checks.CheckList(),
    thermal_checks.CheckList()
]

def save_quality(frame, qc_path):
    frame[['qc_percent','qc_notes','datetime']]
    ncdf = netCDF4.Dataset(qc_path, 'w')
    time = ncdf.createDimension('time', len(frame))
    base_time = ncdf.createVariable('base_time', 'i8', ())
    time_offset = ncdf.createVariable('time_offset', 'i8', ('time',))
    qc_percent = ncdf.createVariable('qc_percent', 'f4', ('time',))
    qc_notes = ncdf.createVariable('qc_notes', str, ('time',))
    for check_mask in frame.filter(like='_check'):
        ncdf.createVariable(check_mask, 'f4', ('time',))[:] = frame[check_mask].values
    for variable_qc in frame.filter(like='qc_'):
        if variable_qc not in ['qc_notes','qc_percent']:
            ncdf.createVariable(variable_qc, 'f4', ('time',))[:] = frame[variable_qc].values
    base_time[:] = frame.datetime.dropna().iloc[0].to_datetime64()
    time_offset[:] = (frame.datetime - frame.datetime.dropna().iloc[0]).values
    qc_percent[:] = frame['qc_percent'].values
    qc_notes[:] = frame['qc_notes'].fillna('').values
    ncdf.close()
    

def read_frame(cxs_file, sum_file):
    hk = get_all_housekeeping(cxs_file).combine_first(
        get_all_housekeeping(sum_file))
    hk.index.name = 'datetime'
    return hk.reset_index()

def check_frame(frame, parameters):
    frame['qc_percent'] = 0
    frame['qc_notes'] = None
    for level in levels:
        level.set_params(parameters)
        frame = level.compute(frame)
    return frame

def update_all(ftp_dir, parameters=None):
    cxs_files = glob(os.path.join(ftp_dir,'AE*','*B1.CXS'))
    for qc_file, cxs_file, sum_file in files_to_update(cxs_files):
        print('Performing quality control for {}'.format(cxs_file))
        frame = read_frame(cxs_file, sum_file)
        if parameters is None:
            parameters = {}
        frame = check_frame(frame, parameters)
        save_quality(frame, qc_file)

def files_to_update(cxs_files, update_only=True):
    for cxs_file in cxs_files:
        possible_sum = os.path.join(os.path.dirname(cxs_file), cxs_file.replace('B1.CXS','.SUM'))
        possible_qc = os.path.join(os.path.dirname(cxs_file), cxs_file.replace('B1.CXS','.qc'))

        if os.path.isfile(possible_sum):
            sum_file = possible_sum

            if os.path.isfile(possible_qc):
                qc_file = possible_qc
                if max(os.path.getmtime(sum_file), os.path.getmtime(cxs_file)) > os.path.getmtime(qc_file):
                    yield (qc_file, cxs_file, sum_file)
                elif not update_only:
                    yield (qc_file, cxs_file, sum_file)
            else:
                yield (possible_qc, cxs_file, sum_file)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('ftp')

    args = parser.parse_args()

    update_all(args.ftp)