import numpy as np
import ruamel_yaml as yml

from glob import glob

import read_data as rd
import ancillary_data as anc

# lsf: land sea flag
_scene_list = ['Ocean_Day', 'Ocean_Night', 'Land_Day', 'Land_Night', 'Snow_Day', 'Snow_Night', 'Coast_Day',
               'Desert_Day', 'Antarctic_Day', 'Polar_Day_Snow', 'Polar_Day_Desert', 'Polar_Day_Ocean',
               'Polar_Day_Desert_Coast', 'Polar_Day_Coast', 'Polar_Day_Land', 'Polar_Night_Snow',
               'Polar_Night_Land', 'Polar_Night_Ocean', 'Land_Day_Desert_Coast', 'Land_Day_Coast', 'Desert']
_flags = ['day', 'night', 'land', 'coast', 'sh_lake', 'sh_ocean', 'water', 'polar', 'sunglint',
          'greenland', 'high_elevation', 'antarctica', 'desert', 'visusd', 'vrused', 'map_snow', 'map_ice',
          'ndsi_snow', 'snow', 'ice', 'new_zealand', 'uniform']

# temp value, need to verify what the actual bad_data value is in the C code
_bad_data = -999.0

_dtr = np.pi/180.
_rtd = 180./np.pi

# I'm defining here the flags for difference scenes. Eventually I want to find a better way of doing this
land = 1
# coast = .2
sh_lake = .3
sh_ocean = .4
water = 5
polar = 60
sunglint = 70
day = 100
night = 200

# #################################################################### #
# TEST CASE
# data:
datapath = '/ships19/hercules/pveglio/mvcm_viirs_hires'
fname_mod02 = glob(f'{datapath}/VNP02MOD.A2022173.1448.001.*.uwssec.nc')[0]
fname_mod03 = glob(f'{datapath}/VNP03MOD.A2022173.1448.001.*.uwssec.nc')[0]
fname_img02 = glob(f'{datapath}/VNP02IMG.A2022173.1448.001.*.uwssec.nc')[0]
fname_img03 = glob(f'{datapath}/VNP03IMG.A2022173.1448.001.*.uwssec.nc')[0]

# thresholds:
threshold_file = '/home/pveglio/mvcm_leo/thresholds/new_thresholds.mvcm.snpp.v1.0.0.yaml'

# ancillary files:
geos_atm_1 = 'GEOS.fpit.asm.inst3_2d_asm_Nx.GEOS5124.20220622_1500.V01.nc4'
geos_atm_2 = 'GEOS.fpit.asm.inst3_2d_asm_Nx.GEOS5124.20220622_1800.V01.nc4'
geos_land = 'GEOS.fpit.asm.tavg1_2d_lnd_Nx.GEOS5124.20220622_1630.V01.nc4'
geos_ocean = 'GEOS.fpit.asm.tavg1_2d_ocn_Nx.GEOS5124.20220622_1630.V01.nc4'
geos_constants = 'GEOS.fp.asm.const_2d_asm_Nx.00000000_0000.V01.nc4'
ndvi_file = 'NDVI.FM.c004.v2.0.WS.00-04.177.hdf'
sst_file = 'oisst.20220622'
eco_file = 'goge1_2_img.v1'

# #################################################################### #


def test_scene():

    ancillary_file_names = {'GEOS_atm_1': geos_atm_1,
                            'GEOS_atm_2': geos_atm_2,
                            'GEOS_land': geos_land,
                            'GEOS_ocean': geos_ocean,
                            'GEOS_constants': geos_constants,
                            'NDVI': ndvi_file,
                            'SST': sst_file,
                            'ANC_DIR': f'{datapath}/ancillary'
                            }

    viirs_data = rd.read_data('viirs', f'{fname_mod02}', f'{fname_mod03}')

    viirs_data = rd.read_ancillary_data(ancillary_file_names, viirs_data)

    with open(threshold_file) as f:
        text = f.read()
    thresholds = yml.safe_load(text)

    sunglint_angle = thresholds['Sun_Glint']['bounds'][3]

    scene = find_scene(viirs_data, sunglint_angle)

    return scene


