Skip to content
Snippets Groups Projects
plot.py 5.65 KiB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 22 13:58:09 2019

@author: Wimmers
"""

from sklearn.metrics import roc_auc_score, roc_curve
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import os

from util.setup import home_dir

logdir = home_dir+'/tf_logs'
root_imgdir = home_dir
img_name = 'train_curve.png'

run_valid_stepping = 1

matplotlib.use('tkagg')

cld_phase_clr_list = [[128,128,128], [11,36,250], [45,254,254], [127,15,126], [254,192,203], [0,0,0]]

cld_phase_clr_list_2 = [[128,128,128],[254,192,203],[11,36,250],[127,15,126]]

for i in range(4):
    cld_phase_clr_list_2[i][0] /= 255
    cld_phase_clr_list_2[i][1] /= 255
    cld_phase_clr_list_2[i][2] /= 255

from matplotlib.colors import ListedColormap
cld_phase_cmap = ListedColormap(cld_phase_clr_list_2, 'CldPhase')

cld_mask_list = [[0,220,0], [220, 220, 220]]

for i in range(2):
    cld_mask_list[i][0] /= 255
    cld_mask_list[i][1] /= 255
    cld_mask_list[i][2] /= 255

cld_mask_cmap = ListedColormap(cld_mask_list, 'CldMask')


def plot_train_curve(rundir=None, title=None):
    last_epoch = 0
    num_train_steps = 0
    step = 0
    train_steps = []
    train_losses = []

    if rundir is not None:
        img_name = os.path.split(rundir)[1]+'.png'
        rundir = os.path.join(logdir, rundir)
    else:
        rundir = logdir

    events_file = os.listdir(os.path.join(rundir, 'plot_train'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_train', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_trn':
                train_losses.append(v.simple_value)
                step += run_valid_stepping
            if v.tag == 'num_epochs':
                last_epoch = v.simple_value
            if v.tag == 'num_train_steps':
                train_steps.append(v.simple_value)
                num_train_steps = v.simple_value
    
    step = 0
    lowest_loss = 1000
    lowest_loss_step = 0
    valid_steps = []
    valid_losses = []
    valid_acc = []
    events_file = os.listdir(os.path.join(rundir, 'plot_valid'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_valid', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_val':
                if v.simple_value < lowest_loss:
                    lowest_loss = v.simple_value
                    lowest_loss_step = train_steps[step]
    
                valid_steps.append(step)
                valid_losses.append(v.simple_value)
                step += run_valid_stepping
            if v.tag == 'acc_val':
                valid_acc.append(v.simple_value)

    ratio = last_epoch/num_train_steps

    for i in range(len(train_steps)):
        train_steps[i] *= ratio
    lowest_loss_step *= ratio

    plt.figure(1)
    fig, ax = plt.subplots(1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel('Loss')
    ax.set_xlabel('Epoch')
    ax2 = ax.twinx()
    ax2.set_ylabel('Mean Absolute Error')
    
    plt.grid(True, linestyle=':')
    # plt.subplots_adjust(bottom=0.12, top=0.94, right=0.97, left=0.13)
    plt.subplots_adjust(bottom=0.12, top=0.94)
    plt.ylim([0,1])
    
    train_hdl, = plt.plot(train_steps, train_losses, '-', linewidth=1.5, color='b')
    valid_hdl, = plt.plot(train_steps, valid_losses, '-', linewidth=1.5, color=[1.0, 0.5, 0.0])
    acc, = plt.plot(train_steps, valid_acc, '-', linewidth=1.5, color=[0.5, 0.0, 0.5])
    optim_hdl, = plt.plot(lowest_loss_step, lowest_loss, 'o', color='g')
    
    plt.legend([train_hdl, valid_hdl, acc, optim_hdl], ['Training', 'Validation', 'MAE', 'Best loss'],
            loc='center right', fancybox=True, framealpha=0.5)
    
    # plt.xlabel('Epoch')
    # plt.ylabel('Loss')
    plt.grid(True)
    plt.ion()
    fig = plt.gcf()
    fig.set_size_inches(8,5.3)
    plt.show()
    plt.savefig(os.path.join(root_imgdir, img_name), dpi=100)
    plt.close()


def plot_roc_curve(correct_labels, prob_pos_class):
    auc = roc_auc_score(correct_labels, prob_pos_class)
    print('AUC: ', auc)
    fpr, tpr, _ = roc_curve(correct_labels, prob_pos_class)

    plt.figure(figsize=(6,4))
    plt.plot(fpr, tpr)

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.show()
    plt.savefig('/home/rink/test_roc', dpi=320)


def plot_image(image, fname, label, cmap='Greys'):
    Fig, ax = plt.subplots()

    # Plot the image
    plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-1.5,vmax=2.5)
    # plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-0.5,vmax=1.5)
    # plt.imshow(radGrid, interpolation = "nearest", cmap = cmap, vmin = vmin, vmax = vmax)
    # plt.title(os.path.basename(InFileWV), fontsize=10)

    # Add the colorbar
    cbar = plt.colorbar(ticks=range(3))
    cbar.ax.set_position([0.90, 0.05, 0.03, 0.90])  # [left, bottom, width, height]
    #cbar.ax.set_ylabel(label, rotation=90, fontsize=10)
    #cbar.ax.tick_params(labelsize=10)

    # *Now* that the colorbar is out, we can set the plot's axis position
    ax.set_position([0.05, 0.05, 0.83, 0.90])  # [left, bottom, width, height]

    # Finish the figure settings and save to a file
    Fig.set_size_inches([10.0, 6.6], forward=True)
    Fig.set_dpi(100)  # This seems to not get applied. Always 100?
    # Fig.show() # Display

    # Save to file
    ImageDirAndName = os.path.join('/Users/tomrink', fname)
    Fig.savefig(ImageDirAndName)

    # Wrap up
    Fig.clf()  # Otherwise fig stays open every time
    plt.close()  # Close current figure
    # plt.close('all')