Skip to content
Snippets Groups Projects
read_data.py 23.3 KiB
Newer Older
import xarray as xr
import numpy as np
import numpy.typing as npt
import ancillary_data as anc

from datetime import datetime as dt
from typing import Dict
from attrs import define, field, validators, Factory
_DTR = np.pi/180.
_RTD = 180./np.pi
_datapath = '/ships19/hercules/pveglio/mvcm_viirs_hires'
logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
# logging.basicConfig(level=logging.INFO, filename='logfile.log', 'filemode='w',
#                       format='%(name)s %(levelname)s %(message)s')
@define(kw_only=True, slots=True)
class CollectInputs(object):
    """Class that collects all input files
    file_name_geo: str
        name of VIIRS geolocation file
    file_name_l1b: str
        name of VIIRS L1b file
    ancillary_dir: str
        path for the ancillary files. (this is necessary because the C functions
        still require path and file name separately)
    sst_file: str
        file name for the Reynolds SST
    ndvi_file: str
        file name for the NDVI hdf4 file
    geos_file_before: str
        file name of the GEOS-5 containing atmospheric profiles before the VIIRS timestamp
    geos_file_after: str
        file name of the GEOS-5 containing atmospheric profiles after the VIIRS timestamp
    geos_land: str
        file name of the GEOS-5 file that has the land ice/snow information
    geos_ocean: str
        file name of the GEOS-5 file that has the ocean information
    geos_constants: str
        file name for the GEOS-5 constans
    dims [optional]: str
        name of the dimensions for the arrays in the xarray.Dataset output
    file_name_geo: str = field(default=f'{_datapath}/VNP03MOD.A2022173.1312.001.2022174012746.uwssec.nc',
                               validator=[validators.instance_of(str), ])
    file_name_l1b: str = field(default=f'{_datapath}/VNP02MOD.A2022173.1312.001.2022174011547.uwssec.nc',
                               validator=[validators.instance_of(str), ])
    ancillary_dir: str = field(default=f'{_datapath}/ancillary',
                               validator=[validators.instance_of(str), ])
    sst_file: str = field(default='oisst.20220622',
                          validator=[validators.instance_of(str), ])
    eco_file: str = field(default='goge1_2_img.v1',
                          validator=[validators.instance_of(str), ])
    ndvi_file: str = field(default='NDVI.FM.c004.v2.0.WS.00-04-177.hdf',
                           validator=[validators.instance_of(str), ])
    geos_file_1: str = field(default='GEOS.fpit.asm.inst3_2d_asm_Nx.GEOS5124.20220622_1200.V01.nc4',
                             validator=[validators.instance_of(str), ])
    geos_file_2: str = field(default='GEOS.fpit.asm.inst3_2d_asm_Nx.GEOS5124.20220622_1500.V01.nc4',
                             validator=[validators.instance_of(str), ])
    geos_land: str = field(default='GEOS.fpit.asm.tavg1_2d_lnd_Nx.GEOS5124.20220622_1330.V01.nc4',
                           validator=[validators.instance_of(str), ])
    geos_ocean: str = field(default='GEOS.fpit.asm.tavg1_2d_ocn_Nx.GEOS5124.20220622_1330.V01.nc4',
                            validator=[validators.instance_of(str), ])
    geos_constants: str = field(default='GEOS.fp.asm.const_2d_asm_Nx.00000000_0000.V01.nc4',
                                validator=[validators.instance_of(str), ])

    dims: tuple = field(default=('number_of_lines', 'number_of_pixels'),
                        validator=[validators.instance_of(tuple), ])
@define(slots=True, kw_only=True)
class ReadData(CollectInputs):
    """Class that reads the L1b/geolocation data from VIIRS. Inherits file names from CollectInputs
    Parameters
    ----------
    satellite: str
        satellite name.
    sensor: str
        sensor name
    """
    satellite: str = field(validator=[validators.instance_of(str),
                                      validators.in_(['snpp'])])
    sensor: str = field(validator=[validators.instance_of(str),
                                   validators.in_(['viirs'])])

    logging.debug('Instance of ReadData created')

    def read_viirs_geo(self) -> xr.Dataset:
        """Read VIIRS geolocation data and generate additional angles

        Parameters
        ----------
        self.file_name_geo: str
            file name of the geolocation file

        Returns
        -------
        geo_data xarray.Dataset
            dataset containing all geolocation data
        """
        logging.debug(f'Reading {self.file_name_geo}')
        geo_data = xr.open_dataset(self.file_name_geo, group='geolocation_data')

        relazi = self.relative_azimuth_angle(geo_data.sensor_azimuth.values, geo_data.solar_azimuth.values)
        sunglint = self.sun_glint_angle(geo_data.sensor_zenith.values, geo_data.solar_zenith.values, relazi)
        scatt_angle = self.scattering_angle(geo_data.solar_zenith.values, geo_data.sensor_zenith.values,
                                            relazi)

        geo_data['relative_azimuth'] = (self.dims, relazi)
        geo_data['sunglint_angle'] = (self.dims, sunglint)
        geo_data['scattering_angle'] = (self.dims, scatt_angle)

        logging.debug('Geolocation file read correctly')

        return geo_data

    def read_viirs_l1b(self,
                       solar_zenith: npt.NDArray[float]) -> xr.Dataset:
        """Read VIIRS L1b data

        Parameters
        ----------
        self.file_name_l1b: str
            file name of the L1b file
        solar_zenith: np.ndarray
            solar zenith angle derived from the geolocation file
        """
        logging.debug(f'Reading {self.file_name_l1b}')
        l1b_data = xr.open_dataset(self.file_name_l1b, group='observation_data', decode_cf=False)

        rad_data = xr.Dataset()
        for band in list(l1b_data.variables):
            if 'reflectance' in l1b_data[band].long_name:
                if hasattr(l1b_data[band], 'VCST_scale_factor'):
                    scale_factor = l1b_data[band].VCST_scale_factor * l1b_data[band].bias_correction
                else:
                    scale_factor = l1b_data[band].scale_factor
                rad_data[band] = (self.dims, l1b_data[band].values * scale_factor / np.cos(solar_zenith*_DTR))

            elif 'radiance' in l1b_data[band].long_name:
                bt_lut = f'{band}_brightness_temperature_lut'
                rad_data[band] = (self.dims,
                                  l1b_data[bt_lut].values[l1b_data[band].values])
        logging.debug('L1b file read correctly')
    def preprocess_viirs(self,
                         geo_data: xr.Dataset,
                         viirs: xr.Dataset) -> xr.Dataset:
        """Create band combinations (e.g. ratios, differences) that are used by some spectral tests.

        Parameters
        ----------
        geo_data: xarray.Dataset
            dataset containing all geolocation data
        viirs: xr.Dataset
            VIIRS L1b data

        Returns
        -------
        viirs: xr.Dataset
            VIIRS L1b data plus band combinations for the spectral tests
        """
        if (('M05' in viirs) and ('M07' in viirs)):
            m01 = viirs.M05.values
            m02 = viirs.M07.values
            r1 = 2.0 * (np.power(m02, 2.0) - np.power(m01, 2.0)) + (1.5 * m02) + (0.5 * m01)
            r2 = m01 + m02 + 0.5
            r3 = r1 / r2
            gemi = r3 * (1.0 - 0.25*r3) - ((m01 - 0.125) / (1.0 - m01))
        else:
            gemi = np.full((viirs.M15.shape), _bad_data)

        """
        if 'M05' in viirs:
            idx = np.nonzero((viirs.M05.values < -99) | (viirs.M05.values > 2))
            viirs['M05'].values[idx] = _bad_data
        else:
            viirs['M05'] = (self.dims, np.full(viirs.M15.shape, _bad_data))
        if 'M07' in viirs:
            idx = np.nonzero((viirs.M07.values < -99) | (viirs.M07.values > 2))
            viirs['M07'].values[idx] = _bad_data
        else:
            viirs['M07'] = (self.dims, np.full(viirs.M15.shape, _bad_data))

        idx = np.nonzero((viirs.M12.values < 0) | (viirs.M12.values > 1000))
        viirs['M12'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M13.values < 0) | (viirs.M13.values > 1000))
        viirs['M13'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M14.values < 0) | (viirs.M14.values > 1000))
        viirs['M14'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M15.values < 0) | (viirs.M15.values > 1000))
        viirs['M15'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M16.values < 0) | (viirs.M16.values > 1000))
        viirs['M16'].values[idx] = _bad_data
        """

        # Compute channel differences and ratios that are used in the tests
        if (('M05' in viirs) and ('M07' in viirs)):
            viirs['M07-M05ratio'] = (self.dims, viirs.M07.values / viirs.M05.values)
        else:
            viirs['M07-M05ratio'] = (self.dims, np.full(viirs.M15.shape, _bad_data))
        viirs['M13-M16'] = (self.dims, viirs.M13.values - viirs.M16.values)
        viirs['M14-M15'] = (self.dims, viirs.M14.values - viirs.M15.values)
        viirs['M15-M12'] = (self.dims, viirs.M15.values - viirs.M12.values)
        viirs['M15-M13'] = (self.dims, viirs.M15.values - viirs.M13.values)
        viirs['M15-M16'] = (self.dims, viirs.M15.values - viirs.M16.values)
        viirs['GEMI'] = (self.dims, gemi)

        # temp value to force the code to work
        viirs['M128'] = (self.dims, np.zeros(viirs.M15.shape))

        viirs.update(geo_data)
        viirs = viirs.set_coords(['latitude', 'longitude'])

        logging.debug('Viirs preprocessing completed successfully.')
        return viirs

    def relative_azimuth_angle(self,
                               sensor_azimuth: npt.NDArray[float],
                               solar_azimuth: npt.NDArray[float]) -> npt.NDArray[float]:
        """Computation of the relative azimuth angle
        Parameters
        ----------
        sensor_azimuth: np.ndarray
            sensor azimuth angle from the geolocation file
        solar_azimuth: np.ndarray
            solar azimuth angle from the geolocation file
        Returns
        -------
        relative_azimuth: np.ndarray
        """
        rel_azimuth = np.abs(180. - np.abs(sensor_azimuth - solar_azimuth))
        logging.debug('Relative azimuth calculated successfully.')
        return rel_azimuth
    def sun_glint_angle(self,
                        sensor_zenith: npt.NDArray[float],
                        solar_zenith: npt.NDArray[float],
                        rel_azimuth: npt.NDArray[float]) -> npt.NDArray[float]:
        """Computation of the sun glint angle
        Parameters
        ----------
        sensor_zenith: np.ndarray
            sensor zenith angle from the geolocation file
        solar_zenith: np.ndarray
            solar zenith angle from the geolocation file
        relative_azimuth: np.ndarray
            relative azimuth computed from function relative_azimuth_angle()
        Returns
        -------
        sunglint_angle: np.ndarray
        """
        cossna = (np.sin(sensor_zenith*_DTR) * np.sin(solar_zenith*_DTR) * np.cos(rel_azimuth*_DTR) +
                  np.cos(sensor_zenith*_DTR) * np.cos(solar_zenith*_DTR))
        cossna[cossna > 1] = 1
        sunglint_angle = np.arccos(cossna) * _RTD
        logging.debug('Sunglint generated')
        return sunglint_angle

    def scattering_angle(self,
                         solar_zenith: npt.NDArray[float],
                         sensor_zenith: npt.NDArray[float],
                         relative_azimuth: npt.NDArray[float]) -> npt.NDArray[float]:
        """Computation of the scattering angle
        Parameters
        ----------
        solar_zenith: np.ndarray
            solar zenith angle from the geolocation file
        sensor_zenith: np.ndarray
            sensor zenith angle angle from the geolocation file
        relative_azimuth: np.ndarray
            relative azimuth computed from function relative_azimuth_angle()
        Returns
        -------
        scattering_angle: np.ndarray
        """
        cos_scatt_angle = -1. * (np.cos(solar_zenith*_DTR) * np.cos(sensor_zenith*_DTR) -
                                 np.sin(solar_zenith*_DTR) * np.sin(sensor_zenith*_DTR) *
                                 np.cos(relative_azimuth*_DTR))

        scatt_angle = np.arccos(cos_scatt_angle) * _RTD

        logging.debug('Scattering angle calculated successfully')

        return scatt_angle


@define(kw_only=True, slots=True)
class ReadAncillary(CollectInputs):
    """Class tha processes all the ancillary files and makes them available for the MVCM"
    latitude: npt.NDArray[float]
        array of latitudes for the granule that is being processed
    longitude: npt.NDArray[float]
        array of longitudes for the granule that is being processed
    resolution: int
        flag to switch between MOD (1) and IMG (2) resolution
    latitude: npt.NDArray[float] = field(validator=[validators.instance_of(np.ndarray), ])
    longitude: npt.NDArray[float] = field(validator=[validators.instance_of(np.ndarray), ])
    resolution: int = field(validator=[validators.instance_of(int),
                            validators.in_([1, 2]), ])

    out_shape: tuple = field(init=False,
                             default=Factory(lambda self: self.latitude.shape, takes_self=True))
    def get_granule_time(self):
        """Get granule timestamp and format it for MVCM"""
        vnp_time = '.'.join(os.path.basename(self.file_name_l1b).split('.')[1:3])
        return dt.strftime(dt.strptime(vnp_time, 'A%Y%j.%H%M'), '%Y-%m-%d %H:%M:%S.000')

    def get_sst(self) -> npt.NDArray[float]:
        """Read Reynolds SST file

        Parameters
        ----------
        Returns
        -------
        sst: npt.NDArray[float]
            array containing the Reynolds SST interpolated at the sensor's resolution
        """
        if not os.path.isfile(os.path.join(self.ancillary_dir, self.sst_file)):
            logging.error('SST file not found')
        sst = np.empty(self.out_shape, dtype=np.float32).ravel()
        sst = anc.py_get_Reynolds_SST(self.latitude.ravel(), self.longitude.ravel(), self.resolution,
                                      self.ancillary_dir, self.sst_file, sst)
        logging.debug('SST file read successfully')
        return sst.reshape(self.out_shape)

    def get_ndvi(self) -> npt.NDArray[float]:
        """Read NDVI file

        Parameters
        ----------

        Returns
        -------
        ndvi: npt.NDArray[float]
            NDVI interpolated at the sensor's resolution
        """
        if not os.path.isfile(os.path.join(self.ancillary_dir, self.ndvi_file)):
            logging.error('NDVI file not found')
        ndvi = np.empty(self.out_shape, dtype=np.float32).ravel()
        ndvi = anc.py_get_NDVI_background(self.latitude.ravel(), self.longitude.ravel(), self.resolution,
                                          self.ancillary_dir, self.ndvi_file, ndvi)
        logging.debug('NDVI file read successfully')
        return ndvi.reshape(self.out_shape)

    def get_eco(self) -> npt.NDArray[float]:
        """Read ECO file

        Parameters
        ----------

        Returns
        -------
        eco: npt.NDArray[float]
            Olson ecosystem type interpolated at the sensor's resolution
        """
        eco = np.empty(self.out_shape, dtype=np.ubyte).ravel()
        eco = anc.py_get_Olson_eco(self.latitude.ravel(), self.longitude.ravel(), self.resolution,
                                   self.ancillary_dir, eco)
        logging.debug('Olson ecosystem file read successfully')
        return eco.reshape(self.out_shape)
    def get_geos(self) -> Dict:
        """Read GEOS-5 data and interpolate the fields to the sensor resolution

        Parameters
        ----------

        Returns
        -------
        geos_data: Dict
            dictionary containing all quantities required by MVCM (see geos_variables here below)
        """
        if not os.path.isfile(os.path.join(self.ancillary, self.geos_file_1)):
            logging.error('GEOS-5 file 1 not found')
        if not os.path.isfile(os.path.join(self.ancillary, self.geos_file_2)):
            logging.error('GEOS-5 file 2 not found')
        if not os.path.isfile(os.path.join(self.ancillary, self.geos_land_file)):
            logging.error('GEOS-5 land file not found')
        if not os.path.isfile(os.path.join(self.ancillary, self.geos_ocean_file)):
            logging.error('GEOS-5 ocean file not found')
        if not os.path.isfile(os.path.join(self.ancillary, self.geos_constants_file)):
            logging.error('GEOS-5 constants file not found')

        geos_variables = ['tpw', 'snow_fraction', 'ice_fraction', 'ocean_fraction',
                          'land_ice_fraction', 'surface_temperature']
        geos_data = {var: np.empty(self.out_shape, dtype=np.float32).ravel() for var in geos_variables}

        geos_data = anc.py_get_GEOS(self.latitude.ravel(), self.longitude.ravel(), self.resolution,
                                    self.get_granule_time(), self.ancillary_dir, self.geos_file_1,
                                    self.geos_file_2, self.geos_land, self.geos_ocean,
                                    self.geos_constants, geos_data)

        for var in list(geos_variables):
            geos_data[var] = geos_data[var].reshape(self.out_shape)

        logging.debug('GEOS data read successfully')
        return geos_data
    def pack_data(self) -> xr.Dataset:
        """Pack all the ancillary data into a single dataset
        Parameters
        ----------
        Returns
        -------
        ancillary_data: xr.Dataset
            dataset containing all the ancillary data
        """
        ancillary_data = xr.Dataset()
        ancillary_data['sst'] = (self.dims, self.get_sst())
        ancillary_data['ecosystem'] = (self.dims, self.get_eco())
        ancillary_data['ndvi'] = (self.dims, self.get_ndvi())
        geos_tmp = self.get_geos()
        for var in list(geos_tmp.keys()):
            ancillary_data[f'geos_{var}'] = (self.dims, geos_tmp[var])
        return ancillary_data
    scene_flags = scn.find_scene(viirs, sunglint_angle)
    scene = scn.scene_id(scene_flags)

    scene_xr = xr.Dataset()
    for s in scn._scene_list:
        scene_xr[s] = (('number_of_lines', 'number_of_pixels'), scene[s])

    scene['lat'] = viirs.latitude
    scene['lon'] = viirs.longitude

    data = xr.Dataset(viirs, coords=scene_xr)
    data.drop_vars(['latitude', 'longitude'])
