plot.py 5.65 KiB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 22 13:58:09 2019
@author: Wimmers
"""
from sklearn.metrics import roc_auc_score, roc_curve
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import os
from util.setup import home_dir
logdir = home_dir+'/tf_logs'
root_imgdir = home_dir
img_name = 'train_curve.png'
run_valid_stepping = 1
matplotlib.use('tkagg')
cld_phase_clr_list = [[128,128,128], [11,36,250], [45,254,254], [127,15,126], [254,192,203], [0,0,0]]
cld_phase_clr_list_2 = [[128,128,128],[254,192,203],[11,36,250],[127,15,126]]
for i in range(4):
cld_phase_clr_list_2[i][0] /= 255
cld_phase_clr_list_2[i][1] /= 255
cld_phase_clr_list_2[i][2] /= 255
from matplotlib.colors import ListedColormap
cld_phase_cmap = ListedColormap(cld_phase_clr_list_2, 'CldPhase')
cld_mask_list = [[0,220,0], [220, 220, 220]]
for i in range(2):
cld_mask_list[i][0] /= 255
cld_mask_list[i][1] /= 255
cld_mask_list[i][2] /= 255
cld_mask_cmap = ListedColormap(cld_mask_list, 'CldMask')
def plot_train_curve(rundir=None, title=None):
last_epoch = 0
num_train_steps = 0
step = 0
train_steps = []
train_losses = []
if rundir is not None:
img_name = os.path.split(rundir)[1]+'.png'
rundir = os.path.join(logdir, rundir)
else:
rundir = logdir
events_file = os.listdir(os.path.join(rundir, 'plot_train'))[0]
path_to_events_file = os.path.join(rundir, 'plot_train', events_file)
for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
for v in e.summary.value:
if v.tag == 'loss_trn':
train_losses.append(v.simple_value)
step += run_valid_stepping
if v.tag == 'num_epochs':
last_epoch = v.simple_value
if v.tag == 'num_train_steps':
train_steps.append(v.simple_value)
num_train_steps = v.simple_value
step = 0
lowest_loss = 1000
lowest_loss_step = 0
valid_steps = []
valid_losses = []
valid_acc = []
events_file = os.listdir(os.path.join(rundir, 'plot_valid'))[0]
path_to_events_file = os.path.join(rundir, 'plot_valid', events_file)
for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
for v in e.summary.value:
if v.tag == 'loss_val':
if v.simple_value < lowest_loss:
lowest_loss = v.simple_value
lowest_loss_step = train_steps[step]
valid_steps.append(step)
valid_losses.append(v.simple_value)
step += run_valid_stepping
if v.tag == 'acc_val':
valid_acc.append(v.simple_value)
ratio = last_epoch/num_train_steps
for i in range(len(train_steps)):
train_steps[i] *= ratio
lowest_loss_step *= ratio
plt.figure(1)
fig, ax = plt.subplots(1)
if title is not None:
ax.set_title(title)
ax.set_ylabel('Loss')
ax.set_xlabel('Epoch')
ax2 = ax.twinx()
ax2.set_ylabel('Mean Absolute Error')
plt.grid(True, linestyle=':')
# plt.subplots_adjust(bottom=0.12, top=0.94, right=0.97, left=0.13)
plt.subplots_adjust(bottom=0.12, top=0.94)
plt.ylim([0,1])
train_hdl, = plt.plot(train_steps, train_losses, '-', linewidth=1.5, color='b')
valid_hdl, = plt.plot(train_steps, valid_losses, '-', linewidth=1.5, color=[1.0, 0.5, 0.0])
acc, = plt.plot(train_steps, valid_acc, '-', linewidth=1.5, color=[0.5, 0.0, 0.5])
optim_hdl, = plt.plot(lowest_loss_step, lowest_loss, 'o', color='g')
plt.legend([train_hdl, valid_hdl, acc, optim_hdl], ['Training', 'Validation', 'MAE', 'Best loss'],
loc='center right', fancybox=True, framealpha=0.5)
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
plt.grid(True)
plt.ion()
fig = plt.gcf()
fig.set_size_inches(8,5.3)
plt.show()
plt.savefig(os.path.join(root_imgdir, img_name), dpi=100)
plt.close()
def plot_roc_curve(correct_labels, prob_pos_class):
auc = roc_auc_score(correct_labels, prob_pos_class)
print('AUC: ', auc)
fpr, tpr, _ = roc_curve(correct_labels, prob_pos_class)
plt.figure(figsize=(6,4))
plt.plot(fpr, tpr)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.show()
plt.savefig('/home/rink/test_roc', dpi=320)
def plot_image(image, fname, label, cmap='Greys'):
Fig, ax = plt.subplots()
# Plot the image
plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-1.5,vmax=2.5)
# plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-0.5,vmax=1.5)
# plt.imshow(radGrid, interpolation = "nearest", cmap = cmap, vmin = vmin, vmax = vmax)
# plt.title(os.path.basename(InFileWV), fontsize=10)
# Add the colorbar
cbar = plt.colorbar(ticks=range(3))
cbar.ax.set_position([0.90, 0.05, 0.03, 0.90]) # [left, bottom, width, height]
#cbar.ax.set_ylabel(label, rotation=90, fontsize=10)
#cbar.ax.tick_params(labelsize=10)
# *Now* that the colorbar is out, we can set the plot's axis position
ax.set_position([0.05, 0.05, 0.83, 0.90]) # [left, bottom, width, height]
# Finish the figure settings and save to a file
Fig.set_size_inches([10.0, 6.6], forward=True)
Fig.set_dpi(100) # This seems to not get applied. Always 100?
# Fig.show() # Display
# Save to file
ImageDirAndName = os.path.join('/Users/tomrink', fname)
Fig.savefig(ImageDirAndName)
# Wrap up
Fig.clf() # Otherwise fig stays open every time
plt.close() # Close current figure
# plt.close('all')