
import numpy as np
# from cartopy import *
from metpy import *

import h5py

from aeolus.aeolus_amv import get_aeolus_time_dict
from aeolus.datasource import CLAVRx
from util.lon_lat_grid import earth_to_indexs
from util.geos_nav import get_navigation
from util.util import bin_data_by, get_bin_ranges

# H08 range we'll use for now
H08_lon_range = [64, 216]  # 0, 360
H08_lat_range = [-70, 70]

G16_lon_range = [-130.0, -30.0]
G16_lat_range = [-64.0, 64.0]

bin_ranges = []
bin_ranges.append([300.0, 2000.0])
bin_ranges.append([2000.0, 8000.0])
bin_ranges.append([8000.0, 20000.0])

lat_ranges = get_bin_ranges(-70, 70, 10)

half_width = 30
num_elems = 5424
num_lines = 5424


def compare_aeolus_max_height(aeolus_dict, files_path, grid_value_name='cld_height_acha', one_cld_layer_only=False, highest_layer=False):
    nav = get_navigation(satellite='GOES16', domain='FD')
    clvrx_files = CLAVRx(files_path)
    num_files = len(clvrx_files.flist)

    keys = list(aeolus_dict.keys())
    num_profs = len(keys)

    a_lons = [[] for i in range(num_files)]
    a_lats = [[] for i in range(num_files)]
    a_hgts = [[] for i in range(num_files)]
    a_profs = [[] for i in range(num_files)]
    a_times = [[] for i in range(num_files)]

    for key in keys:
        _, _, f_idx = clvrx_files.get_file_containing_time(key)
        if f_idx is None:
            continue
        layers = aeolus_dict.get(key)
        if layers is None:
            continue

        lat = layers[0, 0]
        lon = layers[0, 1]

        a_lons[f_idx].append(lon)
        a_lats[f_idx].append(lat)
        a_times[f_idx].append(key)
        a_profs[f_idx].append(layers)

    for f_idx in range(num_files):
        a_lons[f_idx] = np.array(a_lons[f_idx])
        a_lats[f_idx] = np.array(a_lats[f_idx])
        a_hgts[f_idx] = np.array(a_hgts[f_idx])
        a_times[f_idx] = np.array(a_times[f_idx])

    grd_hgt = []
    grd_mean = []
    grd_std = []
    hits = []
    num_layers = []
    lon = []
    lat = []
    layer_thickness = []
    hgt_diff = []
    hgt_top_max_diff_s = []
    layer_bot = []
    layer_top = []
    times = []

    # Loop over all Datasets
    vld_grd = 0
    num_hits = 0
    total = 0

    for f_idx, file in enumerate(clvrx_files.flist):
        try:
            h5f = h5py.File(file, 'r')
        except Exception:
            h5f.close()
            print('Problem with file: ', file)
            continue

        try:
            grd_vals = get_grid_values(h5f, grid_value_name)
        except Exception:
            print('can\'t read grid')
            h5f.close()
            continue

        nn_idxs = nav.earth_to_indexs(a_lons[f_idx], a_lats[f_idx], grd_vals.shape[0])
        if len(nn_idxs) == 0:
            continue

        on_earth = nn_idxs != -1
        nn_idxs = nn_idxs[on_earth]
        grd_crds = np.unravel_index(nn_idxs, grd_vals.shape)
        a_lons[f_idx] = a_lons[f_idx][on_earth]
        a_lats[f_idx] = a_lats[f_idx][on_earth]
        a_times[f_idx] = a_times[f_idx][on_earth]

        layers = []
        for k in range(len(on_earth)):
            if on_earth[k]:
                layers.append(a_profs[f_idx][k])
        a_profs[f_idx] = layers

        total += len(nn_idxs)
        print(clvrx_files.flist[f_idx], len(nn_idxs))

        # loop over Aeolus profile valid in the file
        for k in range(len(nn_idxs)):
            j = grd_crds[0][k]
            i = grd_crds[1][k]

            grd_cntr_val = grd_vals[j, i]
            grd_vals_3x3 = grd_vals[j-1:j+2, i-1:i+2]
            grd_vals_3x3 = grd_vals_3x3.flatten()
            # if np.sum(np.isnan(grd_vals_3x3)) > 4:
            #     continue
            keep = np.invert(np.isnan(grd_vals_3x3))
            grd_3x3_mean = np.mean(grd_vals_3x3[keep])
            grd_3x3_std = np.std(grd_vals_3x3[keep])
            # grd_max = np.max(grd_vals_3x3[keep])

            # bt_vals_5x5 = bt_vals[j-2:j+3, i-2:i+3]
            # bt_vals_5x5 = bt_vals_5x5.flatten()
            # keep = np.invert(np.isnan(bt_vals_5x5))
            # bt_5x5_mean = np.mean(bt_vals_5x5[keep])
            # bt_5x5_std = np.std(bt_vals_5x5[keep])

            grd_vals_1x1 = grd_vals[j, i]
            grd_vals_1x1 = grd_vals_1x1.flatten()
            if (np.isnan(grd_vals_1x1[0])):
                continue

            prof = a_profs[f_idx][k]
            max_idx = np.argmax(prof[:, 2])
            hht_max = prof[max_idx, 2]
            nlayers = prof.shape[0]

            lyr_idxs = [*range(nlayers)]
            if highest_layer:
                lyr_idxs = [max_idx]

            if one_cld_layer_only and nlayers > 1:
                continue

            vld_grd += 1
            hit = False
            lyr_width = None
            hgt_top_diff = None
            hgt_top_max_diff = None
            tops = []
            bots = []

            for j in lyr_idxs:
                hhb = prof[j, 3]
                hht = prof[j, 2]
                bots.append(hhb)
                tops.append(hht)

                lyr_width = hht - hhb

                hhb -= 100
                hht += 100

                cnt = 0
                for i in range(len(grd_vals_1x1)):
                    ghgt = grd_vals_1x1[i]
                    if np.isnan(ghgt):
                        continue

                    hgt_top_diff = ghgt - hht
                    hgt_top_max_diff = ghgt - hht_max

                    if ghgt >= hhb and ghgt <= hht:
                        cnt += 1

                if cnt > 0:
                    hit = True
                    num_hits += 1
                    break

            hits.append(hit)
            grd_hgt.append(grd_vals_1x1)
            grd_mean.append(grd_3x3_mean)
            grd_std.append(grd_3x3_std)

            lon.append(a_lons[f_idx][k])
            lat.append(a_lats[f_idx][k])
            times.append(a_times[f_idx][k])

            layer_thickness.append(lyr_width)
            hgt_diff.append(hgt_top_diff)
            hgt_top_max_diff_s.append(hgt_top_max_diff)

            num_layers.append(nlayers)
            layer_bot.append(bots)
            layer_top.append(tops)

        h5f.close()
        print(f_idx, num_files)

    print('num aeolus profiles: ', num_profs)
    print('Overall: ', num_hits, vld_grd, total)
    print('*')
    hits = np.array(hits)
    grd_hgt = np.array(grd_hgt)
    grd_mean = np.array(grd_mean)
    grd_std = np.array(grd_std)
    num_layers = np.array(num_layers)
    lon = np.array(lon)
    lat = np.array(lat)
    hgt_diff = np.array(hgt_diff)
    hgt_top_max_diff_s = np.array(hgt_top_max_diff_s)
    layer_thickness = np.array(layer_thickness)
    layer_bot = np.array(layer_bot)
    layer_top = np.array(layer_top)
    times = np.array(times)

    print('AEOLUS num layers: ')
    print(np.average(num_layers), np.std(num_layers))
    print('--------------------')

    print('Grid cntr height: ')
    print(np.average(grd_hgt[hits]))
    print(np.average(grd_hgt[np.invert(hits)]))
    print('--------------------')

    print('3x3 mean: ')
    print(np.average(grd_mean[hits]))
    print(np.average(grd_mean[np.invert(hits)]))
    print('--------------------')

    print('3x3 std: ')
    print(np.average(grd_std[hits]))
    print(np.average(grd_std[np.invert(hits)]))
    print('--------------------')

    print('layer thickness: ')
    print(np.average(layer_thickness[hits]))
    print(np.average(layer_thickness[np.invert(hits)]))
    print('---------------------')

    print('CLAVR - AEOLUS TOP: ')
    print(np.average(hgt_diff[hits]), np.std(hgt_diff[hits]))
    print(np.average(hgt_top_max_diff_s[np.invert(hits)]))
    print('---------------------')

    print('CLVR grid cntr height histo: ')
    print(np.histogram(grd_hgt, bins=8))
    print('----------------------')

    bdat = bin_data_by(hgt_top_max_diff_s, grd_hgt[:, 0], bin_ranges=bin_ranges)
    print('CLAVR - (HIGHEST AEOLUS CLD LYR TOP: ')
    print('All: ')
    print(np.average(hgt_top_max_diff_s), np.std(hgt_top_max_diff_s))
    print('Low: ')
    print(np.average(bdat[0]), np.std(bdat[0]))
    print('Mid: ')
    print(np.average(bdat[1]), np.std(bdat[1]))
    print('High: ')
    print(np.average(bdat[2]), np.std(bdat[2]))
    print('------------------------')

    # bdat = bin_data_by(hgt_top_max_diff_s, lat[:, 0], bin_ranges=lat_ranges)
    # lat_cntrs = []
    # avg_dz = []
    # std_dz = []
    # for i in range(len(lat_ranges)):
    #     lat_cntrs.append((lat_ranges[i][0] + lat_ranges[i][1])/2)
    #     avg_dz.append(np.average(bdat[i]))
    #     std_dz.append(np.std(bdat[i]))

    return layer_bot, layer_top, grd_hgt, lat, lon, times