def get_data(satellite: str,
             sensor: str,
             file_names: Dict[str, str],
             sunglint_angle: float,
             hires: bool = False) -> xr.Dataset:

    mod02 = file_names['MOD02']
    mod03 = file_names['MOD03']

    if hires is True:
        img02 = file_names['IMG02']
        img03 = file_names['IMG03']
        viirs = read_viirs(sensor, f'{mod02}', f'{mod03}')
        viirs = read_ancillary_data(file_names, viirs)
        if (('M05' in viirs) and ('M07' in viirs)):
            m01 = viirs.M05.values
            m02 = viirs.M07.values
            r1 = 2.0 * (np.power(m02, 2.0) - np.power(m01, 2.0)) + (1.5 * m02) + (0.5 * m01)
            r2 = m02 + m01 + 0.5
            r3 = r1 / r2
            gemi = r3 * (1.0 - 0.25*r3) - ((m01 - 0.125) / (1.0 - m01))
        else:
            gemi = np.full((viirs.M15.shape), _bad_data)
        if 'M05' in viirs:
            idx = np.nonzero((viirs.M05.values < -99) | (viirs.M05.values > 2))
            viirs['M05'].values[idx] = _bad_data
            viirs['M05'] = (('number_of_lines', 'number_of_pixels'),
                                 np.full(viirs.M15.shape, _bad_data))
        if 'M07' in viirs:
            idx = np.nonzero((viirs.M07.values < -99) | (viirs.M07.values > 2))
            viirs['M07'].values[idx] = _bad_data
            viirs['M07'] = (('number_of_lines', 'number_of_pixels'),
                                 np.full(viirs.M15.shape, _bad_data))

        idx = np.nonzero((viirs.M12.values < 0) | (viirs.M12.values > 1000))
        viirs['M12'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M13.values < 0) | (viirs.M13.values > 1000))
        viirs['M13'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M14.values < 0) | (viirs.M14.values > 1000))
        viirs['M14'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M15.values < 0) | (viirs.M15.values > 1000))
        viirs['M15'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M16.values < 0) | (viirs.M16.values > 1000))
        viirs['M16'].values[idx] = _bad_data
        # Compute channel differences and ratios that are used in the tests
        viirs['M15-M13'] = (('number_of_lines', 'number_of_pixels'),
                                 viirs.M15.values - viirs.M13.values)
        viirs['M14-M15'] = (('number_of_lines', 'number_of_pixels'),
                                 viirs.M14.values - viirs.M15.values)
        viirs['M15-M16'] = (('number_of_lines', 'number_of_pixels'),
                                 viirs.M15.values - viirs.M16.values)
        viirs['M15-M12'] = (('number_of_lines', 'number_of_pixels'),
                                 viirs.M15.values - viirs.M12.values)
        viirs['M13-M16'] = (('number_of_lines', 'number_of_pixels'),
                                 viirs.M13.values - viirs.M16.values)
        viirs['M07-M05ratio'] = (('number_of_lines', 'number_of_pixels'),
                                      viirs.M07.values / viirs.M05.values)
        viirs['GEMI'] = (('number_of_lines', 'number_of_pixels'), gemi)

        # temp value to force the code to work
        viirs['M128'] = (('number_of_lines', 'number_of_pixels'), np.zeros(viirs.M15.shape))
        viirs = read_viirs('viirs', f'{img02}', f'{img03}')
        viirs['M05'] = viirs.I01
        viirs['M07'] = viirs.I02

        idx = np.nonzero((viirs.M05.values < -99) | (viirs.M05.values > 2))
        viirs['M05'].values[idx] = _bad_data
        idx = np.nonzero((viirs.M07.values < -99) | (viirs.M07.values > 2))
        viirs['M07'].values[idx] = _bad_data

        idx = np.nonzero((viirs.I01.values < 0) | (viirs.I01.values > 1000))
        viirs['I01'].values[idx] = _bad_data
        idx = np.nonzero((viirs.I02.values < 0) | (viirs.I02.values > 1000))
        viirs['I02'].values[idx] = _bad_data
        idx = np.nonzero((viirs.I03.values < 0) | (viirs.I03.values > 1000))
        viirs['I03'].values[idx] = _bad_data
        idx = np.nonzero((viirs.I04.values < 0) | (viirs.I04.values > 1000))
        viirs['I04'].values[idx] = _bad_data
        idx = np.nonzero((viirs.I05.values < 0) | (viirs.I05.values > 1000))
        viirs['I05'].values[idx] = _bad_data

        viirs = read_ancillary_data(file_names, viirs, resolution=2)

        viirs['I05-I04'] = (('number_of_lines_2', 'number_of_pixels_2'),
                                 viirs.I05.values - viirs.I04.values)
        viirs['I02-I01ratio'] = (('number_of_lines_2', 'number_of_pixels_2'),
                                      viirs.I02.values / viirs.I01.values)

    scene_flags = scn.find_scene(viirs, sunglint_angle)
    scene = scn.scene_id(scene_flags)

    scene_xr = xr.Dataset()
    for s in scn._scene_list:
        scene_xr[s] = (('number_of_lines', 'number_of_pixels'), scene[s])

    scene['lat'] = viirs.latitude
    scene['lon'] = viirs.longitude
    data = xr.Dataset(viirs, coords=scene_xr)
    data.drop_vars(['latitude', 'longitude'])

    return data