from sys import argv
import os
import time

import pandas as pd
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt

from aeri_tools.io.dmv import housekeeping
from aeri_tools.io.dmv import radiance


def main(path):

    #get data from B1.CXS file
    rad = radiance.get_radiance_from_rnc(path)
    hk = housekeeping.get_all_housekeeping(path)

    #only use rad data from scene H
    data = pd.DataFrame(index=hk.index, columns=('rad', 'hk'))
    data['rad'] = rad.xs(str.encode('H'), level='scene').mean(axis=1)
    # rad.xs(str.encode('H'), level='scene').iloc[:,
    # ((rad.xs(str.encode('H'), level='scene').columns > 560) &
    # (rad.xs(str.encode('H'), level='scene').columns < 600))].mean(axis=1)
    data['hk'] = hk['SCEtemp']
    #get rid of hk rows where scene isn't H
    data.dropna(inplace=True)

    #account for errors when using the savgol_filter
    try:
        #subtract the rolling average
        for x in data.columns:
            tmp = scipy.signal.savgol_filter(data[x].values[:], 201, 3)
            data[x] = data[x] - tmp

        #correlate the data
        correlation = np.correlate(data['rad'].values[:],
                                    data['hk'].values[:], mode='same')

        print('max = ', np.amax(correlation), ' : min = ',
                np.amin(correlation))

        fig, ax = plt.subplots(1, figsize=(15,10), sharex=True)

        '''
        for x, val in ('rad', 'hk'):
            plt.sca(ax[x])
            plt.plot(list(data.index), data[val].values[:])
        '''

        plt.plot(data.index, correlation)

        plt.show()
        name = '/Users/adiebold/Documents/corr_pngs/' + path[-12:-6] + '.png'
        # plt.savefig(name)
        plt.clf()
    #probably the scipy.signal.savgol_filter failing
    except:
        print('FAIL')


if __name__ == '__main__':
    start_time = time.time()
    filepath = argv[1]

    skip_num = 0
    curr_num = 0
    print('skip_num = ', skip_num, '\n')
    if os.path.isdir(filepath):
        for filename_1 in os.listdir(filepath):
            filename_1 = filepath + '/' + filename_1
            filename_1 = filename_1.replace('//', '/')
            if os.path.isdir(filename_1):
                for filename_2 in os.listdir(filename_1):
                    filename_2 = filename_1 + '/' + filename_2
                    filename_2.replace('//', '/')
                    if (os.path.isfile(filename_2) and
                    filename_2.endswith('B1.CXS')):
                        curr_num += 1
                        if curr_num >= skip_num:
                            print(curr_num, ': ', filename_2)
                            main(filename_2)
                        else:
                            print(curr_num, ': ', filename_2, ' -- SKIPPED')
            elif os.path.isfile(filename_1) and filename_1.endswith('B1.CXS'):
                curr_num += 1
                if curr_num >= skip_num:
                    print(curr_num, ': ', filename_1)
                    main(filename_1)
                else:
                    print(curr_num, ': ', filename_1, ' -- SKIPPED')
    elif os.path.isfile(filepath):
        if filepath.endswith('B1.CXS'):
            print(filepath)
            main(filepath)

    print('execution time: %d minute(s), %.2f second(s)' %
        ((time.time()-start_time)//60, (time.time()-start_time)%60))



for f in tf:
    rad = radiance.get_radiance_from_rnc(f)
    hk = housekeeping.get_all_housekeeping(f)
    rad_test = rad.xs(str.encode('H'), level='scene').iloc[:, ((rad.xs(str.encode('H'), level='scene').columns > 560) & (rad.xs(str.encode('H'), level='scene').columns < 600))].mean(axis=1)
    #hk_test = hk['SCEtemp'].loc[rad_test.index[0]:rad_test.index[len(rad_test.index)-1]]
    hk_test = hk['SCEtemp'].loc[rad_test.index]
    ax = plt.subplot(4,1,1)
    ax.set_title('radiance filtered')
    rad_test = rad_test - scipy.signal.savgol_filter(rad_test.values, 51, 4)
    plt.plot(rad_test)
    ax = plt.subplot(4,1,2)
    ax.set_title('SCEtemp raw')
    plt.plot(hk_test)
    ax = plt.subplot(4,1,3)
    ax.set_title('SCEtemp filtered')
    hk_test = hk_test - scipy.signal.savgol_filter(hk_test.values, 51, 4)
    plt.plot(hk_test)
    ax = plt.subplot(4,1,4)
    ax.set_title('correlation')
    #print(len(rad_test.index), ' -- ', len(np.correlate(rad_test.values, hk_test.values, 'same')))
    plt.plot(rad_test.index, np.correlate(hk_test.values, rad_test.values, 'same'))
    corr = np.correlate(hk_test.values, rad_test.values)[0]
    t = f + ' -- ' + 'reduced -- ' + str(corr)
    plt.suptitle(t, fontsize=20)
    print(t)
    plt.show()
    plt.clf()

for c,f in enumerate(tf):
    print(f)
    rad = radiance.get_radiance_from_rnc(f)
    hk = housekeeping.get_all_housekeeping(f)
    rad_test = rad.xs(str.encode('H'), level='scene').iloc[:, ((rad.xs(str.encode('H'), level='scene').columns > 560) & (rad.xs(str.encode('H'), level='scene').columns < 60
    0))].mean(axis=1)
    hk_test = hk['SCEtemp'].loc[rad_test.index]
    rad_test = rad_test - scipy.signal.savgol_filter(rad_test.values, 51, 4)
    hk_test = hk_test - scipy.signal.savgol_filter(hk_test.values, 51, 4)
    curr_num = 0
    curr_roll = max_roll / 2
    corr = []
    for x in range(len(hk_test.values)):
        corr.append(max(np.correlate(hk_test.values[curr_num:curr_num+curr_roll], rad_test.values[curr_num:curr_num+curr_roll], 'full')[curr_roll-1:curr_roll+1]))
        if curr_roll < max_roll and curr_num == 0:
            curr_roll += 1
        elif curr_roll + curr_num == len(hk_test.values):
            curr_num += 1
            curr_roll -= 1
        else:
            curr_num += 1
    #ax = plt.subplot(len(tf), 1, c+1)
    ax = plt.subplot(111)
    ax.set_title(f + ' ||| ' + str(np.corrcoef(hk_test.values, rad_test.values)[0,1]))
    ax.set_ylim(-0.1, 1.0)
    plt.plot(corr)
    plt.savefig('/Users/adiebold/aeri_quality_control/testing/pngs/correlation/' + f[:3] + '_' + f[6:12] + '.png')
    plt.clf()