#!/usr/bin/env python3
# -*- coding: utf-8 -*-
Created on Wed May 22 13:58:09 2019

@author: Wimmers

from sklearn.metrics import roc_auc_score, roc_curve
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import as ccrs
import xarray as xr
import os
import h5py
from util.util import get_grid_values_all, get_cartopy_crs, homedir

from util.setup import home_dir

logdir = home_dir+'/tf_logs'
root_imgdir = home_dir
img_name = 'train_curve.png'

run_valid_stepping = 1

cld_phase_clr_list = [[128,128,128], [11,36,250], [45,254,254], [127,15,126], [254,192,203], [0,0,0]]

cld_phase_clr_list_2 = [[128,128,128],[254,192,203],[11,36,250],[127,15,126]]

for i in range(4):
    cld_phase_clr_list_2[i][0] /= 255
    cld_phase_clr_list_2[i][1] /= 255
    cld_phase_clr_list_2[i][2] /= 255

from matplotlib.colors import ListedColormap
cld_phase_cmap = ListedColormap(cld_phase_clr_list_2, 'CldPhase')

cld_mask_list = [[0,220,0], [220, 220, 220]]

for i in range(2):
    cld_mask_list[i][0] /= 255
    cld_mask_list[i][1] /= 255
    cld_mask_list[i][2] /= 255

cld_mask_cmap = ListedColormap(cld_mask_list, 'CldMask')