def get_grid_values(h5f, grid_name, scale_factor_name='scale_factor', add_offset_name='add_offset'):
    hfds = h5f[grid_name]
    attrs = hfds.attrs

    grd_vals = hfds[:, :]
    grd_vals = np.where(grd_vals == -999, np.nan, grd_vals)
    grd_vals = np.where(grd_vals == -127, np.nan, grd_vals)
    grd_vals = np.where(grd_vals == -32768, np.nan, grd_vals)

    if attrs is None:
        return grd_vals

    if scale_factor_name is not None:
        scale_factor = attrs.get(scale_factor_name)[0]
        grd_vals = grd_vals * scale_factor

    if add_offset_name is not None:
        add_offset = attrs.get(add_offset_name)[0]
        grd_vals = grd_vals + add_offset

    return grd_vals


def analyze_aelolus_cld_layers(cld_layer_dct):
    keys = list(cld_layer_dct.keys())

    num_single_layer = 0
    num_multi_layer = 0
    num_cld_layers = 0

    lats = []
    lons = []
    lats_sl = []
    lons_sl = []

    cld_tops = []
    cld_thickness = []
    cld_cntr = []
    num_layers = []
    max_top = []

    for key in keys:
        layers = cld_layer_dct[key]

        lat = layers[0, 0]
        lon = layers[0, 1]

        n_lyrs = layers.shape[0]
        if n_lyrs > 1:
            num_multi_layer += 1
        else:
            num_single_layer += 1
        num_cld_layers += 1

        num_layers.append(n_lyrs)

        idx_max = np.argmax(layers[:, 2])
        max_top.append(layers[idx_max, 2])
        lats_sl.append(lat)
        lons_sl.append(lon)

        for k in range(n_lyrs):
            # hhh = layers[k, 2]
            hht = layers[k, 2]
            hhb = layers[k, 3]
            cld_tops.append(hht)
            cld_thickness.append(hht - hhb)
            cld_cntr.append((hht - hhb)/2)

            lats.append(lat)
            lons.append(lon)

    lats = np.array(lats)
    lons = np.array(lons)
    cld_tops = np.array(cld_tops)
    cld_thickness = np.array(cld_thickness)
    cld_cntr = np.array(cld_cntr)
    max_top = np.array(max_top)
    lats_sl = np.array(lats_sl)
    lons_sl = np.array(lons_sl)

    print(num_cld_layers, num_single_layer, num_multi_layer)

    mt_by_lat = bin_data_by(max_top, lats_sl, lat_ranges)

    lat_cntrs = []
    avg_maxz = []
    std_maxz = []
    for i in range(len(lat_ranges)):
        lat_cntrs.append((lat_ranges[i][0] + lat_ranges[i][1])/2)
        avg_maxz.append(np.average(mt_by_lat[i]))
        std_maxz.append(np.std(mt_by_lat[i]))

    return cld_thickness, max_top, avg_maxz, std_maxz, lat_cntrs

