Skip to content
Snippets Groups Projects
test_utilities.py 5.77 KiB
Newer Older
"""Test functions for utility_functions module."""
Paolo Veglio's avatar
Paolo Veglio committed

import numpy as np
import pytest

import mvcm.utility_functions as utils


Paolo Veglio's avatar
Paolo Veglio committed
# FIXTURES
@pytest.fixture
Paolo Veglio's avatar
Paolo Veglio committed
def test_arr() -> np.ndarray:
    """Define reference input array for spatial_var() function test."""
    arr = np.array(
        [
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
            [1, 1, 1, 1, 8, 8, 8, 1, 1, 1],
            [5, 5, 5, 5, 7, 9, 8, 5, 5, 5],
            [1, 1, 1, 1, 8, 8, 7, 1, 1, 1],
            [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
            [9, 8, 7, 1, 1, 1, 1, 1, 1, 1],
            [7, 8, 8, 5, 5, 5, 5, 5, 5, 5],
            [8, 9, 8, 1, 1, 1, 1, 1, 1, 1],
            [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
        ],
        dtype=float,
    )
    return arr


Paolo Veglio's avatar
Paolo Veglio committed
@pytest.fixture
def ref_mean() -> np.ndarray:
    """Define reference mean array for local_nxn_mean() function test."""
    arr = np.array(
        [
            [1.33333333, 2.0, 2.0, 2.0, 1.33333333],
            [1.55555556, 2.33333333, 2.33333333, 3.11111111, 2.33333333],
            [2.44444444, 3.66666667, 3.66666667, 4.66666667, 3.44444444],
            [1.55555556, 2.33333333, 2.33333333, 4.11111111, 3.33333333],
            [1.33333333, 2.0, 2.0, 3.0, 2.33333333],
        ],
        dtype=float,
    )
    return arr


@pytest.fixture
def ref_std() -> np.ndarray:
    """Define reference std array for local_nxn_standard_deviation() function test."""
    arr = np.array(
        [
            [2.12132034, 2.29128785, 2.29128785, 2.29128785, 2.12132034],
            [2.00693243, 2.0, 2.0, 2.66666667, 2.91547595],
            [2.45515331, 2.0, 2.0, 2.34520788, 3.20589734],
            [2.00693243, 2.0, 2.0, 3.14024061, 3.60555128],
            [2.12132034, 2.29128785, 2.29128785, 3.24037035, 3.35410197],
        ],
        dtype=float,
    )
    return arr


@pytest.fixture
def ref_spatial_var():
    """Define reference output array for spatial_var() function test."""
    arr = np.array(
        [
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        ],
        dtype=float,
    )
    return arr


@pytest.fixture
Paolo Veglio's avatar
Paolo Veglio committed
def b01() -> np.ndarray:
    """Define band 1 for NDVI test."""
    return np.array([0.1, 0.2, 0.3, 0.5])


@pytest.fixture
def b02() -> np.ndarray:
    """Define band 2 for NDVI test."""
    return np.array([0.3, 0.3, 0.5, 0.5])


@pytest.fixture
def ref_ndvi() -> np.ndarray:
    """Define expected NDVI output."""
    return np.array([0.5, 0.2, 0.25, 0.0])


@pytest.fixture
def thresholds() -> dict:
    """Define thresholds for spatial_var() function test."""
    return {"thr": 2}


Paolo Veglio's avatar
Paolo Veglio committed
@pytest.fixture
def height() -> np.ndarray:
    """Define the height array for testing the elevation correction function."""
    return np.array([0, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000])


@pytest.fixture
def ref_correction() -> np.ndarray:
    """Define the expected result of the elevation correction function."""
    return np.array([0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20])


@pytest.fixture
def ref_bits() -> dict:
Paolo Veglio's avatar
Paolo Veglio committed
    """Define dymmy bit values to test the group_counts() function."""
    return {
        "01": np.array([1, 0, 0, 1, 1, 1]),
        "02": np.array([0, 1, 0, 0, 1, 1]),
        "03": np.array([1, 1, 0, 1, 0, 1]),
        "04": np.array([0, 0, 1, 1, 1, 1]),
        "05": np.array([0, 0, 0, 0, 0, 1]),
        "06": np.array([0, 1, 1, 1, 1, 1]),
        "07": np.array([1, 0, 1, 0, 1, 1]),
        "08": np.array([1, 0, 1, 0, 0, 0]),
        "09": np.array([0, 0, 1, 1, 0, 0]),
        "10": np.array([1, 1, 0, 1, 1, 0]),
        "11": np.array([0, 1, 0, 1, 1, 1]),
        "12": np.array([1, 0, 0, 1, 0, 0]),
        "13": np.array([1, 0, 1, 1, 0, 1]),
        "14": np.array([1, 0, 0, 1, 0, 1]),
        "15": np.array([0, 1, 1, 1, 1, 0]),
        "16": np.array([0, 0, 1, 0, 1, 0]),
    }


@pytest.fixture
def ref_group_count() -> np.ndarray:
    """Test group_count() function."""
    return np.array([3, 4, 4, 4, 5, 3])


# TESTS
def test_local_nxn_mean(test_arr: np.ndarray, ref_mean: np.ndarray) -> None:
    """Test local nxn mean."""
    mean = utils.local_nxn_mean(test_arr[:5, :5], 3, 3)
    assert np.allclose(ref_mean, mean)


def test_local_nxn_standard_deviation(test_arr: np.ndarray, ref_std: np.ndarray) -> None:
    """Test local nxn standard deviation."""
    std = utils.local_nxn_standard_deviation(test_arr[:5, :5], 3, 3)
    assert np.allclose(ref_std, std)


def test_spatial_var(test_arr: np.ndarray, thresholds: dict, ref_spatial_var: np.ndarray) -> None:
    """Test spatial_var function."""
Paolo Veglio's avatar
Paolo Veglio committed
    var = utils.spatial_var(test_arr, thresholds["thr"], vis=True)
    assert np.allclose(np.floor(var), ref_spatial_var)
Paolo Veglio's avatar
Paolo Veglio committed


def test_compute_ndvi(b01: np.ndarray, b02: np.ndarray, ref_ndvi: np.ndarray) -> None:
    """Test compute NDVI."""
    ndvi = utils.compute_ndvi(b01, b02)
    assert np.allclose(ndvi, ref_ndvi)
    pass


def test_bt11_elevation_correction(height: np.ndarray, ref_correction: np.ndarray) -> None:
    """Test 11um BT elevation correction computation."""
    correction = utils.bt11_elevation_correction(height)
    assert np.allclose(correction, ref_correction)


def test_group_count(ref_bits: dict) -> None:
Paolo Veglio's avatar
Paolo Veglio committed
    """Test group count."""
    assert np.allclose(utils.group_count(ref_bits), ref_group_count)


def test_restoral_flag(bits: np.ndarray) -> None:
    """Test computation of restoral flags."""
    pass


def test_combine_indices(index1: tuple, index2: tuple, shape: tuple) -> None:
    """Test combine indices."""
    pass