Skip to content
Snippets Groups Projects
Commit b52713d8 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 8291c8bf
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
from util.plot_cm import confusion_matrix_values
from util.augment import augment_image
from util.setup_cloud_products import logdir, modeldir, now, ancillary_path
from util.util import EarlyStop, normalize, denormalize, scale, scale2, get_grid_values_all, make_tf_callable_generator
import glob
import os, datetime
import numpy as np
import pickle
import h5py
import xarray as xr
import gc
import time
AUTOTUNE = tf.data.AUTOTUNE
LOG_DEVICE_PLACEMENT = False
PROC_BATCH_SIZE = 4
PROC_BATCH_BUFFER_SIZE = 5000
# NumClasses = 7
NumClasses = 5
if NumClasses == 2:
NumLogits = 1
else:
NumLogits = NumClasses
BATCH_SIZE = 128
NUM_EPOCHS = 80
TRACK_MOVING_AVERAGE = False
EARLY_STOP = True
PATIENCE = 7
NOISE_TRAINING = False
NOISE_STDDEV = 0.01
DO_AUGMENT = True
DO_SMOOTH = False
SIGMA = 1.0
DO_ZERO_OUT = False
# CACHE_FILE = '/scratch/long/rink/cld_opd_abi_128x128_cache'
CACHE_FILE = ''
# setup scaling parameters dictionary
mean_std_dct = {}
mean_std_file = ancillary_path+'mean_std_lo_hi_l2.pkl'
f = open(mean_std_file, 'rb')
mean_std_dct_l2 = pickle.load(f)
f.close()
mean_std_file = ancillary_path+'mean_std_lo_hi_l1b.pkl'
f = open(mean_std_file, 'rb')
mean_std_dct_l1b = pickle.load(f)
f.close()
mean_std_dct.update(mean_std_dct_l1b)
mean_std_dct.update(mean_std_dct_l2)
IMG_DEPTH = 1
label_param = 'cloud_probability'
params = ['temp_11_0um_nom', 'refl_0_65um_nom', 'refl_submin_ch01', 'refl_submax_ch01', 'refl_substddev_ch01', label_param]
params_i = ['temp_11_0um_nom', 'refl_0_65um_nom', label_param]
# data_params_half = ['temp_11_0um_nom']
data_params_half = ['temp_11_0um_nom', 'refl_0_65um_nom']
sub_fields = ['refl_submin_ch01', 'refl_submax_ch01', 'refl_substddev_ch01']
data_params_full = ['refl_0_65um_nom']
label_idx_i = params_i.index(label_param)
label_idx = params.index(label_param)
print('data_params_half: ', data_params_half)
print('data_params_full: ', data_params_full)
print('label_param: ', label_param)
KERNEL_SIZE = 3
X_LEN = Y_LEN = 128
if KERNEL_SIZE == 3:
slc_x = slice(0, int(X_LEN/4) + 2)
slc_y = slice(0, int(Y_LEN/4) + 2)
x_64 = slice(4, X_LEN + 4)
y_64 = slice(4, Y_LEN + 4)
# ----------------------------------------
def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME',
kernel_initializer='he_uniform', scale=None, kernel_size=3,
do_drop_out=True, drop_rate=0.5, do_batch_norm=True):
with tf.name_scope(block_name):
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=kernel_size, padding=padding, kernel_initializer=kernel_initializer, activation=activation)(conv)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=kernel_size, padding=padding, activation=None)(skip)
if scale is not None:
skip = tf.keras.layers.Lambda(lambda x: x * scale)(skip)
if do_drop_out:
skip = tf.keras.layers.Dropout(drop_rate)(skip)
if do_batch_norm:
skip = tf.keras.layers.BatchNormalization()(skip)
conv = conv + skip
print(block_name+':', conv.shape)
return conv
def upsample_mean(grd):
bsize, ylen, xlen = grd.shape
up = np.zeros((bsize, ylen*2, xlen*2))
up[:, ::4, ::4] = grd[:, ::4, ::4]
up[:, 1::4, ::4] = grd[:, ::4, ::4]
up[:, ::4, 1::4] = grd[:, ::4, ::4]
up[:, 1::4, 1::4] = grd[:, ::4, ::4]
return up
def get_grid_cell_mean(grd_k):
mean = np.nanmean([grd_k[:, 0::4, 0::4], grd_k[:, 1::4, 0::4], grd_k[:, 2::4, 0::4], grd_k[:, 3::4, 0::4],
grd_k[:, 0::4, 1::4], grd_k[:, 1::4, 1::4], grd_k[:, 2::4, 1::4], grd_k[:, 3::4, 1::4],
grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
grd_k[:, 0::4, 3::4], grd_k[:, 1::4, 3::4], grd_k[:, 2::4, 3::4], grd_k[:, 3::4, 3::4]], axis=0)
np.where(np.isnan(mean), 0, mean)
return mean
def get_min_max_std(grd_k):
lo = np.nanmin([grd_k[:, 0::4, 0::4], grd_k[:, 1::4, 0::4], grd_k[:, 2::4, 0::4], grd_k[:, 3::4, 0::4],
grd_k[:, 0::4, 1::4], grd_k[:, 1::4, 1::4], grd_k[:, 2::4, 1::4], grd_k[:, 3::4, 1::4],
grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
grd_k[:, 0::4, 3::4], grd_k[:, 1::4, 3::4], grd_k[:, 2::4, 3::4], grd_k[:, 3::4, 3::4]], axis=0)
hi = np.nanmax([grd_k[:, 0::4, 0::4], grd_k[:, 1::4, 0::4], grd_k[:, 2::4, 0::4], grd_k[:, 3::4, 0::4],
grd_k[:, 0::4, 1::4], grd_k[:, 1::4, 1::4], grd_k[:, 2::4, 1::4], grd_k[:, 3::4, 1::4],
grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
grd_k[:, 0::4, 3::4], grd_k[:, 1::4, 3::4], grd_k[:, 2::4, 3::4], grd_k[:, 3::4, 3::4]], axis=0)
std = np.nanstd([grd_k[:, 0::4, 0::4], grd_k[:, 1::4, 0::4], grd_k[:, 2::4, 0::4], grd_k[:, 3::4, 0::4],
grd_k[:, 0::4, 1::4], grd_k[:, 1::4, 1::4], grd_k[:, 2::4, 1::4], grd_k[:, 3::4, 1::4],
grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
grd_k[:, 0::4, 3::4], grd_k[:, 1::4, 3::4], grd_k[:, 2::4, 3::4], grd_k[:, 3::4, 3::4]], axis=0)
avg = np.nanmean([grd_k[:, 0::4, 0::4], grd_k[:, 1::4, 0::4], grd_k[:, 2::4, 0::4], grd_k[:, 3::4, 0::4],
grd_k[:, 0::4, 1::4], grd_k[:, 1::4, 1::4], grd_k[:, 2::4, 1::4], grd_k[:, 3::4, 1::4],
grd_k[:, 0::4, 2::4], grd_k[:, 1::4, 2::4], grd_k[:, 2::4, 2::4], grd_k[:, 3::4, 2::4],
grd_k[:, 0::4, 3::4], grd_k[:, 1::4, 3::4], grd_k[:, 2::4, 3::4], grd_k[:, 3::4, 3::4]], axis=0)
np.where(np.isnan(lo), 0, lo)
np.where(np.isnan(hi), 0, hi)
np.where(np.isnan(std), 0, std)
np.where(np.isnan(avg), 0, avg)
return lo, hi, std, avg
def get_label_data_5cat(grd_k):
grd_k = np.where(np.isnan(grd_k), 0, grd_k)
grd_k = np.where(grd_k < 0.5, 0, 1)
s = grd_k[:, 0::4, 0::4] + grd_k[:, 1::4, 0::4] + grd_k[:, 2::4, 0::4] + grd_k[:, 3::4, 0::4] + \
grd_k[:, 0::4, 1::4] + grd_k[:, 1::4, 1::4] + grd_k[:, 2::4, 1::4] + grd_k[:, 3::4, 1::4] + \
grd_k[:, 0::4, 2::4] + grd_k[:, 1::4, 2::4] + grd_k[:, 2::4, 2::4] + grd_k[:, 3::4, 2::4] + \
grd_k[:, 0::4, 3::4] + grd_k[:, 1::4, 3::4] + grd_k[:, 2::4, 3::4] + grd_k[:, 3::4, 3::4]
cat_0 = np.logical_and(s >= 0, s < 1)
cat_1 = np.logical_and(s >= 1, s < 6)
cat_2 = np.logical_and(s >= 6, s < 11)
cat_3 = np.logical_and(s >= 11, s <= 15)
cat_4 = np.logical_and(s > 15, s <= 16)
s[cat_0] = 0
s[cat_1] = 1
s[cat_2] = 2
s[cat_3] = 3
s[cat_4] = 4
return s
def get_label_data_7cat(grd_k):
grd_k = np.where(np.isnan(grd_k), 0, grd_k)
grd_k = np.where(grd_k < 0.5, 0, 1)
s = grd_k[:, 0::4, 0::4] + grd_k[:, 1::4, 0::4] + grd_k[:, 2::4, 0::4] + grd_k[:, 3::4, 0::4] + \
grd_k[:, 0::4, 1::4] + grd_k[:, 1::4, 1::4] + grd_k[:, 2::4, 1::4] + grd_k[:, 3::4, 1::4] + \
grd_k[:, 0::4, 2::4] + grd_k[:, 1::4, 2::4] + grd_k[:, 2::4, 2::4] + grd_k[:, 3::4, 2::4] + \
grd_k[:, 0::4, 3::4] + grd_k[:, 1::4, 3::4] + grd_k[:, 2::4, 3::4] + grd_k[:, 3::4, 3::4]
cat_0 = np.logical_and(s >= 0, s < 1) # CLDY
cat_1 = np.logical_and(s >= 1, s < 4)
cat_2 = np.logical_and(s >= 4, s < 7)
cat_3 = np.logical_and(s >= 7, s < 10)
cat_4 = np.logical_and(s >= 10, s < 13)
cat_5 = np.logical_and(s >= 13, s <= 15)
cat_6 = np.logical_and(s > 15, s <= 16) # CLR
s[cat_0] = 0
s[cat_1] = 1
s[cat_2] = 2
s[cat_3] = 3
s[cat_4] = 4
s[cat_5] = 5
s[cat_6] = 6
return s
def get_label_data(grd_k):
grd_k = np.where(np.isnan(grd_k), 0, grd_k)
grd_k = np.where(grd_k < 0.5, 0, 1)
s = grd_k[:, 0::4, 0::4] + grd_k[:, 1::4, 0::4] + grd_k[:, 2::4, 0::4] + grd_k[:, 3::4, 0::4] + \
grd_k[:, 0::4, 1::4] + grd_k[:, 1::4, 1::4] + grd_k[:, 2::4, 1::4] + grd_k[:, 3::4, 1::4] + \
grd_k[:, 0::4, 2::4] + grd_k[:, 1::4, 2::4] + grd_k[:, 2::4, 2::4] + grd_k[:, 3::4, 2::4] + \
grd_k[:, 0::4, 3::4] + grd_k[:, 1::4, 3::4] + grd_k[:, 2::4, 3::4] + grd_k[:, 3::4, 3::4]
cat_0 = np.logical_and(s >= 0, s < 3)
cat_1 = np.logical_and(s >= 3, s < 14)
cat_2 = np.logical_and(s >= 14, s <= 16)
s[cat_0] = 0
s[cat_1] = 1
s[cat_2] = 2
return s
class SRCNN:
def __init__(self):
self.train_data = None
self.train_label = None
self.test_data = None
self.test_label = None
self.test_data_denorm = None
self.train_dataset = None
self.inner_train_dataset = None
self.test_dataset = None
self.eval_dataset = None
self.X_img = None
self.X_prof = None
self.X_u = None
self.X_v = None
self.X_sfc = None
self.inputs = []
self.y = None
self.handle = None
self.inner_handle = None
self.in_mem_batch = None
self.h5f_l1b_trn = None
self.h5f_l1b_tst = None
self.h5f_l2_trn = None
self.h5f_l2_tst = None
self.logits = None
self.predict_data = None
self.predict_dataset = None
self.mean_list = None
self.std_list = None
self.training_op = None
self.correct = None
self.accuracy = None
self.loss = None
self.pred_class = None
self.variable_averages = None
self.global_step = None
self.writer_train = None
self.writer_valid = None
self.writer_train_valid_loss = None
self.OUT_OF_RANGE = False
self.model = None
self.optimizer = None
self.ema = None
self.train_loss = None
self.train_accuracy = None
self.test_loss = None
self.test_accuracy = None
self.test_auc = None
self.test_recall = None
self.test_precision = None
self.test_confusion_matrix = None
self.test_true_pos = None
self.test_true_neg = None
self.test_false_pos = None
self.test_false_neg = None
self.test_labels = []
self.test_preds = []
self.test_probs = None
self.test_input = []
self.learningRateSchedule = None
self.num_data_samples = None
self.initial_learning_rate = None
self.data_dct = None
self.train_data_files = None
self.train_label_files = None
self.test_data_files = None
self.test_label_files = None
self.n_chans = 5
self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
self.inputs.append(self.X_img)
tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)
def get_in_mem_data_batch(self, idxs, is_training):
if is_training:
data_files = self.train_data_files
label_files = self.train_label_files
else:
data_files = self.test_data_files
label_files = self.test_label_files
data_s = []
label_s = []
for k in idxs:
f = data_files[k]
nda = np.load(f)
data_s.append(nda)
f = label_files[k]
nda = np.load(f)
label_s.append(nda)
input_data = np.concatenate(data_s)
input_label = np.concatenate(label_s)
data_norm = []
for param in data_params_half:
idx = params.index(param)
tmp = input_data[:, idx, :, :]
tmp = tmp[:, slc_y, slc_x]
tmp = normalize(tmp, param, mean_std_dct)
# tmp = scale(tmp, param, mean_std_dct)
data_norm.append(tmp)
# refl_i = input_label[:, params_i.index('refl_0_65um_nom'), :, :]
# refl_avg = get_grid_cell_mean(refl_i)
# refl_avg = refl_avg[:, slc_y, slc_x]
# refl_avg = normalize(refl_avg, 'refl_0_65um_nom', mean_std_dct)
# data_norm.append(refl_avg)
#
# rlo, rhi, rstd, _ = get_min_max_std(refl_i)
rlo = input_data[:, params.index('refl_submin_ch01'), :, :]
rlo = rlo[:, slc_y, slc_x]
rlo = normalize(rlo, 'refl_0_65um_nom', mean_std_dct)
# rlo = scale(rlo, 'refl_0_65um_nom', mean_std_dct)
rhi = input_data[:, params.index('refl_submax_ch01'), :, :]
rhi = rhi[:, slc_y, slc_x]
rhi = normalize(rhi, 'refl_0_65um_nom', mean_std_dct)
# rhi = scale(rhi, 'refl_0_65um_nom', mean_std_dct)
refl_rng = rhi - rlo
data_norm.append(refl_rng)
rstd = input_data[:, params.index('refl_substddev_ch01'), :, :]
rstd = rstd[:, slc_y, slc_x]
rstd = scale2(rstd, 0.0, 20.0)
data_norm.append(rstd)
tmp = input_label[:, label_idx_i, :, :]
tmp = get_grid_cell_mean(tmp)
tmp = tmp[:, slc_y, slc_x]
data_norm.append(tmp)
# ---------
data = np.stack(data_norm, axis=3)
data = data.astype(np.float32)
# -----------------------------------------------------
# -----------------------------------------------------
label = input_label[:, label_idx_i, :, :]
label = label[:, y_64, x_64]
if NumClasses == 5:
label = get_label_data_5cat(label)
elif NumClasses == 7:
label = get_label_data_7cat(label)
else:
label = get_label_data(label)
label = np.where(np.isnan(label), 0, label)
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
label = label.astype(np.float32)
# if is_training and DO_AUGMENT:
# data_ud = np.flip(data, axis=1)
# label_ud = np.flip(label, axis=1)
#
# data_lr = np.flip(data, axis=2)
# label_lr = np.flip(label, axis=2)
#
# data = np.concatenate([data, data_ud, data_lr])
# label = np.concatenate([label, label_ud, label_lr])
return data, label
def get_in_mem_data_batch_train(self, idxs):
return self.get_in_mem_data_batch(idxs, True)
def get_in_mem_data_batch_test(self, idxs):
return self.get_in_mem_data_batch(idxs, False)
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32])
return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_test(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32])
return out
def get_train_dataset(self, num_files):
def integer_gen(limit):
n = 0
while n < limit:
yield n
n += 1
num_gen = integer_gen(num_files)
gen = make_tf_callable_generator(num_gen)
dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function, num_parallel_calls=8)
dataset = dataset.cache(filename=CACHE_FILE)
dataset = dataset.shuffle(PROC_BATCH_BUFFER_SIZE, reshuffle_each_iteration=True)
if DO_AUGMENT:
dataset = dataset.map(augment_image(), num_parallel_calls=8)
dataset = dataset.prefetch(buffer_size=1)
self.train_dataset = dataset
def get_test_dataset(self, num_files):
def integer_gen(limit):
n = 0
while n < limit:
yield n
n += 1
num_gen = integer_gen(num_files)
gen = make_tf_callable_generator(num_gen)
dataset = tf.data.Dataset.from_generator(gen, output_types=tf.int32)
dataset = dataset.batch(PROC_BATCH_SIZE)
dataset = dataset.map(self.data_function_test, num_parallel_calls=8)
dataset = dataset.cache()
self.test_dataset = dataset
def setup_pipeline(self, train_data_files, train_label_files, test_data_files, test_label_files, num_train_samples):
self.train_data_files = train_data_files
self.train_label_files = train_label_files
self.test_data_files = test_data_files
self.test_label_files = test_label_files
self.get_train_dataset(len(train_data_files))
self.get_test_dataset(len(test_data_files))
self.num_data_samples = num_train_samples # approximately
print('datetime: ', now)
print('training and test data: ')
print('---------------------------')
print('num train files: ', len(train_data_files))
print('BATCH SIZE: ', BATCH_SIZE)
print('num test files: ', len(test_data_files))
print('setup_pipeline: Done')
def setup_test_pipeline(self, test_data_files, test_label_files):
self.test_data_files = test_data_files
self.test_label_files = test_label_files
self.get_test_dataset(len(test_data_files))
print('setup_test_pipeline: Done')
def build_srcnn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2):
print('build_cnn')
padding = "SAME"
# activation = tf.nn.relu
# activation = tf.nn.elu
activation = tf.nn.relu
momentum = 0.99
num_filters = 64
input_2d = self.inputs[0]
print('input: ', input_2d.shape)
conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=KERNEL_SIZE, kernel_initializer='he_uniform', activation=activation, padding='VALID')(input_2d)
print(conv.shape)
# if NOISE_TRAINING:
# conv = conv_b = tf.keras.layers.GaussianNoise(stddev=NOISE_STDDEV)(conv)
scale = 0.2
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_1', kernel_size=KERNEL_SIZE, scale=scale)
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_2', kernel_size=KERNEL_SIZE, scale=scale)
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_3', kernel_size=KERNEL_SIZE, scale=scale)
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_4', kernel_size=KERNEL_SIZE, scale=scale)
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_5', kernel_size=KERNEL_SIZE, scale=scale)
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_6', kernel_size=KERNEL_SIZE, scale=scale)
conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, activation=activation, kernel_initializer='he_uniform', padding=padding)(conv_b)
# conv = conv + conv_b
conv = conv_b
print(conv.shape)
if NumClasses == 2:
final_activation = tf.nn.sigmoid # For binary
else:
final_activation = tf.nn.softmax # For multi-class
# This is effectively a Dense layer
self.logits = tf.keras.layers.Conv2D(NumLogits, kernel_size=1, strides=1, padding=padding, activation=final_activation)(conv)
print(self.logits.shape)
def build_training(self):
if NumClasses == 2:
self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only
else:
self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) # For multi-class
# self.loss = tf.keras.losses.MeanAbsoluteError() # Regression
# decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
initial_learning_rate = 0.001
decay_rate = 0.95
steps_per_epoch = int(self.num_data_samples/BATCH_SIZE) # one epoch
decay_steps = int(steps_per_epoch) * 1
print('initial rate, decay rate, steps/epoch, decay steps: ', initial_learning_rate, decay_rate, steps_per_epoch, decay_steps)
self.learningRateSchedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps, decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learningRateSchedule)
if TRACK_MOVING_AVERAGE:
# Not sure that this works properly (from tfa)
# optimizer = tfa.optimizers.MovingAverage(optimizer)
self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)
self.optimizer = optimizer
self.initial_learning_rate = initial_learning_rate
def build_evaluation(self):
self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.test_loss = tf.keras.metrics.Mean(name='test_loss')
if NumClasses == 2:
self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
self.test_auc = tf.keras.metrics.AUC(name='test_auc')
self.test_recall = tf.keras.metrics.Recall(name='test_recall')
self.test_precision = tf.keras.metrics.Precision(name='test_precision')
self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
else:
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def train_step(self, inputs, labels):
labels = tf.squeeze(labels, axis=[3])
with tf.GradientTape() as tape:
pred = self.model([inputs], training=True)
loss = self.loss(labels, pred)
total_loss = loss
if len(self.model.losses) > 0:
reg_loss = tf.math.add_n(self.model.losses)
total_loss = loss + reg_loss
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
if TRACK_MOVING_AVERAGE:
self.ema.apply(self.model.trainable_variables)
self.train_loss(loss)
self.train_accuracy(labels, pred)
return loss
@tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
def test_step(self, inputs, labels):
labels = tf.squeeze(labels, axis=[3])
pred = self.model([inputs], training=False)
t_loss = self.loss(labels, pred)
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
# @tf.function(input_signature=[tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.float32)])
# decorator commented out because pred.numpy(): pred not evaluated yet.
def predict(self, inputs, labels):
pred = self.model([inputs], training=False)
# t_loss = self.loss(tf.squeeze(labels, axis=[3]), pred)
t_loss = self.loss(labels, pred)
self.test_labels.append(labels)
self.test_preds.append(pred.numpy())
self.test_input.append(inputs)
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
def reset_test_metrics(self):
self.test_loss.reset_states()
self.test_accuracy.reset_states()
def get_metrics(self):
recall = self.test_recall.result()
precsn = self.test_precision.result()
f1 = 2 * (precsn * recall) / (precsn + recall)
tn = self.test_true_neg.result()
tp = self.test_true_pos.result()
fn = self.test_false_neg.result()
fp = self.test_false_pos.result()
mcc = ((tp * tn) - (fp * fn)) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
return f1, mcc
def do_training(self, ckpt_dir=None):
if ckpt_dir is None:
if not os.path.exists(modeldir):
os.mkdir(modeldir)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, modeldir, max_to_keep=3)
else:
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
self.writer_train = tf.summary.create_file_writer(os.path.join(logdir, 'plot_train'))
self.writer_valid = tf.summary.create_file_writer(os.path.join(logdir, 'plot_valid'))
self.writer_train_valid_loss = tf.summary.create_file_writer(os.path.join(logdir, 'plot_train_valid_loss'))
step = 0
total_time = 0
best_test_loss = np.finfo(dtype=np.float64).max
if EARLY_STOP:
es = EarlyStop(patience=PATIENCE)
for epoch in range(NUM_EPOCHS):
self.train_loss.reset_states()
self.train_accuracy.reset_states()
t0 = datetime.datetime.now().timestamp()
proc_batch_cnt = 0
n_samples = 0
for data, label in self.train_dataset:
trn_ds = tf.data.Dataset.from_tensor_slices((data, label))
trn_ds = trn_ds.batch(BATCH_SIZE)
for mini_batch in trn_ds:
if self.learningRateSchedule is not None:
loss = self.train_step(mini_batch[0], mini_batch[1])
if (step % 100) == 0:
with self.writer_train.as_default():
tf.summary.scalar('loss_trn', loss.numpy(), step=step)
tf.summary.scalar('learning_rate', self.optimizer.lr.numpy(), step=step)
tf.summary.scalar('num_train_steps', step, step=step)
tf.summary.scalar('num_epochs', epoch, step=step)
self.reset_test_metrics()
for data_tst, label_tst in self.test_dataset:
tst_ds = tf.data.Dataset.from_tensor_slices((data_tst, label_tst))
tst_ds = tst_ds.batch(BATCH_SIZE)
for mini_batch_test in tst_ds:
self.test_step(mini_batch_test[0], mini_batch_test[1])
with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
with self.writer_train_valid_loss.as_default():
tf.summary.scalar('loss_trn', loss.numpy(), step=step)
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
print('****** test loss, acc, lr: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
self.optimizer.lr.numpy())
step += 1
print('train loss: ', loss.numpy())
proc_batch_cnt += 1
n_samples += data.shape[0]
print('proc_batch_cnt: ', proc_batch_cnt, n_samples)
t1 = datetime.datetime.now().timestamp()
print('End of Epoch: ', epoch+1, 'elapsed time: ', (t1-t0))
total_time += (t1-t0)
self.reset_test_metrics()
for data, label in self.test_dataset:
ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.batch(BATCH_SIZE)
for mini_batch in ds:
self.test_step(mini_batch[0], mini_batch[1])
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss:
best_test_loss = tst_loss
ckpt_manager.save()
if EARLY_STOP and es.check_stop(tst_loss):
break
print('total time: ', total_time)
self.writer_train.close()
self.writer_valid.close()
self.writer_train_valid_loss.close()
def build_model(self):
self.build_srcnn()
self.model = tf.keras.Model(self.inputs, self.logits)
def restore(self, ckpt_dir):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
self.reset_test_metrics()
for data, label in self.test_dataset:
ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.batch(BATCH_SIZE)
for mini_batch_test in ds:
self.predict(mini_batch_test[0], mini_batch_test[1])
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
labels = np.concatenate(self.test_labels)
preds = np.concatenate(self.test_preds)
inputs = np.concatenate(self.test_input)
print(labels.shape, preds.shape)
return labels, preds, inputs
def do_evaluate(self, inputs, ckpt_dir):
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
self.reset_test_metrics()
pred = self.model([inputs], training=False)
self.test_probs = pred
pred = pred.numpy()
return pred
def run(self, directory, ckpt_dir=None, num_data_samples=50000):
train_data_files = glob.glob(directory+'train*mres*.npy')
valid_data_files = glob.glob(directory+'valid*mres*.npy')
train_label_files = glob.glob(directory+'train*ires*.npy')
valid_label_files = glob.glob(directory+'valid*ires*.npy')
self.setup_pipeline(train_data_files, train_label_files, valid_data_files, valid_label_files, num_data_samples)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training(ckpt_dir=ckpt_dir)
def run_restore(self, directory, ckpt_dir):
self.num_data_samples = 1000
valid_data_files = glob.glob(directory + 'valid*mres*.npy')
valid_label_files = [f.replace('mres', 'ires') for f in valid_data_files]
self.setup_test_pipeline(valid_data_files, valid_label_files)
self.build_model()
self.build_training()
self.build_evaluation()
return self.restore(ckpt_dir)
def run_evaluate(self, data, ckpt_dir):
# data = tf.convert_to_tensor(data, dtype=tf.float32)
self.num_data_samples = 80000
self.build_model()
self.build_training()
self.build_evaluation()
return self.do_evaluate(data, ckpt_dir)
def setup_inference(self, ckpt_dir):
self.num_data_samples = 80000
self.build_model()
self.build_training()
self.build_evaluation()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=self.model)
ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3)
ckpt.restore(ckpt_manager.latest_checkpoint)
def do_inference(self, inputs):
self.reset_test_metrics()
pred = self.model([inputs], training=False)
self.test_probs = pred
pred = pred.numpy()
return pred
def run_inference(self, in_file, out_file):
gc.collect()
h5f = h5py.File(in_file, 'r')
bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
y_len, x_len = bt.shape
refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
refl_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
refl_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
refl_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
cp = get_grid_values_all(h5f, label_param)
cld_frac = self.run_inference_(bt, refl, refl_lo, refl_hi, refl_std, cp)
cld_frac_out = np.full((y_len, x_len), -1, dtype=np.int8)
border = int((KERNEL_SIZE - 1) / 2)
cld_frac_out[border:y_len - border, border:x_len - border] = cld_frac[0, :, :]
# Use this hack for now.
off_earth = (bt <= 161.0)
night = np.isnan(refl)
cld_frac_out[off_earth] = -1
cld_frac_out[np.invert(off_earth) & night] = -1
# --- Make a DataArray ----------------------------------------------------
# var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um']
# dims = ['num_params', 'y', 'x']
# da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims)
# da.assign_coords({
# 'num_params': var_names,
# 'lat': (['y', 'x'], lats),
# 'lon': (['y', 'x'], lons)
# })
# ---------------------------------------------------------------------------
h5f.close()
if out_file is not None:
np.save(out_file, (cld_frac_out, bt, refl, cp))
else:
# return [cld_frac_out, bt, refl, cp, lons, lats]
return cld_frac_out
def run_inference_full_disk(self, in_file, out_file):
gc.collect()
t0 = time.time()
h5f = h5py.File(in_file, 'r')
bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
y_len, x_len = bt.shape
h_y_len = int(y_len / 2)
refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
refl_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
refl_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
refl_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
cp = get_grid_values_all(h5f, label_param)
t1 = time.time()
print(' read time:', (t1-t0))
bt_nh = bt[0:h_y_len + 1, :]
refl_nh = refl[0:h_y_len + 1, :]
refl_lo_nh = refl_lo[0:h_y_len + 1, :]
refl_hi_nh = refl_hi[0:h_y_len + 1, :]
refl_std_nh = refl_std[0:h_y_len + 1, :]
cp_nh = cp[0:h_y_len + 1, :]
bt_sh = bt[h_y_len - 1:y_len, :]
refl_sh = refl[h_y_len - 1:y_len, :]
refl_lo_sh = refl_lo[h_y_len - 1:y_len, :]
refl_hi_sh = refl_hi[h_y_len - 1:y_len, :]
refl_std_sh = refl_std[h_y_len - 1:y_len, :]
cp_sh = cp[h_y_len - 1:y_len, :]
t0 = time.time()
cld_frac_nh = self.run_inference_(bt_nh, refl_nh, refl_lo_nh, refl_hi_nh, refl_std_nh, cp_nh)
cld_frac_sh = self.run_inference_(bt_sh, refl_sh, refl_lo_sh, refl_hi_sh, refl_std_sh, cp_sh)
t1 = time.time()
print(' inference time: ', (t1-t0))
cld_frac_out = np.full((y_len, x_len), -1, dtype=np.int8)
border = int((KERNEL_SIZE - 1) / 2)
cld_frac_out[border:h_y_len, border:x_len - border] = cld_frac_nh[0, :, :]
cld_frac_out[h_y_len:y_len - border, border:x_len - border] = cld_frac_sh[0, :, :]
# Use this hack for now.
off_earth = (bt <= 161.0)
night = np.isnan(refl)
cld_frac_out[off_earth] = -1
cld_frac_out[np.invert(off_earth) & night] = -1
# --- Make DataArray -------------------------------------------------
# var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um']
# dims = ['num_params', 'y', 'x']
# da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims)
# da.assign_coords({
# 'num_params': var_names,
# 'lat': (['y', 'x'], lats),
# 'lon': (['y', 'x'], lons)
# })
# ------------------------------------------------------------------------
h5f.close()
if out_file is not None:
np.save(out_file, (cld_frac_out, bt, refl, cp))
else:
# return [cld_frac_out, bt, refl, cp, lons, lats]
return cld_frac_out
def run_inference_(self, bt, refl, refl_lo, refl_hi, refl_std, cp):
bt = normalize(bt, 'temp_11_0um_nom', mean_std_dct)
refl = normalize(refl, 'refl_0_65um_nom', mean_std_dct)
refl_lo = normalize(refl_lo, 'refl_0_65um_nom', mean_std_dct)
refl_hi = normalize(refl_hi, 'refl_0_65um_nom', mean_std_dct)
refl_rng = refl_hi - refl_lo
refl_std = np.where(np.isnan(refl_std), 0, refl_std)
cp = np.where(np.isnan(cp), 0, cp)
data = np.stack([bt, refl, refl_rng, refl_std, cp], axis=2)
data = np.expand_dims(data, axis=0)
probs = self.do_inference(data)
cld_frac = probs.argmax(axis=3)
cld_frac = cld_frac.astype(np.int8)
return cld_frac
def run_restore_static(directory, ckpt_dir, out_file=None):
nn = SRCNN()
labels, preds, inputs = nn.run_restore(directory, ckpt_dir)
if out_file is not None:
y_hi, x_hi = (Y_LEN // 4) + 1, (X_LEN // 4) + 1
np.save(out_file,
[np.squeeze(labels), preds.argmax(axis=3),
denormalize(inputs[:, 1:y_hi, 1:x_hi, 0], 'temp_11_0um_nom', mean_std_dct),
denormalize(inputs[:, 1:y_hi, 1:x_hi, 1], 'refl_0_65um_nom', mean_std_dct),
inputs[:, 1:y_hi, 1:x_hi, 2],
inputs[:, 1:y_hi, 1:x_hi, 3],
inputs[:, 1:y_hi, 1:x_hi, 4]])
def run_evaluate_static(in_file, out_file, ckpt_dir):
gc.collect()
h5f = h5py.File(in_file, 'r')
bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
y_len, x_len = bt.shape
refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
refl_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
refl_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
refl_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
cp = get_grid_values_all(h5f, label_param)
# lons = get_grid_values_all(h5f, 'longitude')
# lats = get_grid_values_all(h5f, 'latitude')
cld_frac = run_evaluate_static_(bt, refl, refl_lo, refl_hi, refl_std, cp, ckpt_dir)
cld_frac_out = np.zeros((y_len, x_len), dtype=np.int8)
border = int((KERNEL_SIZE - 1)/2)
cld_frac_out[border:y_len-border, border:x_len - border] = cld_frac[0, :, :]
# Use this hack for now.
off_earth = (bt <= 161.0)
night = np.isnan(refl)
cld_frac_out[off_earth] = -1
cld_frac_out[np.invert(off_earth) & night] = -1
# --- Make a DataArray ----------------------------------------------------
# var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um']
# dims = ['num_params', 'y', 'x']
# da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims)
# da.assign_coords({
# 'num_params': var_names,
# 'lat': (['y', 'x'], lats),
# 'lon': (['y', 'x'], lons)
# })
# ---------------------------------------------------------------------------
h5f.close()
if out_file is not None:
np.save(out_file, (cld_frac_out, bt, refl, cp))
else:
# return [cld_frac_out, bt, refl, cp, lons, lats]
return cld_frac_out
def run_evaluate_static_full_disk(in_file, out_file, ckpt_dir):
gc.collect()
h5f = h5py.File(in_file, 'r')
bt = get_grid_values_all(h5f, 'temp_11_0um_nom')
y_len, x_len = bt.shape
h_y_len = int(y_len/2)
refl = get_grid_values_all(h5f, 'refl_0_65um_nom')
refl_lo = get_grid_values_all(h5f, 'refl_0_65um_nom_min_sub')
refl_hi = get_grid_values_all(h5f, 'refl_0_65um_nom_max_sub')
refl_std = get_grid_values_all(h5f, 'refl_0_65um_nom_stddev_sub')
cp = get_grid_values_all(h5f, label_param)
# lons = get_grid_values_all(h5f, 'longitude')
# lats = get_grid_values_all(h5f, 'latitude')
bt_nh = bt[0:h_y_len+1, :]
refl_nh = refl[0:h_y_len+1, :]
refl_lo_nh = refl_lo[0:h_y_len+1, :]
refl_hi_nh = refl_hi[0:h_y_len+1, :]
refl_std_nh = refl_std[0:h_y_len+1, :]
cp_nh = cp[0:h_y_len+1, :]
bt_sh = bt[h_y_len-1:y_len, :]
refl_sh = refl[h_y_len-1:y_len, :]
refl_lo_sh = refl_lo[h_y_len-1:y_len, :]
refl_hi_sh = refl_hi[h_y_len-1:y_len, :]
refl_std_sh = refl_std[h_y_len-1:y_len, :]
cp_sh = cp[h_y_len-1:y_len, :]
cld_frac_nh = run_evaluate_static_(bt_nh, refl_nh, refl_lo_nh, refl_hi_nh, refl_std_nh, cp_nh, ckpt_dir)
cld_frac_sh = run_evaluate_static_(bt_sh, refl_sh, refl_lo_sh, refl_hi_sh, refl_std_sh, cp_sh, ckpt_dir)
cld_frac_out = np.zeros((y_len, x_len), dtype=np.int8)
border = int((KERNEL_SIZE - 1)/2)
cld_frac_out[border:h_y_len, border:x_len - border] = cld_frac_nh[0, :, :]
cld_frac_out[h_y_len:y_len - border, border:x_len - border] = cld_frac_sh[0, :, :]
# Use this hack for now.
off_earth = (bt <= 161.0)
night = np.isnan(refl)
cld_frac_out[off_earth] = -1
cld_frac_out[np.invert(off_earth) & night] = -1
# --- Make DataArray -------------------------------------------------
# var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um']
# dims = ['num_params', 'y', 'x']
# da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims)
# da.assign_coords({
# 'num_params': var_names,
# 'lat': (['y', 'x'], lats),
# 'lon': (['y', 'x'], lons)
# })
# ------------------------------------------------------------------------
h5f.close()
if out_file is not None:
np.save(out_file, (cld_frac_out, bt, refl, cp))
else:
# return [cld_frac_out, bt, refl, cp, lons, lats]
return cld_frac_out
def run_evaluate_static_valid(in_file, out_file, ckpt_dir):
gc.collect()
h5f = h5py.File(in_file, 'r')
bt = get_grid_values_all(h5f, 'orig/temp_ch38')
y_len, x_len = bt.shape
refl = get_grid_values_all(h5f, 'orig/refl_ch01')
refl_lo = get_grid_values_all(h5f, 'orig/refl_submin_ch01')
refl_hi = get_grid_values_all(h5f, 'orig/refl_submax_ch01')
refl_std = get_grid_values_all(h5f, 'orig/refl_substddev_ch01')
cp = get_grid_values_all(h5f, 'orig/'+label_param)
lons = get_grid_values_all(h5f, 'orig/longitude')
lats = get_grid_values_all(h5f, 'orig/latitude')
cp_sres = get_grid_values_all(h5f, 'super/'+label_param)
mean_cp_sres = get_grid_cell_mean(np.expand_dims(cp_sres, axis=0))[0]
cld_frac_truth = get_label_data_5cat(np.expand_dims(cp_sres, axis=0))[0]
h5f.close()
cld_frac = run_evaluate_static_(bt, refl, refl_lo, refl_hi, refl_std, cp, ckpt_dir)
cld_frac_out = np.zeros((y_len, x_len), dtype=np.int8)
border = int((KERNEL_SIZE - 1)/2)
cld_frac_out[border:y_len-border, border:x_len - border] = cld_frac[0, :, :]
var_names = ['cloud_fraction', 'temp_11_0um', 'refl_0_65um']
dims = ['num_params', 'y', 'x']
da = xr.DataArray(np.stack([cld_frac_out, bt, refl], axis=0), dims=dims)
da.assign_coords({
'num_params': var_names,
'lat': (['y', 'x'], lats),
'lon': (['y', 'x'], lons)
})
if out_file is not None:
np.save(out_file, (cld_frac_out, bt, refl, cp, lons, lats, mean_cp_sres, cld_frac_truth))
else:
return [cld_frac_out, bt, refl, cp, lons, lats]
def run_evaluate_static_(bt, refl, refl_lo, refl_hi, refl_std, cp, ckpt_dir):
bt = normalize(bt, 'temp_11_0um_nom', mean_std_dct)
refl = normalize(refl, 'refl_0_65um_nom', mean_std_dct)
refl_lo = normalize(refl_lo, 'refl_0_65um_nom', mean_std_dct)
refl_hi = normalize(refl_hi, 'refl_0_65um_nom', mean_std_dct)
refl_std = np.where(np.isnan(refl_std), 0, refl_std)
cp = np.where(np.isnan(cp), 0, cp)
data = np.stack([bt, refl, refl_lo, refl_hi, refl_std, cp], axis=2)
data = np.expand_dims(data, axis=0)
nn = SRCNN()
probs = nn.run_evaluate(data, ckpt_dir)
cld_frac = probs.argmax(axis=3)
cld_frac = cld_frac.astype(np.int8)
return cld_frac
def analyze_3cat(file):
tup = np.load(file, allow_pickle=True)
lbls = tup[0]
pred = tup[1]
lbls = lbls.flatten()
pred = pred.flatten()
print(np.sum(lbls == 0), np.sum(lbls == 1), np.sum(lbls == 2))
msk_0_1 = lbls != 2
msk_1_2 = lbls != 0
msk_0_2 = lbls != 1
lbls_0_1 = lbls[msk_0_1]
pred_0_1 = pred[msk_0_1]
pred_0_1 = np.where(pred_0_1 == 2, 1, pred_0_1)
# ----
lbls_1_2 = lbls[msk_1_2]
lbls_1_2 = np.where(lbls_1_2 == 1, 0, lbls_1_2)
lbls_1_2 = np.where(lbls_1_2 == 2, 1, lbls_1_2)
pred_1_2 = pred[msk_1_2]
pred_1_2 = np.where(pred_1_2 == 0, -9, pred_1_2)
pred_1_2 = np.where(pred_1_2 == 1, 0, pred_1_2)
pred_1_2 = np.where(pred_1_2 == 2, 1, pred_1_2)
pred_1_2 = np.where(pred_1_2 == -9, 1, pred_1_2)
# ----
lbls_0_2 = lbls[msk_0_2]
lbls_0_2 = np.where(lbls_0_2 == 2, 1, lbls_0_2)
pred_0_2 = pred[msk_0_2]
pred_0_2 = np.where(pred_0_2 == 2, 1, pred_0_2)
cm_0_1 = confusion_matrix_values(lbls_0_1, pred_0_1)
cm_1_2 = confusion_matrix_values(lbls_1_2, pred_1_2)
cm_0_2 = confusion_matrix_values(lbls_0_2, pred_0_2)
true_0_1 = (lbls_0_1 == 0) & (pred_0_1 == 0)
false_0_1 = (lbls_0_1 == 1) & (pred_0_1 == 0)
true_no_0_1 = (lbls_0_1 == 1) & (pred_0_1 == 1)
false_no_0_1 = (lbls_0_1 == 0) & (pred_0_1 == 1)
true_0_2 = (lbls_0_2 == 0) & (pred_0_2 == 0)
false_0_2 = (lbls_0_2 == 1) & (pred_0_2 == 0)
true_no_0_2 = (lbls_0_2 == 1) & (pred_0_2 == 1)
false_no_0_2 = (lbls_0_2 == 0) & (pred_0_2 == 1)
true_1_2 = (lbls_1_2 == 0) & (pred_1_2 == 0)
false_1_2 = (lbls_1_2 == 1) & (pred_1_2 == 0)
true_no_1_2 = (lbls_1_2 == 1) & (pred_1_2 == 1)
false_no_1_2 = (lbls_1_2 == 0) & (pred_1_2 == 1)
tp_0 = np.sum(true_0_1).astype(np.float64)
tp_1 = np.sum(true_1_2).astype(np.float64)
tp_2 = np.sum(true_0_2).astype(np.float64)
tn_0 = np.sum(true_no_0_1).astype(np.float64)
tn_1 = np.sum(true_no_1_2).astype(np.float64)
tn_2 = np.sum(true_no_0_2).astype(np.float64)
fp_0 = np.sum(false_0_1).astype(np.float64)
fp_1 = np.sum(false_1_2).astype(np.float64)
fp_2 = np.sum(false_0_2).astype(np.float64)
fn_0 = np.sum(false_no_0_1).astype(np.float64)
fn_1 = np.sum(false_no_1_2).astype(np.float64)
fn_2 = np.sum(false_no_0_2).astype(np.float64)
recall_0 = tp_0 / (tp_0 + fn_0)
recall_1 = tp_1 / (tp_1 + fn_1)
recall_2 = tp_2 / (tp_2 + fn_2)
precision_0 = tp_0 / (tp_0 + fp_0)
precision_1 = tp_1 / (tp_1 + fp_1)
precision_2 = tp_2 / (tp_2 + fp_2)
mcc_0 = ((tp_0 * tn_0) - (fp_0 * fn_0)) / np.sqrt((tp_0 + fp_0) * (tp_0 + fn_0) * (tn_0 + fp_0) * (tn_0 + fn_0))
mcc_1 = ((tp_1 * tn_1) - (fp_1 * fn_1)) / np.sqrt((tp_1 + fp_1) * (tp_1 + fn_1) * (tn_1 + fp_1) * (tn_1 + fn_1))
mcc_2 = ((tp_2 * tn_2) - (fp_2 * fn_2)) / np.sqrt((tp_2 + fp_2) * (tp_2 + fn_2) * (tn_2 + fp_2) * (tn_2 + fn_2))
acc_0 = np.sum(lbls_0_1 == pred_0_1)/pred_0_1.size
acc_1 = np.sum(lbls_1_2 == pred_1_2)/pred_1_2.size
acc_2 = np.sum(lbls_0_2 == pred_0_2)/pred_0_2.size
print(acc_0, recall_0, precision_0, mcc_0)
print(acc_1, recall_1, precision_1, mcc_1)
print(acc_2, recall_2, precision_2, mcc_2)
return cm_0_1, cm_1_2, cm_0_2, [acc_0, acc_1, acc_2], [recall_0, recall_1, recall_2],\
[precision_0, precision_1, precision_2], [mcc_0, mcc_1, mcc_2]
def analyze_5cat(file):
tup = np.load(file, allow_pickle=True)
lbls = tup[0]
pred = tup[1]
lbls = lbls.flatten()
pred = pred.flatten()
np.histogram(lbls, bins=5)
np.histogram(pred, bins=5)
new_lbls = np.zeros(lbls.size, dtype=np.int32)
new_pred = np.zeros(pred.size, dtype=np.int32)
new_lbls[lbls == 0] = 0
new_lbls[lbls == 1] = 1
new_lbls[lbls == 2] = 1
new_lbls[lbls == 3] = 1
new_lbls[lbls == 4] = 2
new_pred[pred == 0] = 0
new_pred[pred == 1] = 1
new_pred[pred == 2] = 1
new_pred[pred == 3] = 1
new_pred[pred == 4] = 2
np.histogram(new_lbls, bins=3)
np.histogram(new_pred, bins=3)
lbls = new_lbls
pred = new_pred
print(np.sum(lbls == 0), np.sum(lbls == 1), np.sum(lbls == 2))
msk_0_1 = lbls != 2
msk_1_2 = lbls != 0
msk_0_2 = lbls != 1
lbls_0_1 = lbls[msk_0_1]
pred_0_1 = pred[msk_0_1]
pred_0_1 = np.where(pred_0_1 == 2, 1, pred_0_1)
# ----------------------------------------------
lbls_1_2 = lbls[msk_1_2]
lbls_1_2 = np.where(lbls_1_2 == 1, 0, lbls_1_2)
lbls_1_2 = np.where(lbls_1_2 == 2, 1, lbls_1_2)
pred_1_2 = pred[msk_1_2]
pred_1_2 = np.where(pred_1_2 == 0, -9, pred_1_2)
pred_1_2 = np.where(pred_1_2 == 1, 0, pred_1_2)
pred_1_2 = np.where(pred_1_2 == 2, 1, pred_1_2)
pred_1_2 = np.where(pred_1_2 == -9, 1, pred_1_2)
# -----------------------------------------------
lbls_0_2 = lbls[msk_0_2]
lbls_0_2 = np.where(lbls_0_2 == 2, 1, lbls_0_2)
pred_0_2 = pred[msk_0_2]
pred_0_2 = np.where(pred_0_2 == 2, 1, pred_0_2)
cm_0_1 = confusion_matrix_values(lbls_0_1, pred_0_1)
cm_1_2 = confusion_matrix_values(lbls_1_2, pred_1_2)
cm_0_2 = confusion_matrix_values(lbls_0_2, pred_0_2)
true_0_1 = (lbls_0_1 == 0) & (pred_0_1 == 0)
false_0_1 = (lbls_0_1 == 1) & (pred_0_1 == 0)
true_no_0_1 = (lbls_0_1 == 1) & (pred_0_1 == 1)
false_no_0_1 = (lbls_0_1 == 0) & (pred_0_1 == 1)
true_0_2 = (lbls_0_2 == 0) & (pred_0_2 == 0)
false_0_2 = (lbls_0_2 == 1) & (pred_0_2 == 0)
true_no_0_2 = (lbls_0_2 == 1) & (pred_0_2 == 1)
false_no_0_2 = (lbls_0_2 == 0) & (pred_0_2 == 1)
true_1_2 = (lbls_1_2 == 0) & (pred_1_2 == 0)
false_1_2 = (lbls_1_2 == 1) & (pred_1_2 == 0)
true_no_1_2 = (lbls_1_2 == 1) & (pred_1_2 == 1)
false_no_1_2 = (lbls_1_2 == 0) & (pred_1_2 == 1)
tp_0 = np.sum(true_0_1).astype(np.float64)
tp_1 = np.sum(true_1_2).astype(np.float64)
tp_2 = np.sum(true_0_2).astype(np.float64)
tn_0 = np.sum(true_no_0_1).astype(np.float64)
tn_1 = np.sum(true_no_1_2).astype(np.float64)
tn_2 = np.sum(true_no_0_2).astype(np.float64)
fp_0 = np.sum(false_0_1).astype(np.float64)
fp_1 = np.sum(false_1_2).astype(np.float64)
fp_2 = np.sum(false_0_2).astype(np.float64)
fn_0 = np.sum(false_no_0_1).astype(np.float64)
fn_1 = np.sum(false_no_1_2).astype(np.float64)
fn_2 = np.sum(false_no_0_2).astype(np.float64)
recall_0 = tp_0 / (tp_0 + fn_0)
recall_1 = tp_1 / (tp_1 + fn_1)
recall_2 = tp_2 / (tp_2 + fn_2)
precision_0 = tp_0 / (tp_0 + fp_0)
precision_1 = tp_1 / (tp_1 + fp_1)
precision_2 = tp_2 / (tp_2 + fp_2)
mcc_0 = ((tp_0 * tn_0) - (fp_0 * fn_0)) / np.sqrt((tp_0 + fp_0) * (tp_0 + fn_0) * (tn_0 + fp_0) * (tn_0 + fn_0))
mcc_1 = ((tp_1 * tn_1) - (fp_1 * fn_1)) / np.sqrt((tp_1 + fp_1) * (tp_1 + fn_1) * (tn_1 + fp_1) * (tn_1 + fn_1))
mcc_2 = ((tp_2 * tn_2) - (fp_2 * fn_2)) / np.sqrt((tp_2 + fp_2) * (tp_2 + fn_2) * (tn_2 + fp_2) * (tn_2 + fn_2))
acc_0 = np.sum(lbls_0_1 == pred_0_1)/pred_0_1.size
acc_1 = np.sum(lbls_1_2 == pred_1_2)/pred_1_2.size
acc_2 = np.sum(lbls_0_2 == pred_0_2)/pred_0_2.size
print(acc_0, recall_0, precision_0, mcc_0)
print(acc_1, recall_1, precision_1, mcc_1)
print(acc_2, recall_2, precision_2, mcc_2)
return cm_0_1, cm_1_2, cm_0_2, [acc_0, acc_1, acc_2], [recall_0, recall_1, recall_2],\
[precision_0, precision_1, precision_2], [mcc_0, mcc_1, mcc_2], lbls, pred
# import matplotlib.pyplot as plt
# Fig, ax = plt.subplots()
# import numpy as np
# tup = np.load('/Users/tomrink/cld_frac_viirs.npy', allow_pickle=True)
# lbls = tup[0].flatten()
# pred = tup[1].flatten()
# bt = tup[2].flatten()
# refl = tup[3].flatten()
# refl_rng = tup[4].flatten()
# refl_std = tup[5].flatten()
# cld_prob = tup[6].flatten()
# cat_0 = lbls == 0
# cat_1 = lbls == 1
# cat_2 = lbls == 2
# cat_3 = lbls == 3
# cat_4 = lbls == 4
# cat_0_hit = (lbls == 0) & (pred == 0)
# cat_1_hit = (lbls == 1) & (pred == 1)
# cat_2_hit = (lbls == 2) & (pred == 2)
# cat_3_hit = (lbls == 3) & (pred == 3)
# cat_4_hit = (lbls == 4) & (pred == 4)
# cat_0_miss = (lbls == 0) & (pred != 0)
# cat_1_miss = (lbls == 1) & (pred != 1)
# cat_2_miss = (lbls == 2) & (pred != 2)
# cat_3_miss = (lbls == 3) & (pred != 3)
# cat_4_miss = (lbls == 4) & (pred != 4)
# plt.hist(cld_prob[cat_0], log=True, histtype='step', linewidth=1.4, color='blue', label='CLR')
# plt.hist(cld_prob[cat_1], log=True, histtype='step', linewidth=1.4, color='orange', label='1/4')
# plt.hist(cld_prob[cat_2], log=True, histtype='step', linewidth=1.4, color='green', label='1/2')
# plt.hist(cld_prob[cat_3], log=True, histtype='step', linewidth=1.4, color='red', label='3/4')
# plt.hist(cld_prob[cat_4], log=True, histtype='step', linewidth=1.4, color='purple', label='CLD')
# ax.legend(loc='upper center')
if __name__ == "__main__":
nn = SRCNN()
nn.run('matchup_filename')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment