Skip to content
Snippets Groups Projects
plot.py 13.82 KiB
from sklearn.metrics import roc_auc_score, roc_curve
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import cartopy.crs as ccrs
import xarray as xr
import os
import h5py
from util.util import get_grid_values_all, get_cartopy_crs, homedir
import csv
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.markers import MarkerStyle

from util.setup import home_dir
# import mpl_scatter_density

logdir = home_dir+'/tf_logs'
root_imgdir = home_dir
img_name = 'train_curve.png'

run_valid_stepping = 1

try:
    matplotlib.use('tkagg')
except Exception as e:
    print(e)

cld_phase_clr_list = [[128,128,128], [11,36,250], [45,254,254], [127,15,126], [254,192,203], [0,0,0]]

cld_phase_clr_list_2 = [[128,128,128],[254,192,203],[11,36,250],[127,15,126]]

for i in range(4):
    cld_phase_clr_list_2[i][0] /= 255
    cld_phase_clr_list_2[i][1] /= 255
    cld_phase_clr_list_2[i][2] /= 255

from matplotlib.colors import ListedColormap
cld_phase_cmap = ListedColormap(cld_phase_clr_list_2, 'CldPhase')

cld_mask_list = [[0,220,0], [220, 220, 220]]

for i in range(2):
    cld_mask_list[i][0] /= 255
    cld_mask_list[i][1] /= 255
    cld_mask_list[i][2] /= 255

cld_mask_cmap = ListedColormap(cld_mask_list, 'CldMask')


# # Read colors from CSV file
# colors_file = homedir + 'ICE.ct'
# colors = []
# with open(colors_file, 'r') as csvfile:
#     csv_reader = csv.DictReader(csvfile)
#     for row in csv_reader:
#         colors.append((float(row['R']) / 255, float(row['G']) / 255, float(row['B']) / 255))
# # Define a custom colormap
# cmap_name = 'icing_colormap'
# icing_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)


def plot_train_curve(rundir=None, title=None):
    last_epoch = 0
    num_train_steps = 0
    step = 0
    train_steps = []
    train_losses = []

    if rundir is not None:
        img_name = os.path.split(rundir)[1]+'.png'
        rundir = os.path.join(logdir, rundir)
    else:
        rundir = logdir

    events_file = os.listdir(os.path.join(rundir, 'plot_train'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_train', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_trn':
                train_losses.append(v.simple_value)
                step += run_valid_stepping
            if v.tag == 'num_epochs':
                last_epoch = v.simple_value
            if v.tag == 'num_train_steps':
                train_steps.append(v.simple_value)
                num_train_steps = v.simple_value
    
    step = 0
    lowest_loss = 1000
    lowest_loss_step = 0
    valid_steps = []
    valid_losses = []
    valid_acc = []
    events_file = os.listdir(os.path.join(rundir, 'plot_valid'))[0]
    path_to_events_file = os.path.join(rundir, 'plot_valid', events_file)

    for e in tf.compat.v1.train.summary_iterator(path_to_events_file):
        for v in e.summary.value:
            if v.tag == 'loss_val':
                if v.simple_value < lowest_loss:
                    lowest_loss = v.simple_value
                    lowest_loss_step = train_steps[step]
    
                valid_steps.append(step)
                valid_losses.append(v.simple_value)
                step += run_valid_stepping
            if v.tag == 'acc_val':
                valid_acc.append(v.simple_value)

    ratio = last_epoch/num_train_steps

    for i in range(len(train_steps)):
        train_steps[i] *= ratio
    lowest_loss_step *= ratio

    plt.figure(1)
    fig, ax = plt.subplots(1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel('Loss')
    ax.set_xlabel('Epoch')
    ax2 = ax.twinx()
    ax2.set_ylabel('Mean Absolute Error')
    
    plt.grid(True, linestyle=':')
    # plt.subplots_adjust(bottom=0.12, top=0.94, right=0.97, left=0.13)
    plt.subplots_adjust(bottom=0.12, top=0.94)
    plt.ylim([0,1])
    
    train_hdl, = plt.plot(train_steps, train_losses, '-', linewidth=1.5, color='b')
    valid_hdl, = plt.plot(train_steps, valid_losses, '-', linewidth=1.5, color=[1.0, 0.5, 0.0])
    acc, = plt.plot(train_steps, valid_acc, '-', linewidth=1.5, color=[0.5, 0.0, 0.5])
    optim_hdl, = plt.plot(lowest_loss_step, lowest_loss, 'o', color='g')
    
    plt.legend([train_hdl, valid_hdl, acc, optim_hdl], ['Training', 'Validation', 'MAE', 'Best loss'],
            loc='center right', fancybox=True, framealpha=0.5)
    
    # plt.xlabel('Epoch')
    # plt.ylabel('Loss')
    plt.grid(True)
    plt.ion()
    fig = plt.gcf()
    fig.set_size_inches(8,5.3)
    plt.show()
    plt.savefig(os.path.join(root_imgdir, img_name), dpi=100)
    plt.close()


def plot_roc_curve(correct_labels, prob_pos_class):
    auc = roc_auc_score(correct_labels, prob_pos_class)
    print('AUC: ', auc)
    fpr, tpr, _ = roc_curve(correct_labels, prob_pos_class)

    plt.figure(figsize=(6,4))
    plt.plot(fpr, tpr)

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.show()
    plt.savefig(homedir+'test_roc', dpi=320)


def plot_image(image, cmap='Greys', vmin=None, vmax=None, fname=None, title=None, do_color_bar=True):
    Fig, ax = plt.subplots()

    # Plot the image
    plt.imshow(image, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
    if title is not None:
        plt.title(os.path.basename(title), fontsize=10, fontweight='bold')

    # Add the colorbar
    if do_color_bar:
        cbar = plt.colorbar()
        cbar.ax.set_position([0.90, 0.05, 0.03, 0.90])  # [left, bottom, width, height]
        #cbar.ax.set_ylabel(label, rotation=90, fontsize=10)
        #cbar.ax.tick_params(labelsize=10)

        # *Now* that the colorbar is out, we can set the plot's axis position
        ax.set_position([0.05, 0.05, 0.83, 0.90])  # [left, bottom, width, height]

    # Finish the figure settings and save to a file
    Fig.set_size_inches([10.0, 6.6], forward=True)
    Fig.set_dpi(100)  # This seems to not get applied. Always 100?

    # Save to file
    if fname is not None:
        ImageDirAndName = os.path.join(homedir, fname)
        Fig.savefig(ImageDirAndName)
    else:
        Fig.show()

    # Wrap up
    #Fig.clf()  # Otherwise fig stays open every time
    #plt.close()  # Close current figure
    # plt.close('all')


def multi_plot(image_s, title_s, cmap, vmin=None, vmax=None, do_color_bar=True):
    fig = plt.figure(figsize=(12, 8))
    fig.add_subplot(1, 2, 1)
    plt.imshow(image_s[0], cmap=cmap, vmin=vmin, vmax=vmax)
    plt.title(title_s[0], fontweight='bold')
    if do_color_bar:
        plt.colorbar()

    fig.add_subplot(1, 2, 2)
    plt.imshow(image_s[1], cmap=cmap, vmin=vmin, vmax=vmax)
    plt.title(title_s[1], fontweight='bold')
    if do_color_bar:
        plt.colorbar()


def plot_image2(image, cmap='Greys', title='BOEING Icing Reports 2018-2020'):
    fig = plt.figure(figsize=(12, 8))
    ax = plt.axes(projection=ccrs.PlateCarree())
    img_extent = [-180, 180, -90, 90]
    ax.imshow(image, origin='lower', extent=img_extent, transform=ccrs.PlateCarree(), cmap=cmap, vmin=0.1, vmax=1)
    ax.coastlines(resolution='50m', color='blue', linewidth=0.5)

    plt.title(title, loc='left', fontweight='bold', fontsize=16)

    # Save to file
    ImageDirAndName = os.path.join('/Users/tomrink', 'testimage')
    fig.savefig(ImageDirAndName)

    # Wrap up
    fig.clf()  # Otherwise fig stays open every time
    plt.close()  # Close current figure


def plot_image3(image, bt, cmap='Greys'):

    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs('GOES16', 'CONUS')

    fig = plt.figure(figsize=(15, 12))
    ax = fig.add_subplot(1, 1, 1, projection=geos)
    ax.set_extent([-105, -70, 15, 50], crs=ccrs.PlateCarree())
    #ax.imshow(bt, cmap=cmap, origin='upper', extent=(xmin, xmax, ymin, ymax), transform=geos)
    #ax.imshow(image, cmap='YlGnBu', origin='upper', extent=(xmin, xmax, ymin, ymax), vmin=0.6, vmax=1.0, transform=geos)
    ax.imshow(image, cmap='Purples', origin='upper', extent=(xmin, xmax, ymin, ymax), vmin=0.52, vmax=1.0, transform=geos)
    ax.coastlines(resolution='50m', color='blue', linewidth=0.25)
    ax.add_feature(ccrs.cartopy.feature.STATES, linewidth=0.25)
    plt.title('GOES-16', loc='left', fontweight='bold', fontsize=15)
    fig.savefig('/Users/tomrink/image')


def make_icing_image(h5f, probs, ice_lons, ice_lats, clvrx_str_time, satellite, domain, ice_lons_vld=None, ice_lats_vld=None,
                     ice_intensity=None, imagename='icing', extent=[-110, -60, 10, 55], prob_thresh=0.50):
    # keep for reference
    # mrkr_styl = MarkerStyle('o', fillstyle=None)
    geos, xlen, xmin, xmax, ylen, ymin, ymax = get_cartopy_crs(satellite, domain)
    bg_image = None
    if h5f is not None:
        bg_image = get_grid_values_all(h5f, 'temp_10_4um_nom')

    fig = plt.figure(figsize=(48, 30))
    ax = fig.add_subplot(1, 1, 1, projection=geos)
    if extent is not None:
        if satellite == 'H08':
            ax.set_extent(extent, crs=geos)
        elif satellite == 'H09':
            ax.set_extent(extent, crs=geos)
        else:
            ax.set_extent(extent, crs=ccrs.PlateCarree())
    if bg_image is not None:
        bg_im = ax.imshow(bg_image, origin='upper', extent=(xmin, xmax, ymin, ymax), cmap='Greys', transform=geos)
    if probs is not None:
        # p_im = ax.imshow(probs, cmap='Purples', origin='upper', extent=(xmin, xmax, ymin, ymax),
        p_im = ax.imshow(probs, cmap='jet', origin='upper', extent=(xmin, xmax, ymin, ymax),
        # p_im = ax.imshow(probs, cmap=icing_cmap, origin='upper', extent=(xmin, xmax, ymin, ymax),
                         vmin=prob_thresh, vmax=1.0, alpha=0.8, transform=geos)
    ax.coastlines(resolution='50m', color='green', linewidth=1.50)
    ax.add_feature(ccrs.cartopy.feature.STATES, linewidth=1.50, edgecolor='green')
    ax.add_feature(ccrs.cartopy.feature.BORDERS, linewidth=1.50, edgecolor='green')

    if ice_lons is not None:
        ax.scatter(ice_lons, ice_lats, s=40.0, marker='o', color='blue', transform=ccrs.PlateCarree())
    if ice_lats_vld is not None:
        if ice_intensity is None:
            # ax.scatter(ice_lons_vld, ice_lats_vld, s=180.0, marker='o', color='red', transform=ccrs.PlateCarree())
            ax.scatter(ice_lons_vld, ice_lats_vld, s=500.0, marker='o', color='white', transform=ccrs.PlateCarree())
            ax.scatter(ice_lons_vld, ice_lats_vld, s=200.0, marker='o', color='black', transform=ccrs.PlateCarree())
        else:
            lons = ice_lons_vld[ice_intensity == 1]
            lats = ice_lats_vld[ice_intensity == 1]
            ax.scatter(lons, lats, s=500.0, marker='o', color='white', transform=ccrs.PlateCarree())
            ax.scatter(lons, lats, s=200.0, marker='o', color='black', transform=ccrs.PlateCarree())
            ax.scatter(lons, lats, s=5000.0, marker='o', facecolors='none', edgecolors='white', linewidths=2.0,
                       transform=ccrs.PlateCarree())
            lons = ice_lons_vld[ice_intensity > 1]
            lats = ice_lats_vld[ice_intensity > 1]
            ax.scatter(lons, lats, s=500.0, marker='^', color='white', transform=ccrs.PlateCarree())
            ax.scatter(lons, lats, s=200.0, marker='^', color='black', transform=ccrs.PlateCarree())
            ax.scatter(lons, lats, s=5000.0, marker='o', facecolors='none', edgecolors='white', linewidths=2.0,
                       transform=ccrs.PlateCarree())
            lons = ice_lons_vld[ice_intensity == -1]
            lats = ice_lats_vld[ice_intensity == -1]
            # ax.scatter(lons, lats, s=500.0, marker='o', color='white', transform=ccrs.PlateCarree())
            # ax.scatter(lons, lats, s=200.0, marker='o', color='green', transform=ccrs.PlateCarree())

    title = satellite+', '+clvrx_str_time+'  >50% Probability Icing'
    plt.title(title, loc='left', fontweight='bold', fontsize=24)
    cbar = plt.colorbar(p_im)
    cbar.set_label(label='Icing Probability', fontsize=30)
    cbar.ax.tick_params(labelsize=30)

    ImageDirAndName = os.path.join(homedir, imagename+'_'+clvrx_str_time)
    fig.savefig(ImageDirAndName)


def plot_grid2d(grid_2d, x_map_2d, y_map_2d, proj, vmin=None, vmax=None, cmap='jet'):
    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1, projection=proj)

    # Save mappable object returned by pcolormesh
    p = ax.pcolormesh(x_map_2d, y_map_2d, grid_2d, vmin=vmin, vmax=vmax, cmap=cmap, shading='auto')

    ax.coastlines(resolution='50m', color='green', linewidth=1.50)

    # Add gridlines for longitudes and latitudes
    gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                      linewidth=2, color='gray', alpha=0.5, linestyle='--')

    gl.xlabels_top = gl.ylabels_right = False
    gl.ylabel_style = {'rotation': 90}
    gl.ylabel_padding = 15

    # Add color bar
    cbar = plt.colorbar(p, orientation="horizontal", pad=0.05)
    cbar.set_label('Value', size=15)  # Set color bar label

    plt.show()


# values, edges in Epoch time
def make_time_domain_hist(values, edges):
    # convert the epoch format to matplotlib date format
    # mpl_data = mdates.epoch2num(values)

    fig, ax = plt.subplots(1, 1)
    ax.hist(values, bins=edges, ec='black')
    # ax.xaxis.set_major_locator(mdates.MonthLocator())
    # ax.xaxis.set_major_formatter(mdates.DateFormatter('%y-%m-%d_%h'))
    # fig.autofmt_xdate()
    plt.show()


def scatter_density(x, y, color=None, fname=None, fig=None, ax=None, x_rng=None, y_rng=None, dpi=72):
    if fig is None:
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1, projection='scatter_density')
    if color is not None:
        ax.scatter_density(x, y, color=color, dpi=dpi)
    else:
        ax.scatter_density(x, y, dpi=dpi)
    if x_rng is not None and y_rng is not None:
        ax.set_xlim(x_rng[0], x_rng[1])
        ax.set_ylim(y_rng[0], y_rng[1])

    if fname is not None:
        fig.savefig(fname)

    return fig, ax