def find_scene(data, sunglint_angle):
    eco = np.array(data['eco'].values, dtype=np.uint8)
    snowfr = data['geos_snowfr'].values
    icefr = data['geos_icefr'].values
    lsf = np.array(data['land_water_mask'].values, dtype=np.uint8)
    lat = data['latitude'].values
    lon = data['longitude'].values
    sza = data['solar_zenith'].values
    vza = data['sensor_zenith'].values
    raz = data['relative_azimuth'].values
    b065 = data['M05'].values
    b086 = data['M07'].values
    elev = data['height'].values
    ndvibk = data['ndvi'].values
    sfctmp = data['geos_sfct'].values

    dim1 = data.latitude.shape[0]
    dim2 = data.latitude.shape[1]

    day = np.zeros((dim1, dim2))
    day[sza <= 85] = 1
    cos_refang = np.sin(vza*_dtr) * np.sin(sza*_dtr) * np.cos(raz*_dtr) + np.cos(vza*_dtr) * np.cos(sza*_dtr)
    refang = np.arccos(cos_refang) * _rtd

    # tmp = np.ones((dim1, dim2))
    # tmp[day == 1] = day
    # tmp[day == 0] = night

    scene_flag = {flg: np.zeros((dim1, dim2)) for flg in _flags}

    scene_flag['day'][sza <= 85] = 1
    scene_flag['visusd'][sza <= 85] = 1
    scene_flag['night'][sza > 85] = 1

    scene_flag['polar'][np.abs(lat) > 60] = 1
    # ################# need to pass refang (once I figure out what it is) and sunglint_angle. The latter
    # comes from the thresholds file. In the C code is snglnt_bounds[3]
    idx = np.nonzero((scene_flag['day'] == 1) & (refang <= sunglint_angle))
    scene_flag['sunglint'][idx] = 1

    # Force consistency between lsf and ecosystem type for water
    idx = np.nonzero((lsf == 0) | (lsf >= 5) & (lsf < 7))
    eco[idx] = 14

    # start by defining anythings as land
    scene_flag['land'] = np.ones((dim1, dim2))
    scene_flag['water'] = np.zeros((dim1, dim2))

    # Fix-up for missing ecosystem data in eastern Greenland and north-eastern Siberia.
    # Without this, these regions become completely "coast".
    idx = np.nonzero((lsf != 255) & (lsf == 1) | (lsf == 4))
    scene_flag['land'][idx] = 1

    # idx = np.nonzero((lsf != 255) & (eco == 14))

    idx = np.nonzero((lsf != 255) & (eco == 14) & (lsf == 1) | (lsf == 4) & (lat < 64.0))
    scene_flag['coast'][idx] = 1

    idx = np.nonzero((lsf != 255) & (eco == 14) & (lsf == 1) | (lsf == 4) &
                     (lat >= 67.5) & (lon < -40.0) & (lon > -168.6))
    scene_flag['coast'][idx] = 1
    idx = np.nonzero((lsf != 255) & (eco == 14) & (lsf == 1) | (lsf == 4) &
                     (lat >= 67.5) & (lon > -12.5))
    scene_flag['coast'][idx] = 1

    idx = np.nonzero((lsf != 255) & (eco == 14) & (lsf == 1) | (lsf == 4) &
                     (lat >= 64.0) & (lat < 67.5) & (lon < -40.0) & (lon > -168.5))
    scene_flag['coast'][idx] = 1
    idx = np.nonzero((lsf != 255) & (eco == 14) & (lsf == 1) | (lsf == 4) &
                     (lat >= 64.0) & (lat < 67.5) & (lon > -30.0))
    scene_flag['coast'][idx] = 1

    idx = np.nonzero(lsf == 2)
    scene_flag['coast'][idx] = 1
    scene_flag['land'][idx] = 1

    idx = np.nonzero(lsf == 3)
    scene_flag['land'][idx] = 1
    scene_flag['sh_lake'][idx] = 1

    idx = np.nonzero((lsf == 0) | (lsf >= 5) & (lsf <= 7))
    scene_flag['water'][idx] = 1
    scene_flag['land'][idx] = 0

    scene_flag['sh_ocean'][lsf == 0] = 1

    # Need shallow lakes to be processed as "coast" for day, but not night
    idx = np.nonzero((lsf == 3) & (day == 1))
    scene_flag['coast'][idx] = 1

    # if land/sea flag is missing, then calculate visible ratio to determine if land or water.
    idx = np.nonzero((lsf == 255) & (b065 != _bad_data) & (b086 != _bad_data) & (b086/b065 > 0.9))
    scene_flag['land'][idx] = 1

    idx = np.nonzero((lsf == 255) & (b065 != _bad_data) & (b086 != _bad_data) & (b086/b065 <= 0.9))
    scene_flag['land'][idx] = 0
    scene_flag['water'][idx] = 1

    # Check surface elevation
    # First, define "Greenland".
    idx = np.nonzero((scene_flag['land'] == 1) &
                     (lat >= 60.0) & (lat < 67.0) & (lon >= -60.0) & (lon < -30.0))
    scene_flag['greenland'][idx] = 1

    idx = np.nonzero((scene_flag['land'] == 1) &
                     (lat >= 67.0) & (lat < 75.0) & (lon >= -60.0) & (lon < -10.0))
    scene_flag['greenland'][idx] = 1

    idx = np.nonzero((scene_flag['land'] == 1) &
                     (lat >= 75.0) & (lon >= -70.0) & (lon < -10.0))
    scene_flag['greenland'][idx] = 1

    scene_flag['high_elevation'][elev > 2000] = 1
    idx = np.nonzero((elev > 200) & (scene_flag['greenland'] == 1) & (scene_flag['land'] == 1))
    scene_flag['high_elevation'][idx] = 1

    idx = np.nonzero((lat >= 75.0) & (lat <= 79.0) & (lon >= -73.0) & (lon <= -50.0) &
                     (scene_flag['land'] == 1))
    scene_flag['high_elevation'][idx] = 1

    scene_flag['antarctica'][lat < -60.0] = 1

    ##################################
    #  somewhere here I need to add  #
    #  the 11um elevation correction #
    ##################################

    # this is a temporary variable for the 11um elevation correction
    elev_correction = elev/1000.0 * 5.0

    ## Get surface temperature from NWP and SST fields
    ## if it's land get it from GDAS/GEOS5
    #sfctmp[scene_flag['land'] == 1] = sfct
    ## otherwise use the ReynoldsSST
    #sfctmp[scene_flag['land'] == 0] = reynSST

    # Use background NDVI to define "desert"
    idx = np.nonzero((scene_flag['land'] == 1) & (ndvibk < 0.3))
    scene_flag['desert'][idx] = 1
    idx = np.nonzero((scene_flag['land'] == 1) & (lat < -69.0))
    scene_flag['desert'][idx] = 1

    idx = np.nonzero((eco == 2) | (eco == 8) | (eco == 11) | (eco == 40) | (eco == 41) | (eco == 46) |
                     (eco == 51) | (eco == 52) | (eco == 59) | (eco == 71) | (eco == 50))
    scene_flag['vrused'] = np.ones((dim1, dim2))
    scene_flag['vrused'][idx] = 0

    snow_fraction = data['geos_snowfr']
    perm_ice_fraction = data['geos_landicefr']
    ice_fraction = data['geos_icefr']

    idx = tuple(np.nonzero((snow_fraction > 0.10) & (snow_fraction <= 1.0)))
    scene_flag['map_snow'][idx] = 1

    idx = tuple(np.nonzero((perm_ice_fraction > 0.10) & (perm_ice_fraction <= 1.0)))
    scene_flag['map_snow'][idx] = 1

    idx = tuple(np.nonzero((ice_fraction > 0.10) & (ice_fraction <= 1.0)))
    scene_flag['map_ice'][idx] = 1

    # need to define this function and write this block better
    # if day == 1:
    #    # Run quick version of D. Hall's snow detection algorithm
    #    # Temporarily disabled because this function does not exist and I need to decide how to implement it
    #    # scene_flag['ndsi_snow'] = run_snow_mask()
    #     scene_flag['ndsi_snow'] = np.zeros((dim1, dim2))

    idx = np.nonzero((day == 1) & (water == 1) & (lat >= -60.0) & (lat <= 25.0) &
                     (scene_flag['map_snow'] == 1) & (scene_flag['ndsi_snow'] == 1))
    scene_flag['ice'][idx] = 1

    idx = np.nonzero((day == 1) & (water == 1) & (lat < -60.0) &
                     (scene_flag['ndsi_snow'] == 1))
    scene_flag['ice'][idx] = 1

    idx = np.nonzero((day == 1) & (water == 1) & (lsf == 3) | (lsf == 5) &
                     (scene_flag['ndsi_snow'] == 1))
    scene_flag['ice'][idx] = 1

    idx = np.nonzero((day == 1) & (water == 1) &
                     (scene_flag['map_ice'] == 1) & (scene_flag['ndsi_snow'] == 1))
    scene_flag['ice'][idx] = 1

    # Define New Zealand region which receives snow but snow map does not show it.
    idx = np.nonzero((day == 1) & (land == 1) &
                     (lat >= 48.0) & (lat <= -34.0) & (lon >= 165.0) & (lon <= 180.0))
    scene_flag['new_zealand'][idx] = 1

    idx = np.nonzero((day == 1) & (land == 1) & (lat >= -60.0) & (lat <= 25.0) &
                     (scene_flag['map_snow'] == 1) & (scene_flag['ndsi_snow'] == 1) |
                     (scene_flag['new_zealand'] == 1))
    scene_flag['snow'][idx] = 1

    idx = np.nonzero((day == 1) & (land == 1) & (lat < -60.0))
    scene_flag['snow'][idx] = 1

    idx = np.nonzero((day == 1) & (land == 1) & (scene_flag['ndsi_snow'] == 1))
    scene_flag['snow'][idx] = 1

    idx = np.nonzero((day == 0) & (scene_flag['map_snow'] == 1) &
                     (sfctmp > 280.0) & (elev < 500.0))
    scene_flag['snow'][idx] = 0

    idx = np.nonzero((day == 0) & (scene_flag['map_snow'] == 1) &
                     (sfctmp > 280.0) & (elev < 500.0))
    scene_flag['ice'][idx] = 0

    idx = np.nonzero((day == 0) & (lat > 86.0))
    scene_flag['ice'][idx] = 1

    # Check regional uniformity
    # Check for data border pixels
    scene_flag['uniform'] = anc.py_check_reg_uniformity(eco, eco, snowfr, icefr, lsf)['loc_uniform']
    print(scene_flag['uniform'][:10, :10])

    # TEMP VALUES FOR DEBUGGING
    scene_flag['lat'] = lat
    scene_flag['lon'] = lon

    return scene_flag


