Skip to content
Snippets Groups Projects
plot.py 9.41 KiB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 22 13:58:09 2019

@author: Wimmers
"""

from sklearn.metrics import roc_auc_score, roc_curve
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import cartopy.crs as ccrs
import xarray as xr
import os
import h5py
from util.util import get_grid_values_all, get_cartopy_crs, homedir

from util.setup import home_dir

logdir = home_dir+'/tf_logs'
root_imgdir = home_dir
img_name = 'train_curve.png'

run_valid_stepping = 1

try:
    matplotlib.use('tkagg')
except Exception as e:
    print(e)

cld_phase_clr_list = [[128,128,128], [11,36,250], [45,254,254], [127,15,126], [254,192,203], [0,0,0]]

cld_phase_clr_list_2 = [[128,128,128],[254,192,203],[11,36,250],[127,15,126]]

for i in range(4):
    cld_phase_clr_list_2[i][0] /= 255
    cld_phase_clr_list_2[i][1] /= 255
    cld_phase_clr_list_2[i][2] /= 255

from matplotlib.colors import ListedColormap
cld_phase_cmap = ListedColormap(cld_phase_clr_list_2, 'CldPhase')

cld_mask_list = [[0,220,0], [220, 220, 220]]

for i in range(2):
    cld_mask_list[i][0] /= 255
    cld_mask_list[i][1] /= 255
    cld_mask_list[i][2] /= 255

cld_mask_cmap = ListedColormap(cld_mask_list, 'CldMask')


def plot_train_curve(rundir=None, title=None):
    last_epoch = 0
    num_train_steps = 0
    step = 0
    train_steps = []
    train_losses = []

    if rundir is not None:
        img_name = os.path.split(rundir)[1]+'.png'
        rundir = os.path.join(logdir, rundir)
    else:
        rundir = logdir

    events_file = os.listdir(os.path.join(rundir, 'plot_train'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_train', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_trn':
                train_losses.append(v.simple_value)
                step += run_valid_stepping
            if v.tag == 'num_epochs':
                last_epoch = v.simple_value
            if v.tag == 'num_train_steps':
                train_steps.append(v.simple_value)
                num_train_steps = v.simple_value
    
    step = 0
    lowest_loss = 1000
    lowest_loss_step = 0
    valid_steps = []
    valid_losses = []
    valid_acc = []
    events_file = os.listdir(os.path.join(rundir, 'plot_valid'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_valid', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_val':
                if v.simple_value < lowest_loss:
                    lowest_loss = v.simple_value
                    lowest_loss_step = train_steps[step]
    
                valid_steps.append(step)
                valid_losses.append(v.simple_value)
                step += run_valid_stepping
            if v.tag == 'acc_val':
                valid_acc.append(v.simple_value)

    ratio = last_epoch/num_train_steps

    for i in range(len(train_steps)):
        train_steps[i] *= ratio
    lowest_loss_step *= ratio

    plt.figure(1)
    fig, ax = plt.subplots(1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel('Loss')
    ax.set_xlabel('Epoch')
    ax2 = ax.twinx()
    ax2.set_ylabel('Mean Absolute Error')
    
    plt.grid(True, linestyle=':')
    # plt.subplots_adjust(bottom=0.12, top=0.94, right=0.97, left=0.13)
    plt.subplots_adjust(bottom=0.12, top=0.94)
    plt.ylim([0,1])
    
    train_hdl, = plt.plot(train_steps, train_losses, '-', linewidth=1.5, color='b')
    valid_hdl, = plt.plot(train_steps, valid_losses, '-', linewidth=1.5, color=[1.0, 0.5, 0.0])
    acc, = plt.plot(train_steps, valid_acc, '-', linewidth=1.5, color=[0.5, 0.0, 0.5])
    optim_hdl, = plt.plot(lowest_loss_step, lowest_loss, 'o', color='g')
    
    plt.legend([train_hdl, valid_hdl, acc, optim_hdl], ['Training', 'Validation', 'MAE', 'Best loss'],
            loc='center right', fancybox=True, framealpha=0.5)
    
    # plt.xlabel('Epoch')
    # plt.ylabel('Loss')
    plt.grid(True)
    plt.ion()
    fig = plt.gcf()
    fig.set_size_inches(8,5.3)
    plt.show()
    plt.savefig(os.path.join(root_imgdir, img_name), dpi=100)
    plt.close()


def plot_roc_curve(correct_labels, prob_pos_class):
    auc = roc_auc_score(correct_labels, prob_pos_class)
    print('AUC: ', auc)
    fpr, tpr, _ = roc_curve(correct_labels, prob_pos_class)

    plt.figure(figsize=(6,4))
    plt.plot(fpr, tpr)

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.show()
    plt.savefig('/home/rink/test_roc', dpi=320)


def plot_image(image, fname, label, cmap='Greys'):
    Fig, ax = plt.subplots()

    # Plot the image
    plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-1.5,vmax=2.5)
    # plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-0.5,vmax=1.5)
    # plt.imshow(radGrid, interpolation = "nearest", cmap = cmap, vmin = vmin, vmax = vmax)
    # plt.title(os.path.basename(InFileWV), fontsize=10)

    # Add the colorbar
    cbar = plt.colorbar(ticks=range(3))
    cbar.ax.set_position([0.90, 0.05, 0.03, 0.90])  # [left, bottom, width, height]
    #cbar.ax.set_ylabel(label, rotation=90, fontsize=10)
    #cbar.ax.tick_params(labelsize=10)

    # *Now* that the colorbar is out, we can set the plot's axis position
    ax.set_position([0.05, 0.05, 0.83, 0.90])  # [left, bottom, width, height]

    # Finish the figure settings and save to a file
    Fig.set_size_inches([10.0, 6.6], forward=True)
    Fig.set_dpi(100)  # This seems to not get applied. Always 100?
    # Fig.show() # Display

    # Save to file
    ImageDirAndName = os.path.join('/Users/tomrink', fname)
    Fig.savefig(ImageDirAndName)

    # Wrap up
    Fig.clf()  # Otherwise fig stays open every time
    plt.close()  # Close current figure
    # plt.close('all')


def plot_image2(image, cmap='Greys', title='BOEING Icing Reports 2018-2020'):
    fig = plt.figure(figsize=(12, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())
    img_extent = [-180, 180, -90, 90]
    ax.imshow(image, origin='lower', extent=img_extent, transform=ccrs.PlateCarree(), cmap=cmap, vmin=0.1, vmax=1)
    ax.coastlines(resolution='50m', color='blue', linewidth=0.5)

    plt.title(title, loc='left', fontweight='bold', fontsize=16)

    # Save to file
    ImageDirAndName = os.path.join('/Users/tomrink', 'testimage')
    fig.savefig(ImageDirAndName)

    # Wrap up
    fig.clf()  # Otherwise fig stays open every time
    plt.close()  # Close current figure


def plot_image3(image, bt, cmap='Greys'):

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs('GOES16', 'CONUS')

    fig = plt.figure(figsize=(15, 12))
    ax = fig.add_subplot(1, 1, 1, projection=geos)
    ax.set_extent([-105, -70, 15, 50], crs=ccrs.PlateCarree())
    #ax.imshow(bt, cmap=cmap, origin='upper', extent=(xmin, xmax, ymin, ymax), transform=geos)
    #ax.imshow(image, cmap='YlGnBu', origin='upper', extent=(xmin, xmax, ymin, ymax), vmin=0.6, vmax=1.0, transform=geos)
    ax.imshow(image, cmap='Purples', origin='upper', extent=(xmin, xmax, ymin, ymax), vmin=0.52, vmax=1.0, transform=geos)
    ax.coastlines(resolution='50m', color='blue', linewidth=0.25)
    ax.add_feature(ccrs.cartopy.feature.STATES, linewidth=0.25)
    plt.title('GOES-16', loc='left', fontweight='bold', fontsize=15)
    fig.savefig('/Users/tomrink/image')


def make_icing_image(h5f, probs, ice_lons, ice_lats, clvrx_str_time, satellite, domain, ice_lons_vld=None, ice_lats_vld=None, imagename='icing', extent=[-110, -60, 10, 55]):
    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    bg_image = None
    if h5f is not None:
        bg_image = get_grid_values_all(h5f, 'temp_10_4um_nom')

    fig = plt.figure(figsize=(48, 30))
    ax = fig.add_subplot(1, 1, 1, projection=geos)
    if extent is not None:
        if satellite == 'H08':
            ax.set_extent(extent, crs=geos)
        else:
            ax.set_extent(extent, crs=ccrs.PlateCarree())
    if bg_image is not None:
        bg_im = ax.imshow(bg_image, origin='upper', extent=(xmin, xmax, ymin, ymax), cmap='Greys', transform=geos)
    if probs is not None:
        p_im = ax.imshow(probs, cmap='Purples', origin='upper', extent=(xmin, xmax, ymin, ymax),
                         vmin=0.20, vmax=1.0, alpha=0.8, transform=geos)
    ax.coastlines(resolution='50m', color='green', linewidth=1.50)
    # ax.add_feature(ccrs.cartopy.feature.STATES, linewidth=0.25)

    if ice_lons is not None:
        ax.scatter(ice_lons, ice_lats, s=40.0, marker='o', color='blue', transform=ccrs.PlateCarree())
    if ice_lats_vld is not None:
        ax.scatter(ice_lons_vld, ice_lats_vld, s=180.0, marker='o', color='red', transform=ccrs.PlateCarree())

    title = satellite+', '+clvrx_str_time+'  >50% Probability Icing'
    plt.title(title, loc='left', fontweight='bold', fontsize=24)
    cbar = plt.colorbar(p_im)
    cbar.ax.tick_params(labelsize=30)

    ImageDirAndName = os.path.join(homedir, imagename+'_'+clvrx_str_time)
    fig.savefig(ImageDirAndName)


# values, edges in Epoch time
def make_time_domain_hist(values, edges):
    # convert the epoch format to matplotlib date format
    #mpl_data = mdates.epoch2num(values)

    fig, ax = plt.subplots(1, 1)
    ax.hist(values, bins=edges, ec='black')
    #ax.xaxis.set_major_locator(mdates.MonthLocator())
    #ax.xaxis.set_major_formatter(mdates.DateFormatter('%y-%m-%d_%h'))
    #fig.autofmt_xdate()
    plt.show()