Skip to content
Snippets Groups Projects
test_variables.py 21.75 KiB
from aeri_qc import radiometric_checks, igm_checks, state_checks, electronic_checks, thermal_checks, cooler_checks, scene_checks, global_checks
from aeri_qc.all_checks import aggregate_qc_check, aggregate_qc_flag, get_checklist
import numpy as np
import pandas as pd
import pytest
import xarray as xr
import os
import logging


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.NOTSET)


def cubic(x):
    return x**3


def linear(x):
    return x


def create_data(start, stop, length, skips=0, middle=0, func=cubic):
    data = np.vectorize(func)(np.linspace(
        start, stop, length))  # generate shape
    data = data - np.min(data)  # shift to origin
    if np.max(data) == 0:
        data = np.ones(length)
    else:
        data = data / np.max(data) * 2 - 1
    spikes = data * np.resize(np.repeat((1, 0), (1, skips)),
                              data.shape)  # create spikes
    base = middle * np.resize(np.repeat((0, 1), (1, skips)), data.shape)
    return base + spikes


class TestCase(object):

    @classmethod
    def setup_class(cls):
        np.seterr(invalid='warn')
        np.seterr(divide='warn')
        cls.data_interval = pd.Timedelta(24*10, 's')

    def setup_method(self, method):
        self.expected_output = os.environ['QC_EXPECTED_OUTPUT']
        self.datetimes = np.vectorize(pd.Timestamp)(
            np.arange('2022-01-01', '2022-01-02', self.data_interval, dtype='datetime64'))

        self.start = -1
        self.stop = 1
        self.length = len(self.datetimes)
        self.skips = 2
        self.base = 0
        self.amplitude = 1
        self.middle = 0
        self.frame = pd.DataFrame({
            'datetime': self.datetimes
        })
        self.parameters = {}
        self.name = method.__name__
        self.data_func = cubic

    @property
    def should_save(self):
        return os.environ.get('QC_SAVE_EXPECTED')

    def save_expected(self, data):
        dataset = self.frame.to_xarray().drop(
            'index').rename_dims({'index': 'time'})

        dataset['expected_output'] = ("time", data)
        # Make sure datetime is in saved file as reference, even if not used in test/
        dataset['datetime'] = ("time", self.datetimes)
        dataset.to_netcdf(os.path.join(
            self.expected_output, f'{self.name}.nc'))

    def get_expected(self):
        fp = os.path.join(self.expected_output, f'{self.name}.nc')
        if os.path.isfile(fp):
            return xr.load_dataset(fp).get('expected_output')

    def get_data(self, start=None, stop=None, length=None, skips=None, base=None, amplitude=None, middle=None, data_func=None):
        start = self.start if start is None else start
        stop = self.stop if stop is None else stop
        length = self.length if length is None else length
        skips = self.skips if skips is None else skips
        base = self.base if base is None else base
        amplitude = self.amplitude if amplitude is None else amplitude
        middle = self.middle if middle is None else middle
        data_func = self.data_func if data_func is None else data_func
        return base + create_data(start, stop, length, skips=skips, middle=middle, func=data_func) * amplitude

    def run_test(self):
        actual = np.around(self.qc_func(self.frame, self.parameters), 3)
        if np.shape(actual) == ():
            actual = np.repeat(actual, self.frame.shape[0])
        # Saving then opening expected data tests that the data was saved correctly.
        if self.should_save:
            self.save_expected(actual)
        expected = self.get_expected()
        err_msg = f'Test Name: {self.name}\nParameters: {self.parameters}\nFrame:\n{self.frame}\nACTUAL: {np.array(actual)}\nEXPECTED: {np.array(expected)}'
        np.testing.assert_array_equal(actual, expected, err_msg, verbose=False)
        if np.count_nonzero(expected) == 0:
            logger.warning(
                f'{self.name}: expected contained no flagged records')
        elif np.all(expected == 1) == self.length:
            logger.warning(
                f'{self.name}: expected contained only flagged records')


class TestHbbTempOutlierFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.base = self.base + 300
        self.amplitude = self.amplitude * .002 * \
            (self.data_interval / pd.Timedelta(1, 's'))
        # WARNING: self.qc_func must be set.
        self.qc_func = electronic_checks.hbb_temp_outlier_flag
        self.data_func = linear

    @pytest.mark.parametrize("suffix, bot, apex, top", (('bot', True, False, False), ('apex', False, True, False), ('top', False, False, True), ('all', True, True, True)))
    def test_hbb_temp_outlier_flag(self, suffix, bot, apex, top):
        self.name = f'{self.name}_{suffix}'
        temp = self.get_data()
        self.frame['HBBbottomTemp'] = temp if bot else self.base
        self.frame['HBBapexTemp'] = temp if apex else self.base
        self.frame['HBBtopTemp'] = temp if top else self.base
        self.run_test()


class TestAbbTempOutlierFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.base = self.base + 300
        self.amplitude = self.amplitude * .02 * \
            (self.data_interval / pd.Timedelta(1, 's'))
        # WARNING: self.qc_func must be set.
        self.qc_func = electronic_checks.abb_temp_outlier_flag
        self.data_func = linear

    @pytest.mark.parametrize("suffix, bot, apex, top", (('bot', True, False, False), ('apex', False, True, False), ('top', False, False, True), ('all', True, True, True)))
    def test_abb_temp_outlier_flag(self, suffix, bot, apex, top):
        self.name = f'{self.name}_{suffix}'
        temp = self.get_data()
        self.frame['ABBbottomTemp'] = temp if bot else self.base
        self.frame['ABBapexTemp'] = temp if apex else self.base
        self.frame['ABBtopTemp'] = temp if top else self.base
        self.run_test()


class TestAbbThermistorFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = state_checks.abb_thermistor_flag

    @pytest.mark.parametrize("suffix1, base, amplitude", (('_min', 200, 20), ('_max', 318, 20), ('_mid', 259, 59)))
    @pytest.mark.parametrize("suffix2, bot, apex, top", (('_bot', True, False, False), ('_apex', False, True, False), ('_top', False, False, True), ('_all', True, True, True)))
    def test_abb_thermistor_flag(self, suffix1, bot, apex, top, suffix2, base, amplitude):
        self.name = f'{self.name}{suffix1}{suffix2}'
        self.base = self.base + base
        self.amplitude = self.amplitude * amplitude
        temp = self.get_data(skips=0)
        self.frame['ABBbottomTemp'] = temp if bot else self.base
        self.frame['ABBapexTemp'] = temp if apex else self.base
        self.frame['ABBtopTemp'] = temp if top else self.base
        self.run_test()


class TestAirInterferometerOutlierFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.base = self.base + 300
        self.amplitude = self.amplitude * \
            (self.data_interval / pd.Timedelta(1, 's'))
        self.qc_func = electronic_checks.air_interferometer_outlier_flag
        self.data_func = linear

    def test_air_interferometer_outlier_flag(self):
        self.frame['airNearInterferometerTemp'] = self.get_data()
        self.run_test()


class TestBstTempOutlierFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.base = self.base + 300
        self.amplitude = self.amplitude * 2 * \
            (self.data_interval / pd.Timedelta(1, 's'))
        self.qc_func = electronic_checks.bst_temp_outlier_flag
        self.data_func = linear

    def test_bst_temp_outlier_flag(self):
        self.frame['BBsupportStructureTemp'] = self.get_data()
        self.run_test()


class TestHbbStableFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.base = self.base + 300
        self.amplitude = self.amplitude * .004 * \
            (self.data_interval / pd.Timedelta(1, 's'))
        self.qc_func = state_checks.hbb_stable_flag
        self.data_func = linear

    @pytest.mark.parametrize("suffix, bot, apex, top", (('bot', True, False, False), ('apex', False, True, False), ('top', False, False, True), ('all', True, True, True)))
    def test_hbb_stable_flag(self, suffix, bot, apex, top):
        self.name = f'{self.name}_{suffix}'
        temp = self.get_data()
        self.frame['HBBbottomTemp'] = temp if bot else self.base
        self.frame['HBBapexTemp'] = temp if apex else self.base
        self.frame['HBBtopTemp'] = temp if top else self.base
        self.run_test()


class TestHatchFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = scene_checks.hatch_flag

    def test_hatch_flag(self):
        self.name = f'{self.name}'
        self.frame['sceneMirrorPosition'] = np.concatenate(
            tuple(np.repeat((ord('A'), ord('H'), 0), n) for n in range(16)))
        self.frame['hatchOpen'] = np.zeros(
            self.frame['sceneMirrorPosition'].shape)
        self.frame['hatchOpen'][::2] = 1
        self.run_test()


class TestHbbCovarianceFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.base = self.base + 300
        self.amplitude = self.amplitude * 5
        self.qc_func = thermal_checks.hbb_covariance_flag

    @pytest.mark.parametrize("suffix1, bot, apex, top", (('_bot', True, False, False), ('_apex', False, True, False), ('_top', False, False, True), ('_all', True, True, True)))
    @pytest.mark.parametrize("suffix2, skips", (('_spike', None), ('_smooth', 0)))
    def test_hbb_covariance_flag(self, suffix1, bot, apex, top, suffix2, skips):
        self.name = f'{self.name}{suffix1}{suffix2}'
        temp = self.get_data(skips=skips)
        self.frame['HBBbottomTemp'] = temp if bot else self.base
        self.frame['HBBtopTemp'] = temp if top else self.base
        self.frame['HBBapexTemp'] = temp if apex else self.base
        self.run_test()


class TestHbbStdDevFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.base = self.base + 300
        self.amplitude = self.amplitude * 5
        self.qc_func = radiometric_checks.hbb_std_dev_flag
        # Make a very steep function with low standard deviation.
        self.data_func = np.arctan
        self.start = -1000

    @pytest.mark.parametrize("suffix, skips, stop", (('spike', 5, 1000), ('smooth', 0, 100)))
    def test_hbb_std_dev_flag(self, suffix, skips, stop):
        self.name = f'{self.name}_{suffix}'
        self.frame['HBBviewStdDevRadiance985_990'] = self.get_data(
            stop=stop, skips=skips, middle=-1)
        self.run_test()


class TestHbbThermistorFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = state_checks.hbb_thermistor_flag

    @pytest.mark.parametrize("suffix1, base, amplitude", (('_min', 331, 1), ('_max', 335, 1), ('_mid', 333, 2)))
    @pytest.mark.parametrize("suffix2, bot, apex, top", (('_bot', True, False, False), ('_apex', False, True, False), ('_top', False, False, True), ('_all', True, True, True)))
    def test_hbb_thermistor_flag(self, suffix1, base, amplitude, suffix2, bot, apex, top):
        self.name = f'{self.name}{suffix1}{suffix2}'
        temp = self.get_data(base=base, amplitude=amplitude, skips=0)
        self.frame['HBBbottomTemp'] = temp if bot else 333
        self.frame['HBBapexTemp'] = temp if apex else 333
        self.frame['HBBtopTemp'] = temp if top else 333
        self.run_test()


class TestImaginaryRadianceFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.data_func = linear
        self.qc_func = radiometric_checks.imaginary_radiance_flag

    def test_imaginary_radiance_flag(self):
        self.name = f'{self.name}'
        self.frame['skyViewImaginaryRadiance2510_2515'] = self.get_data(
            skips=0, amplitude=2)
        self.run_test()


class TestLwResponsivityFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.start = -10*np.pi
        self.stop = 10*np.pi
        self.data_func = lambda x: np.sin(x)**2 * x**3
        self.qc_func = radiometric_checks.lw_responsivity_flag

    def test_lw_responsivity_flag(self):
        self.name = f'{self.name}'
        self.frame['LWresponsivity'] = self.get_data(skips=1)
        self.run_test()


class TestMissingDataFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = global_checks.missing_data_flag

    def test_missing_data_flag(self):
        self.name = f'{self.name}'
        self.frame['missingDataFlag'] = np.empty(
            (self.frame['datetime'].shape))
        self.frame['missingDataFlag'][::3] = 1
        self.frame['missingDataFlag'][1::3] = 0
        self.frame['missingDataFlag'][2::3] = np.nan
        self.run_test()


class TestSafingFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = scene_checks.safing_flag
        self.data_func = lambda x: 1

    def test_safing_flag(self):
        self.frame['hatchOpen'] = self.get_data(base=-3, amplitude=3, skips=1)
        self.frame['sceneMirrorPosition'] = self.get_data(amplitude=65, skips=1) + self.get_data(amplitude=72, skips=2)
        self.run_test()