def scene_id(scene_flag):

    dim1, dim2 = scene_flag['day'].shape[0], scene_flag['day'].shape[1]
    scene = {scn: np.zeros((dim1, dim2)) for scn in _scene_list}

    # Ocean Day
    idx = np.nonzero((scene_flag['water'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                     (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['coast'] == 0) & (scene_flag['desert'] == 0))
    scene['Ocean_Day'][idx] = 1

    # Ocean Night
    idx = np.nonzero((scene_flag['water'] == 1) & (scene_flag['night'] == 1) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                     (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['coast'] == 0) & (scene_flag['desert'] == 0))
    scene['Ocean_Night'][idx] = 1

    # Land Day
    idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                     (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['coast'] == 0) & (scene_flag['desert'] == 0))
    scene['Land_Day'][idx] = 1

    # Land Night
    idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['night'] == 1) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0) &
                     (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['coast'] == 0))
    scene['Land_Night'][idx] = 1

    # Snow Day
    idx = np.nonzero((scene_flag['day'] == 1) &
                     ((scene_flag['ice'] == 1) | (scene_flag['snow'] == 1)) &
                     (scene_flag['polar'] == 1) & (scene_flag['antarctica'] == 1))
    scene['Snow_Day'][idx] = 1

    # Snow Night
    idx = np.nonzero((scene_flag['night'] == 1) &
                     ((scene_flag['ice'] == 1) | (scene_flag['snow'] == 1)) &
                     (scene_flag['polar'] == 1) & (scene_flag['antarctica'] == 1))
    scene['Snow_Night'][idx] = 1

    # Land Day Coast
    idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['coast'] == 1) & (scene_flag['desert'] == 0) &
                     (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0))
    scene['Land_Day_Coast'][idx] = 1

    # Land Day Desert
    idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['desert'] == 1) & (scene_flag['coast'] == 0) &
                     (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0))
    scene['Desert_Day'][idx] = 1

    # Land Day Desert Coast
    idx = np.nonzero((scene_flag['land'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['desert'] == 1) & (scene_flag['coast'] == 1) &
                     (scene_flag['polar'] == 0) & (scene_flag['antarctica'] == 0))
    scene['Land_Day_Desert_Coast'][idx] = 1

    # Antarctic Day
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['antarctica'] == 1) & (scene_flag['land'] == 1))
    scene['Antarctic_Day'][idx] = 1

    # Polar Day Snow
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                     ((scene_flag['snow'] == 1) | (scene_flag['ice'] == 1)) &
                     (scene_flag['antarctica'] == 0))
    scene['Polar_Day_Snow'][idx] = 1

    # Polar Day Desert
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['land'] == 1) & (scene_flag['desert'] == 1) &
                     (scene_flag['coast'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
    scene['Polar_Day_Desert'][idx] = 1

    # Polar Day Desert Coast
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['land'] == 1) & (scene_flag['desert'] == 1) &
                     (scene_flag['coast'] == 1) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
    scene['Polar_Day_Desert_Coast'][idx] = 1

    # Polar Day Coast
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['land'] == 1) & (scene_flag['coast'] == 1) &
                     (scene_flag['desert'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
    scene['Polar_Day_Coast'][idx] = 1

    # Polar Day Land
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['land'] == 1) & (scene_flag['coast'] == 0) &
                     (scene_flag['desert'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
    scene['Polar_Day_Land'][idx] = 1

    # Polar Day Ocean
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['day'] == 1) &
                     (scene_flag['water'] == 1) & (scene_flag['coast'] == 0) &
                     (scene_flag['desert'] == 0) & (scene_flag['antarctica'] == 0) &
                     (scene_flag['ice'] == 0) & (scene_flag['snow'] == 0))
    scene['Polar_Day_Ocean'][idx] = 1

    # Polar Night Snow
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['night'] == 1) &
                     ((scene_flag['snow'] == 1) | (scene_flag['ice'] == 1)))
    scene['Polar_Night_Snow'][idx] = 1

    # Polar Night Land
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['night'] == 1) &
                     (scene_flag['land'] == 1))
    scene['Polar_Night_Land'][idx] = 1

    # Polar Night Ocean
    idx = np.nonzero((scene_flag['polar'] == 1) & (scene_flag['night'] == 1) &
                     (scene_flag['water'] == 1))
    scene['Polar_Night_Ocean'][idx] = 1

    idx = np.nonzero(scene_flag['desert'] == 1)
    scene['Desert'][idx] = 1

    return scene