def plot_train_curve(rundir=None, title=None):
    last_epoch = 0
    num_train_steps = 0
    step = 0
    train_steps = []
    train_losses = []

    if rundir is not None:
        img_name = os.path.split(rundir)[1]+'.png'
        rundir = os.path.join(logdir, rundir)
        rundir = logdir

    events_file = os.listdir(os.path.join(rundir, 'plot_train'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_train', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_trn':
                step += run_valid_stepping
            if v.tag == 'num_epochs':
                last_epoch = v.simple_value
            if v.tag == 'num_train_steps':
                num_train_steps = v.simple_value
    step = 0
    lowest_loss = 1000
    lowest_loss_step = 0
    valid_steps = []
    valid_losses = []
    valid_acc = []
    events_file = os.listdir(os.path.join(rundir, 'plot_valid'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_valid', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_val':
                if v.simple_value < lowest_loss:
                    lowest_loss = v.simple_value
                    lowest_loss_step = train_steps[step]
                step += run_valid_stepping
            if v.tag == 'acc_val':

    ratio = last_epoch/num_train_steps

    for i in range(len(train_steps)):
        train_steps[i] *= ratio
    lowest_loss_step *= ratio

    fig, ax = plt.subplots(1)
    if title is not None:
    ax2 = ax.twinx()
    ax2.set_ylabel('Mean Absolute Error')
    plt.grid(True, linestyle=':')
    # plt.subplots_adjust(bottom=0.12, top=0.94, right=0.97, left=0.13)
    plt.subplots_adjust(bottom=0.12, top=0.94)
    train_hdl, = plt.plot(train_steps, train_losses, '-', linewidth=1.5, color='b')
    valid_hdl, = plt.plot(train_steps, valid_losses, '-', linewidth=1.5, color=[1.0, 0.5, 0.0])
    acc, = plt.plot(train_steps, valid_acc, '-', linewidth=1.5, color=[0.5, 0.0, 0.5])
    optim_hdl, = plt.plot(lowest_loss_step, lowest_loss, 'o', color='g')
    plt.legend([train_hdl, valid_hdl, acc, optim_hdl], ['Training', 'Validation', 'MAE', 'Best loss'],
            loc='center right', fancybox=True, framealpha=0.5)
    # plt.xlabel('Epoch')
    # plt.ylabel('Loss')
    fig = plt.gcf()
    plt.savefig(os.path.join(root_imgdir, img_name), dpi=100)

def plot_roc_curve(correct_labels, prob_pos_class):
    auc = roc_auc_score(correct_labels, prob_pos_class)
    print('AUC: ', auc)
    fpr, tpr, _ = roc_curve(correct_labels, prob_pos_class)

    plt.plot(fpr, tpr)

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.savefig('/home/rink/test_roc', dpi=320)

def plot_image(image, fname, label, cmap='Greys'):
    Fig, ax = plt.subplots()

    # Plot the image
    plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-1.5,vmax=2.5)
    # plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-0.5,vmax=1.5)
    # plt.imshow(radGrid, interpolation = "nearest", cmap = cmap, vmin = vmin, vmax = vmax)
    # plt.title(os.path.basename(InFileWV), fontsize=10)

    # Add the colorbar
    cbar = plt.colorbar(ticks=range(3))[0.90, 0.05, 0.03, 0.90])  # [left, bottom, width, height], rotation=90, fontsize=10)

    # *Now* that the colorbar is out, we can set the plot's axis position
    ax.set_position([0.05, 0.05, 0.83, 0.90])  # [left, bottom, width, height]

    # Finish the figure settings and save to a file
    Fig.set_size_inches([10.0, 6.6], forward=True)
    Fig.set_dpi(100)  # This seems to not get applied. Always 100?
    # # Display

    # Save to file
    ImageDirAndName = os.path.join('/Users/tomrink', fname)

    # Wrap up
    Fig.clf()  # Otherwise fig stays open every time
    plt.close()  # Close current figure
    # plt.close('all')

def plot_image2(image, cmap='Greys', title='BOEING Icing Reports 2018-2020'):
    fig = plt.figure(figsize=(12, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())
    img_extent = [-180, 180, -90, 90]
    ax.imshow(image, origin='lower', extent=img_extent, transform=ccrs.PlateCarree(), cmap=cmap, vmin=0.1, vmax=1)
    ax.coastlines(resolution='50m', color='blue', linewidth=0.5)

    plt.title(title, loc='left', fontweight='bold', fontsize=16)

    # Save to file
    ImageDirAndName = os.path.join('/Users/tomrink', 'testimage')

    # Wrap up
    fig.clf()  # Otherwise fig stays open every time
    plt.close()  # Close current figure

def plot_image3(image, bt, cmap='Greys'):

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs('GOES16', 'CONUS')

    fig = plt.figure(figsize=(15, 12))
    ax = fig.add_subplot(1, 1, 1, projection=geos)
    ax.set_extent([-105, -70, 15, 50], crs=ccrs.PlateCarree())
    #ax.imshow(bt, cmap=cmap, origin='upper', extent=(xmin, xmax, ymin, ymax), transform=geos)
    #ax.imshow(image, cmap='YlGnBu', origin='upper', extent=(xmin, xmax, ymin, ymax), vmin=0.6, vmax=1.0, transform=geos)
    ax.imshow(image, cmap='Purples', origin='upper', extent=(xmin, xmax, ymin, ymax), vmin=0.52, vmax=1.0, transform=geos)
    ax.coastlines(resolution='50m', color='blue', linewidth=0.25)
    ax.add_feature(ccrs.cartopy.feature.STATES, linewidth=0.25)
    plt.title('GOES-16', loc='left', fontweight='bold', fontsize=15)

def make_icing_image(h5f, probs, ice_lons, ice_lats, clvrx_str_time, satellite, domain, ice_lons_vld=None, ice_lats_vld=None, imagename='icing', extent=[-110, -60, 10, 55]):
    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    bg_image = None
    if h5f is not None:
        bg_image = get_grid_values_all(h5f, 'temp_10_4um_nom')

    fig = plt.figure(figsize=(48, 30))
    ax = fig.add_subplot(1, 1, 1, projection=geos)
    if extent is not None:
        if satellite == 'H08':
            ax.set_extent(extent, crs=geos)
            ax.set_extent(extent, crs=ccrs.PlateCarree())
    if bg_image is not None:
        bg_im = ax.imshow(bg_image, origin='upper', extent=(xmin, xmax, ymin, ymax), cmap='Greys', transform=geos)
    if probs is not None:
        p_im = ax.imshow(probs, cmap='Purples', origin='upper', extent=(xmin, xmax, ymin, ymax),
                         vmin=0.20, vmax=1.0, alpha=0.8, transform=geos)
    ax.coastlines(resolution='50m', color='green', linewidth=1.50)
    # ax.add_feature(ccrs.cartopy.feature.STATES, linewidth=0.25)

    if ice_lons is not None:
        ax.scatter(ice_lons, ice_lats, s=40.0, marker='o', color='blue', transform=ccrs.PlateCarree())
    if ice_lats_vld is not None:
        ax.scatter(ice_lons_vld, ice_lats_vld, s=180.0, marker='o', color='red', transform=ccrs.PlateCarree())

    title = satellite+', '+clvrx_str_time+'  >50% Probability Icing'
    plt.title(title, loc='left', fontweight='bold', fontsize=24)
    cbar = plt.colorbar(p_im)

    ImageDirAndName = os.path.join(homedir, imagename+'_'+clvrx_str_time)

# values, edges in Epoch time
def make_time_domain_hist(values, edges):
    # convert the epoch format to matplotlib date format
    #mpl_data = mdates.epoch2num(values)

    fig, ax = plt.subplots(1, 1)
    ax.hist(values, bins=edges, ec='black')