from util import BaseCheckList
import numpy as np
import pandas as pd

def hbb_temp_outlier_check(frame, parameters):
    if not np.in1d(['HBBbottomTemp','HBBapexTemp','HBBtopTemp'], frame.columns).all():
        return frame

    window_length = parameters.get('window_length', 100)

    frame['hbb_temp_outlier_check'] = (
        _find_6sigma_outliers(frame['HBBbottomTemp'], window_length) |
        _find_6sigma_outliers(frame['HBBapexTemp'], window_length) |
        _find_6sigma_outliers(frame['HBBtopTemp'], window_length)
     ) * 1
    return frame

def abb_temp_outlier_check(frame, parameters):
    if not np.in1d(['ABBbottomTemp','ABBapexTemp','ABBtopTemp'], frame.columns).all():
        return frame

    window_length = parameters.get('window_length', 100)

    frame['abb_temp_outlier_check'] = (
        _find_6sigma_outliers(frame['ABBbottomTemp'], window_length, use_mean=True) |
        _find_6sigma_outliers(frame['ABBapexTemp'], window_length, use_mean=True) |
        _find_6sigma_outliers(frame['ABBtopTemp'], window_length, use_mean=True)
     ) * 1
    return frame

def calibrationambienttemp_outlier_check(frame, parameters):
    if 'calibrationAmbientTemp' not in frame.columns:
        return frame
    
    window_length = parameters.get('window_length', 100)

    frame['calibrationambienttemp_outlier_check'] = _find_6sigma_outliers(frame['calibrationAmbientTemp'], window_length, use_mean=True) * 1
    return frame

class CheckList(BaseCheckList):
    checks = [ hbb_temp_outlier_check , abb_temp_outlier_check, calibrationambienttemp_outlier_check ]

def _compute_robust_zscore(frame, window_length, use_mean=False):
    use_mean = False
    
    if use_mean:
        robust_rolling_std = frame.rolling(window=window_length, center=True, min_periods=1).std()
        return abs((frame - frame.rolling(window=window_length, center=True, min_periods=1).mean()) / robust_rolling_std)
    else:
        # Compute a centered rolling MAD over window_length
        rolling_mad = abs(frame - frame.rolling(window=window_length, center=True, min_periods=1).median()
                     ).rolling(window=window_length, center=True, min_periods=1).median()
        # standard deviation is proportional to median absolute deviation I'm told
        robust_rolling_std = rolling_mad * 1.48
        return abs((frame - frame.rolling(window=window_length, center=True, min_periods=1).median()) / robust_rolling_std)
    
def _find_6sigma_outliers(frame, window_length, use_mean=False):
    # Find outliers with deviation greater than 6 sigma
    outlier_mask = _compute_robust_zscore(frame, window_length, use_mean) > 6
    return outlier_mask


#### TESTS ####


def test_hbb_temp_outlier_check():
    frame = pd.DataFrame({
        'HBBapexTemp':[0,1,10,1],
        'HBBbottomTemp':[1,1,1,1],
        'HBBtopTemp':[0,1,10,1],
        'qc_notes':''
    })
    assert hbb_temp_outlier_check(frame, {})['hbb_temp_outlier_check'].values.tolist() == [0,0,1,0]

def test_abb_temp_outlier_check():
    frame = pd.DataFrame({
        'ABBapexTemp':[0,1,10,1],
        'ABBbottomTemp':[1,1,1,1],
        'ABBtopTemp':[0,1,10,1],
        'qc_notes':''
    })
    assert abb_temp_outlier_check(frame, {})['abb_temp_outlier_check'].values.tolist() == [0,0,1,0]

def test_calibrationambienttemp_temp_outlier_check():
    frame = pd.DataFrame({
        'calibrationAmbientTemp':[0,1,10,1],
        'qc_notes':''
    })
    assert calibrationambienttemp_outlier_check(frame, {})['calibrationambienttemp_outlier_check'].values.tolist() == [0,0,1,0]