import numpy as np
import h5py
import xarray as xr

from util.util import get_grid_values_all


def acspo_validate(oper_file, cspp_file, rel_tol=0.001, deg_tol=0.0002, outfile_nc=None):

    h5f_oper = h5py.File(oper_file, 'r')
    h5f_cspp = h5py.File(cspp_file, 'r')

    lon_cspp = get_grid_values_all(h5f_cspp, 'lon')
    lat_cspp = get_grid_values_all(h5f_cspp, 'lat')
    if np.any(np.isnan(lat_cspp)) or np.any(np.isnan(lon_cspp)):
        print('Invalid lat')
    if np.min(lat_cspp) < -90 or np.max(lat_cspp) > 90:
        print('Invalid lat')
    if np.min(lon_cspp) < -180 or np.max(lon_cspp) > 180:
        print('Invalid lon')

    # ----* Investigate: CSPP and OPER time values don't seem to ever match *-------------------
    # dtime_cspp = get_grid_values_all(h5f_cspp, 'sst_dtime')[0, ]
    # ref_time_cspp = get_grid_values_all(h5f_cspp, 'time')[0]
    print('cspp shape: ', lat_cspp.shape)
    cspp_track_len = lat_cspp.shape[0]
    c_idx = lat_cspp.shape[1] // 2

    cntr_lon_cspp = lon_cspp[:, c_idx]
    cntr_lat_cspp = lat_cspp[:, c_idx]
    # cntr_dtime_cspp = dtime_cspp[::16, c_idx]
    sst_cspp = get_grid_values_all(h5f_cspp, 'sea_surface_temperature')[0, ]
    l2p_flags_cspp = get_grid_values_all(h5f_cspp, 'l2p_flags')[0, ]

    lon_oper = get_grid_values_all(h5f_oper, 'lon')
    lat_oper = get_grid_values_all(h5f_oper, 'lat')
    if np.any(np.isnan(lat_cspp)) or np.any(np.isnan(lon_cspp)):
        print('Invalid lat')
    if np.min(lat_cspp) < -90 or np.max(lat_cspp) > 90:
        print('Invalid lat')
    if np.min(lon_cspp) < -180 or np.max(lon_cspp) > 180:
        print('Invalid lon')

    # --* See note above on issue with time *-------------------------------------
    # dtime_oper = get_grid_values_all(h5f_oper, 'sst_dtime')[0, ]
    # ref_time_oper = get_grid_values_all(h5f_oper, 'time')[0]

    cntr_lon_oper = lon_oper[:, c_idx]
    cntr_lat_oper = lat_oper[:, c_idx]
    # cntr_dtime_oper = dtime_oper[::16, c_idx]
    print('oper shape: ', lat_oper.shape)
    oper_track_len = lat_oper.shape[0]
    sst_oper = get_grid_values_all(h5f_oper, 'sea_surface_temperature')[0, ]
    l2p_flags_oper = get_grid_values_all(h5f_oper, 'l2p_flags')[0, ]

    # cntr_time_cspp = cntr_dtime_cspp + ref_time_cspp
    # cntr_time_oper = cntr_dtime_oper + ref_time_oper

    # generate a ndarray of boolean, True: clear or probably clear
    cspp_clear = (l2p_flags_cspp & (1 << 15)) == 0
    oper_clear = (l2p_flags_oper & (1 << 15)) == 0

    # Find the overlap Track indexes relative to each ------------------------
    start_idx_oper, stop_idx_oper = -1, -1
    start_idx_cspp, stop_idx_cspp = -1, -1

    # Average latitude spacing: 0.00652 deg
    thresh = deg_tol  # deg

    for k in range(len(cntr_lat_oper)):
        diff = np.abs(cntr_lat_oper[k] - cntr_lat_cspp)
        idx_a = np.where(diff < thresh)[0]
        if len(idx_a) == 0:
            continue
        elif len(idx_a) == 1:
            start_idx_oper = k
            start_idx_cspp = idx_a[0]
            break
        else:
            print("weirdness")
            break

    if start_idx_oper < 0:
        print("no overlap found")

    num_oper_remain = (oper_track_len - 1) - start_idx_oper
    num_cspp_remain = (cspp_track_len - 1) - start_idx_cspp

    if num_oper_remain <= num_cspp_remain:
        stop_idx_oper = start_idx_oper + num_oper_remain
        stop_idx_cspp = start_idx_cspp + num_oper_remain
    else:
        stop_idx_oper = start_idx_oper + num_cspp_remain
        stop_idx_cspp = start_idx_cspp + num_cspp_remain

    print('oper start, stop ', start_idx_oper, stop_idx_oper)
    print('cspp start, stop ', start_idx_cspp, stop_idx_cspp)
    # --------------------------------------------------------------------------

    lon_cspp = lon_cspp[start_idx_cspp:stop_idx_cspp, :]
    lat_cspp = lat_cspp[start_idx_cspp:stop_idx_cspp, :]
    lon_oper = lon_oper[start_idx_oper:stop_idx_oper, :]
    lat_oper = lat_oper[start_idx_oper:stop_idx_oper, :]
    cspp_clear = cspp_clear[start_idx_cspp:stop_idx_cspp, :]
    oper_clear = oper_clear[start_idx_oper:stop_idx_oper, :]
    sst_cspp_2d = sst_cspp[start_idx_cspp:stop_idx_cspp, :]
    sst_oper_2d = sst_oper[start_idx_oper:stop_idx_oper, :]
    overlap_shape = lon_cspp.shape
    print('overlap shape, size: ', overlap_shape, np.size(lon_cspp))

    # sst_cspp_2d = np.where(cspp_clear, sst_cspp_2d, np.nan)
    # sst_oper_2d = np.where(oper_clear, sst_oper_2d, np.nan)

    sst_cspp = sst_cspp_2d.flatten()
    sst_oper = sst_oper_2d.flatten()
    oper_clear = oper_clear.flatten()
    cspp_clear = cspp_clear.flatten()
    lon_cspp = lon_cspp.flatten()
    lat_cspp = lat_cspp.flatten()
    lon_oper = lon_oper.flatten()
    lat_oper = lat_oper.flatten()

    both_valid = np.invert(np.isnan(sst_cspp)) & np.invert(np.isnan(sst_oper))
    both_clear = oper_clear & cspp_clear
    keep = both_valid & both_clear

    valid_sst_cspp = sst_cspp[keep]
    valid_sst_oper = sst_oper[keep]
    valid_lon_cspp = lon_cspp[keep]
    valid_lat_cspp = lat_cspp[keep]
    valid_lon_oper = lon_oper[keep]
    valid_lat_oper = lat_oper[keep]

    xarray_data = xr.Dataset({
        'sst_cspp': xr.DataArray(valid_sst_cspp, coords=None, dims=None, name='sst_cspp'),
        'sst_oper': xr.DataArray(valid_sst_oper, coords=None, dims=None, name='sst_oper'),
        'cspp_lat': xr.DataArray(valid_lat_cspp, coords=None, dims=None, name='cspp_lat'),
        'cspp_lon': xr.DataArray(valid_lon_cspp, coords=None, dims=None, name='cspp_lon'),
        'oper_lat': xr.DataArray(valid_lat_oper, coords=None, dims=None, name='oper_lat'),
        'oper_lon': xr.DataArray(valid_lon_oper, coords=None, dims=None, name='oper_lon'),
    })
    if outfile_nc is not None:
        xarray_data.to_netcdf(outfile_nc)

    print('fraction approx equal: ',
          np.sum(np.isclose(valid_sst_cspp, valid_sst_oper, rtol=rel_tol))/np.sum(keep))

    h5f_oper.close()
    h5f_cspp.close()

    return sst_cspp_2d, sst_oper_2d