class TestSkyBrightnessTempSpectralAveragesCh1Flag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = radiometric_checks.sky_brightness_temp_spectral_averages_ch1_flag
        self.base = self.base + 300

    @pytest.mark.parametrize("suffix1, amplitude", (('_lower', 20), ('_upper', -20)))
    @pytest.mark.parametrize("suffix2, surface, structure, nearBBs", (('_surface', True, False, False), ('_structure', False, True, False), ('_nearBBs', False, False, True)))
    def test_sky_brightness_temp_spectral_averages_ch1_flag(self, suffix1, amplitude, suffix2, surface, structure, nearBBs):
        self.name = f'{self.name}{suffix1}{suffix2}'
        temp = self.get_data(amplitude=amplitude, skips=0, base=self.base)
        self.frame['surfaceLayerAirTemp675_680'] = temp if surface else self.base - amplitude
        self.frame['BBsupportStructureTemp'] = temp if structure else self.base - amplitude
        self.frame['airNearBBsTemp'] = temp if nearBBs else self.base - amplitude
        self.run_test()


class TestSkyBrightnessTempSpectralAveragesCh2Flag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = radiometric_checks.sky_brightness_temp_spectral_averages_ch2_flag
        self.base = self.base + 300

    @pytest.mark.parametrize("suffix1, amplitude", (('_lower', 20), ('_upper', -20)))
    @pytest.mark.parametrize("suffix2, surface, structure, nearBBs", (('_surface', True, False, False), ('_structure', False, True, False), ('_nearBBs', False, False, True)))
    def test_sky_brightness_temp_spectral_averages_ch2_flag(self, suffix1, amplitude, suffix2, surface, structure, nearBBs):
        self.name = f'{self.name}{suffix1}{suffix2}'
        temp = self.get_data(amplitude=amplitude, skips=0, base=self.base)
        self.frame['surfaceLayerAirTemp2295_2300'] = temp if surface else self.base - amplitude
        self.frame['BBsupportStructureTemp'] = temp if structure else self.base - amplitude
        self.frame['airNearBBsTemp'] = temp if nearBBs else self.base - amplitude
        self.run_test()


class TestSwResponsivityFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.start = -10*np.pi
        self.stop = 10*np.pi
        self.data_func = lambda x: np.sin(x)**2 * x**3
        self.qc_func = radiometric_checks.sw_responsivity_flag

    def test_sw_responsivity_flag(self):
        self.frame['SWresponsivity'] = self.get_data(skips=1)
        self.run_test()


class TestCrossCorrelationCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = thermal_checks.cross_correlation_check
        self.base = self.base + 300
        self.start = -10*np.pi
        self.stop = 10*np.pi
        #self.data_func = lambda x: np.sin(x)**2 * x**3
        self.data_func = lambda x: 0

    def test_cross_correlation_check(self):
        self.name = f'{self.name}'
        temp = self.get_data(skips=0)
        self.frame['radiance'] = self.get_data(data_func=(lambda x: np.sin(x)), base=300, skips=0)
        self.frame['SCEtemp'] = self.get_data(data_func=(lambda x: 0), base=100, skips=0)
        self.parameters['radiance'] = pd.Series(data=temp, index=self.datetimes)
        self.parameters['SCEtemp'] = pd.Series(data=temp, index=self.datetimes)
        self.run_test()
        pass


class TestDetectorCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = state_checks.detector_check

    @pytest.mark.parametrize("suffix, base, amplitude", (('min', 80, 5), ('max', 90, 5), ('mid', 85, 5)))
    def test_detector_check(self, suffix, base, amplitude):
        self.name = f'{self.name}_{suffix}'
        temp = self.get_data(base=base, amplitude=amplitude, skips=0)
        self.frame['detectorTemp'] = temp
        self.run_test()


class TestDetectorTempCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = cooler_checks.detector_temp_check

    # seems to not affect
    @pytest.mark.parametrize("suffix1, freq", (('_1Hz', 1), ('_2Hz', 2), ('_3Hz', 3)))
    @pytest.mark.parametrize("suffix2, spikes", (('_no_spikes', 0), ('_medium_spikes', 3), ('_high_spikes', 5)))
    def test_detector_temp_check(self, suffix1, suffix2, spikes, freq):
        self.name = f'{self.name}{suffix1}{suffix2}'
        self.start = -freq*np.pi
        self.end = freq*np.pi
        temp = self.get_data(skips=spikes, base=.5, amplitude=.5)
        self.frame['detectorTemp'] = temp
        self.run_test()


class TestHbbLwNenCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = radiometric_checks.hbb_lw_nen_check

    @pytest.mark.parametrize("suffix, base, amplitude", (('min', 0, .15), ('max', .5, .15), ('mid', .3, .15)))
    def test_hbb_lw_nen_check(self, suffix, base, amplitude):
        self.base = self.base + base
        self.amplitude = self.amplitude * amplitude
        self.name = f'{self.name}_{suffix}'
        nen = self.get_data(skips=0)
        self.frame['LW_HBB_NEN'] = nen
        self.run_test()


class TestHbbSwNenCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = radiometric_checks.hbb_sw_nen_check

    @pytest.mark.parametrize("suffix, base, amplitude", (('min', 0, .01), ('max', .1, .08), ('mid', .015, .01)))
    def test_hbb_sw_nen_check(self, suffix, base, amplitude):
        self.base = self.base + base
        self.amplitude = self.amplitude * amplitude
        self.name = f'{self.name}_{suffix}'
        nen = self.get_data(skips=0)
        self.frame['SW_HBB_NEN'] = nen
        self.run_test()


class TestHysteresisCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = radiometric_checks.hysteresis_check

    def test_hysteresis_check(self):
        self.frame['longwaveWindowAirTemp985_990'] = self.get_data(data_func=lambda x: np.random.rand(), skips=11)
        self.frame['even_odd'] = self.get_data(data_func=lambda x: 0, amplitude=1, skips=5)
        self.run_test()
        pass


class TestSceTempDeviationCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = thermal_checks.sce_temp_deviation_check

    def test_sce_temp_deviation_check(self):
        self.frame['SCEtemp'] = self.get_data(amplitude=400)
        self.frame['timeHHMMSS'] = self.get_data(amplitude=400)
        self.run_test()
        pass


class TestSpikeCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.qc_func = igm_checks.spike_check
        #self.data_func = cubic
        self.data_func = lambda x: x*x + x

    @pytest.mark.parametrize("suffix", ('A', 'H', 'S'))
    def test_spike_check(self, suffix):
        self.name = f'{self.name}_{suffix}'
        #base_data = base_data = np.array([self.get_data(length=self.length, skips=150+ i//10, data_func=lambda x: x) for i in range(self.length)])
        #input_data = np.fft.fft(np.fft.fftshift(base_data))
        #self.parameters['ScanCounts'] = [base_data]
        #self.frame['sceneMirrorPosition'] = np.repeat(ord(suffix), self.length)
        #self.frame['hatch_flag'] = self.get_data(amplitude=0)
        #self.frame['safing_flag'] = self.get_data(amplitude=0)
        #self.frame['hysteresis_check'] = self.get_data(amplitude=0)

        #self.run_test()
        pass


class TestAggregateQcCheck(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.data_func = np.exp
        self.start = -5
        self.stop = 5
        self.qc_func = aggregate_qc_check

    def test_aggregate_qc_check(self):
        self.frame = pd.DataFrame({})
        for name in [check.name for check in get_checklist().checks if 'aggregate' not in check.name]:
            self.frame[name] = self.get_data(amplitude=.5, base=.5, skips=0)
        self.run_test()


class TestAggregateQcFlag(TestCase):
    def setup_method(self, method):
        super().setup_method(method)
        self.data_func = np.exp
        self.start = -5
        self.stop = 5
        self.qc_func = aggregate_qc_flag

    def test_aggregate_qc_flag(self):
        self.frame = pd.DataFrame({})
        for name in [check.name for check in get_checklist().checks if 'aggregate' not in check.name]:
            self.frame[name] = self.get_data(amplitude=.5, base=.5, skips=0)
        self.run_test()