#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Wed May 22 13:58:09 2019 @author: Wimmers """ from sklearn.metrics import roc_auc_score, roc_curve import tensorflow as tf import matplotlib import matplotlib.pyplot as plt import cartopy.crs as ccrs import xarray as xr import os from util.setup import home_dir logdir = home_dir+'/tf_logs' root_imgdir = home_dir img_name = 'train_curve.png' run_valid_stepping = 1 matplotlib.use('tkagg') cld_phase_clr_list = [[128,128,128], [11,36,250], [45,254,254], [127,15,126], [254,192,203], [0,0,0]] cld_phase_clr_list_2 = [[128,128,128],[254,192,203],[11,36,250],[127,15,126]] for i in range(4): cld_phase_clr_list_2[i][0] /= 255 cld_phase_clr_list_2[i][1] /= 255 cld_phase_clr_list_2[i][2] /= 255 from matplotlib.colors import ListedColormap cld_phase_cmap = ListedColormap(cld_phase_clr_list_2, 'CldPhase') cld_mask_list = [[0,220,0], [220, 220, 220]] for i in range(2): cld_mask_list[i][0] /= 255 cld_mask_list[i][1] /= 255 cld_mask_list[i][2] /= 255 cld_mask_cmap = ListedColormap(cld_mask_list, 'CldMask') def plot_train_curve(rundir=None, title=None): last_epoch = 0 num_train_steps = 0 step = 0 train_steps = [] train_losses = [] if rundir is not None: img_name = os.path.split(rundir)[1]+'.png' rundir = os.path.join(logdir, rundir) else: rundir = logdir events_file = os.listdir(os.path.join(rundir, 'plot_train'))[0] path_to_events_file = os.path.join(rundir, 'plot_train', events_file) for e in tf.compat.v1.train.summary_iterator(path_to_events_file): for v in e.summary.value: if v.tag == 'loss_trn': train_losses.append(v.simple_value) step += run_valid_stepping if v.tag == 'num_epochs': last_epoch = v.simple_value if v.tag == 'num_train_steps': train_steps.append(v.simple_value) num_train_steps = v.simple_value step = 0 lowest_loss = 1000 lowest_loss_step = 0 valid_steps = [] valid_losses = [] valid_acc = [] events_file = os.listdir(os.path.join(rundir, 'plot_valid'))[0] path_to_events_file = os.path.join(rundir, 'plot_valid', events_file) for e in tf.compat.v1.train.summary_iterator(path_to_events_file): for v in e.summary.value: if v.tag == 'loss_val': if v.simple_value < lowest_loss: lowest_loss = v.simple_value lowest_loss_step = train_steps[step] valid_steps.append(step) valid_losses.append(v.simple_value) step += run_valid_stepping if v.tag == 'acc_val': valid_acc.append(v.simple_value) ratio = last_epoch/num_train_steps for i in range(len(train_steps)): train_steps[i] *= ratio lowest_loss_step *= ratio plt.figure(1) fig, ax = plt.subplots(1) if title is not None: ax.set_title(title) ax.set_ylabel('Loss') ax.set_xlabel('Epoch') ax2 = ax.twinx() ax2.set_ylabel('Mean Absolute Error') plt.grid(True, linestyle=':') # plt.subplots_adjust(bottom=0.12, top=0.94, right=0.97, left=0.13) plt.subplots_adjust(bottom=0.12, top=0.94) plt.ylim([0,1]) train_hdl, = plt.plot(train_steps, train_losses, '-', linewidth=1.5, color='b') valid_hdl, = plt.plot(train_steps, valid_losses, '-', linewidth=1.5, color=[1.0, 0.5, 0.0]) acc, = plt.plot(train_steps, valid_acc, '-', linewidth=1.5, color=[0.5, 0.0, 0.5]) optim_hdl, = plt.plot(lowest_loss_step, lowest_loss, 'o', color='g') plt.legend([train_hdl, valid_hdl, acc, optim_hdl], ['Training', 'Validation', 'MAE', 'Best loss'], loc='center right', fancybox=True, framealpha=0.5) # plt.xlabel('Epoch') # plt.ylabel('Loss') plt.grid(True) plt.ion() fig = plt.gcf() fig.set_size_inches(8,5.3) plt.show() plt.savefig(os.path.join(root_imgdir, img_name), dpi=100) plt.close() def plot_roc_curve(correct_labels, prob_pos_class): auc = roc_auc_score(correct_labels, prob_pos_class) print('AUC: ', auc) fpr, tpr, _ = roc_curve(correct_labels, prob_pos_class) plt.figure(figsize=(6,4)) plt.plot(fpr, tpr) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic') plt.show() plt.savefig('/home/rink/test_roc', dpi=320) def plot_image(image, fname, label, cmap='Greys'): Fig, ax = plt.subplots() # Plot the image plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-1.5,vmax=2.5) # plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=-0.5,vmax=1.5) # plt.imshow(radGrid, interpolation = "nearest", cmap = cmap, vmin = vmin, vmax = vmax) # plt.title(os.path.basename(InFileWV), fontsize=10) # Add the colorbar cbar = plt.colorbar(ticks=range(3)) cbar.ax.set_position([0.90, 0.05, 0.03, 0.90]) # [left, bottom, width, height] #cbar.ax.set_ylabel(label, rotation=90, fontsize=10) #cbar.ax.tick_params(labelsize=10) # *Now* that the colorbar is out, we can set the plot's axis position ax.set_position([0.05, 0.05, 0.83, 0.90]) # [left, bottom, width, height] # Finish the figure settings and save to a file Fig.set_size_inches([10.0, 6.6], forward=True) Fig.set_dpi(100) # This seems to not get applied. Always 100? # Fig.show() # Display # Save to file ImageDirAndName = os.path.join('/Users/tomrink', fname) Fig.savefig(ImageDirAndName) # Wrap up Fig.clf() # Otherwise fig stays open every time plt.close() # Close current figure # plt.close('all') def plot_image2(image, cmap='Greys'): fig = plt.figure(figsize=(12, 8)) ax = plt.axes(projection=ccrs.PlateCarree()) img_extent = [-180, 180, -90, 90] ax.imshow(image, origin='lower', extent=img_extent, transform=ccrs.PlateCarree(), cmap=cmap, vmin=-1.5, vmax=2.5) ax.coastlines(resolution='50m', color='blue', linewidth=1) # Save to file ImageDirAndName = os.path.join('/Users/tomrink', 'testimage') fig.savefig(ImageDirAndName) # Wrap up fig.clf() # Otherwise fig stays open every time plt.close() # Close current figure # Example GOES file to retrieve GEOS parameters in MetPy form (CONUS) exmp_file = '/Users/tomrink/data/OR_ABI-L1b-RadC-M6C14_G16_s20193140811215_e20193140813588_c20193140814070.nc' def plot_image3(image, cmap='Greys'): exmpl_ds = xr.open_dataset(exmp_file) mdat = exmpl_ds.metpy.parse_cf('Rad') geos = mdat.metpy.cartopy_crs x = mdat.x y = mdat.y fig = plt.figure(figsize=(15, 12)) ax = fig.add_subplot(1, 1, 1, projection=geos) ax.imshow(image, origin='upper', extent=(x.min(), x.max(), y.min(), y.max()), transform=geos) ax.coastlines(resolution='50m', color='black', linewidth=0.25) ax.add_feature(ccrs.cartopy.feature.STATES, linewidth=0.25) plt.title('GOES-16', loc='left', fontweight='bold', fontsize=15)