def analyze_plot():
    import h5py
    h5f_a = h5py.File('/Users/tomrink/20241107055000.nc', 'r')
    h5f_b = h5py.File('/Users/tomrink/20241107072000.nc', 'r')
    h5f_c = h5py.File('/Users/tomrink/20241107073000.nc', 'r')
    h5f_d = h5py.File('/Users/tomrink/20241107170000.nc', 'r')
    h5f_e = h5py.File('/Users/tomrink/20241107171000.nc', 'r')
    h5f_f = h5py.File('/Users/tomrink/20241107202000.nc', 'r')
    h5f_g = h5py.File('/Users/tomrink/20241107203000.nc', 'r')

    import numpy as np
    sst_cspp_all = np.concatenate([h5f_a['sst_cspp'][0, :], h5f_b['sst_cspp'][0, :], h5f_c['sst_cspp'][0, :], h5f_d['sst_cspp'][0, :], h5f_e['sst_cspp'][0, :], h5f_f['sst_cspp'][0, :], h5f_g['sst_cspp'][0, :]])
    print(sst_cspp_all.shape)
    sst_oper_all = np.concatenate([h5f_a['sst_oper'][0, :], h5f_b['sst_oper'][0, :], h5f_c['sst_oper'][0, :], h5f_d['sst_oper'][0, :], h5f_e['sst_oper'][0, :], h5f_f['sst_oper'][0, :], h5f_g['sst_oper'][0, :]])
    print(sst_oper_all.shape)

    diff = sst_cspp_all - sst_oper_all
    import matplotlib.pyplot as plt
    plt.hist(diff, bins=40)
    plt.yscale('log')
    plt.title('ACSPO SST (CSPP - OPER)')
    import h5py
    h5f_a = h5py.File('/Users/tomrink/20241107055000.nc', 'r')
    h5f_b = h5py.File('/Users/tomrink/20241107072000.nc', 'r')
    h5f_c = h5py.File('/Users/tomrink/20241107073000.nc', 'r')
    h5f_d = h5py.File('/Users/tomrink/20241107170000.nc', 'r')
    h5f_e = h5py.File('/Users/tomrink/20241107171000.nc', 'r')
    h5f_f = h5py.File('/Users/tomrink/20241107202000.nc', 'r')
    h5f_g = h5py.File('/Users/tomrink/20241107203000.nc', 'r')
    import numpy as np
    sst_cspp_all = np.concatenate([h5f_a['sst_cspp'][0, :], h5f_b['sst_cspp'][0, :], h5f_c['sst_cspp'][0, :], h5f_d['sst_cspp'][0, :], h5f_e['sst_cspp'][0, :], h5f_f['sst_cspp'][0, :], h5f_g['sst_cspp'][0, :]])
    sst_oper_all = np.concatenate([h5f_a['sst_oper'][0, :], h5f_b['sst_oper'][0, :], h5f_c['sst_oper'][0, :], h5f_d['sst_oper'][0, :], h5f_e['sst_oper'][0, :], h5f_f['sst_oper'][0, :], h5f_g['sst_oper'][0, :]])
    diff = sst_cspp_all - sst_oper_all
    import matplotlib.pyplot as plt
    plt.hist(diff, bins=40)
    plt.yscale('log')
    plt.xlabel('SST (\u00b0K)')
    plt.title('ACSPO SST (CSPP - OPER), Clear Sky, 2024-11_07')
    import numpy as np
    c_a = np.isclose(sst_cspp_all, sst_oper_all, rtol=0.001)
    c_a.shape
    sst_cspp_all.shape
    c_a_i = np.invert(c_a)
    plt.scatter(sst_cspp_all[c_a_i], diff[c_a_i])
    plt.close()
    plt.scatter(sst_cspp_all[c_a_i], diff[c_a_i], s=1)
    plt.title('ACSPO SST (CSPP-OPER) > rel_tol=0.001 vs SST')
    plt.xlabel('SST (\u00b0K)')
    plt.ylabel('SST (\u00b0K)')
    plt.show()

    import h5py
    h5f_a = h5py.File('/Users/tomrink/20241107055000.nc', 'r')
    h5f_b = h5py.File('/Users/tomrink/20241107072000.nc', 'r')
    h5f_c = h5py.File('/Users/tomrink/20241107073000.nc', 'r')
    h5f_d = h5py.File('/Users/tomrink/20241107170000.nc', 'r')
    h5f_e = h5py.File('/Users/tomrink/20241107171000.nc', 'r')
    h5f_f = h5py.File('/Users/tomrink/20241107202000.nc', 'r')
    h5f_g = h5py.File('/Users/tomrink/20241107203000.nc', 'r')
    import numpy as np
    sst_cspp_all = np.concatenate(
        [h5f_a['sst_cspp'][0, :], h5f_b['sst_cspp'][0, :], h5f_c['sst_cspp'][0, :], h5f_d['sst_cspp'][0, :],
         h5f_e['sst_cspp'][0, :], h5f_f['sst_cspp'][0, :], h5f_g['sst_cspp'][0, :]])
    sst_cspp_all.shape
    sst_oper_all = np.concatenate(
        [h5f_a['sst_oper'][0, :], h5f_b['sst_oper'][0, :], h5f_c['sst_oper'][0, :], h5f_d['sst_oper'][0, :],
         h5f_e['sst_oper'][0, :], h5f_f['sst_oper'][0, :], h5f_g['sst_oper'][0, :]])
    sst_oper_all.shape
    diff = sst_cspp_all - sst_oper_all
    import matplotlib.pyplot as plt
    plt.hist(diff, bins=40)
    plt.yscale('log')
    plt.title('ACSPO SST (CSPP - OPER)')
    import h5py
    h5f_a = h5py.File('/Users/tomrink/20241107055000.nc', 'r')
    h5f_b = h5py.File('/Users/tomrink/20241107072000.nc', 'r')
    h5f_c = h5py.File('/Users/tomrink/20241107073000.nc', 'r')
    h5f_d = h5py.File('/Users/tomrink/20241107170000.nc', 'r')
    h5f_e = h5py.File('/Users/tomrink/20241107171000.nc', 'r')
    h5f_f = h5py.File('/Users/tomrink/20241107202000.nc', 'r')
    h5f_g = h5py.File('/Users/tomrink/20241107203000.nc', 'r')
    import numpy as np
    sst_cspp_all = np.concatenate(
        [h5f_a['sst_cspp'][0, :], h5f_b['sst_cspp'][0, :], h5f_c['sst_cspp'][0, :], h5f_d['sst_cspp'][0, :],
         h5f_e['sst_cspp'][0, :], h5f_f['sst_cspp'][0, :], h5f_g['sst_cspp'][0, :]])
    sst_oper_all = np.concatenate(
        [h5f_a['sst_oper'][0, :], h5f_b['sst_oper'][0, :], h5f_c['sst_oper'][0, :], h5f_d['sst_oper'][0, :],
         h5f_e['sst_oper'][0, :], h5f_f['sst_oper'][0, :], h5f_g['sst_oper'][0, :]])
    diff = sst_cspp_all - sst_oper_all
    import matplotlib.pyplot as plt
    plt.hist(diff, bins=40)
    plt.yscale('log')
    plt.xlabel('SST (\u00b0K)')
    plt.title('ACSPO SST (CSPP - OPER), Clear Sky, 2024-11_07')
    import numpy as np
    c_a = np.isclose(sst_cspp_all, sst_oper_all, rtol=0.001)
    c_a.shape
    sst_cspp_all.shape
    c_a_i = np.invert(c_a)
    plt.scatter(sst_cspp_all[c_a_i], diff[c_a_i])
    plt.close()
    plt.scatter(sst_cspp_all[c_a_i], diff[c_a_i], s=1)
    plt.title('ACSPO SST (CSPP-OPER) > rel_tol=0.001 vs SST')
    plt.xlabel('SST (\u00b0K)')
    plt.ylabel('SST (\u00b0K)')
    plt.show()