Skip to content
Snippets Groups Projects
Commit 0088ea48 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent df44ff8d
Branches
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ import h5py
# L1B M/I-bands: /apollo/cloud/scratch/cwhite/VIIRS_HRES/2019/2019_01_01/
# CLAVRx: /apollo/cloud/scratch/Satellite_Output/VIIRS_HRES/2019/2019_01_01/
# /apollo/cloud/scratch/Satellite_Output/andi/NEW/VIIRS_HRES/2019
LOG_DEVICE_PLACEMENT = False
......@@ -24,7 +25,7 @@ else:
NumLogits = NumClasses
BATCH_SIZE = 128
NUM_EPOCHS = 100
NUM_EPOCHS = 40
TRACK_MOVING_AVERAGE = False
EARLY_STOP = True
......@@ -207,15 +208,20 @@ class UNET:
self.test_data_files = None
self.test_label_files = None
# n_chans = len(self.train_params)
n_chans = 3
self.train_data_nda = None
self.train_label_nda = None
self.test_data_nda = None
self.test_label_nda = None
# self.n_chans = len(self.train_params)
self.n_chans = 1
if TRIPLET:
n_chans *= 3
self.X_img = tf.keras.Input(shape=(None, None, n_chans))
self.n_chans *= 3
self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
self.inputs.append(self.X_img)
# self.inputs.append(tf.keras.Input(shape=(None, None, 5)))
self.inputs.append(tf.keras.Input(shape=(None, None, 3)))
self.inputs.append(tf.keras.Input(shape=(None, None, 1)))
self.flight_level = 0
......@@ -294,7 +300,7 @@ class UNET:
#
# return data, data_alt, label
def get_in_mem_data_batch(self, idxs, is_training):
def get_in_mem_data_batch_salt(self, idxs, is_training):
data = []
for k in idxs:
if is_training:
......@@ -327,6 +333,33 @@ class UNET:
return data, data, label
def get_in_mem_data_batch(self, idxs, is_training):
if is_training:
data = self.train_data_nda[idxs,]
data = np.expand_dims(data, axis=3)
label = self.train_label_nda[idxs,]
label = np.expand_dims(label, axis=3)
else:
data = self.test_data_nda[idxs,]
data = np.expand_dims(data, axis=3)
label = self.test_label_nda[idxs,]
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
label = label.astype(np.float32)
if is_training and DO_AUGMENT:
data_ud = np.flip(data, axis=1)
label_ud = np.flip(label, axis=1)
data_lr = np.flip(data, axis=2)
label_lr = np.flip(label, axis=2)
data = np.concatenate([data, data_ud, data_lr])
label = np.concatenate([label, label_ud, label_lr])
return data, data, label
# def get_parameter_data(self, param, nd_idxs, is_training):
# if is_training:
# if param in self.train_params_l1b:
......@@ -386,12 +419,12 @@ class UNET:
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32, tf.int32])
out = tf.numpy_function(self.get_in_mem_data_batch_train, [indexes], [tf.float32, tf.float32, tf.float32])
return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
def data_function_test(self, indexes):
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32, tf.int32])
out = tf.numpy_function(self.get_in_mem_data_batch_test, [indexes], [tf.float32, tf.float32, tf.float32])
return out
@tf.function(input_signature=[tf.TensorSpec(None, tf.int32)])
......@@ -428,49 +461,77 @@ class UNET:
dataset = dataset.map(self.data_function_evaluate, num_parallel_calls=8)
self.eval_dataset = dataset
def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None):
if filename_l1b_trn is not None:
self.h5f_l1b_trn = h5py.File(filename_l1b_trn, 'r')
if filename_l1b_tst is not None:
self.h5f_l1b_tst = h5py.File(filename_l1b_tst, 'r')
if filename_l2_trn is not None:
self.h5f_l2_trn = h5py.File(filename_l2_trn, 'r')
if filename_l2_tst is not None:
self.h5f_l2_tst = h5py.File(filename_l2_tst, 'r')
if trn_idxs is None:
# Note: time is same across both L1B and L2 for idxs
if self.h5f_l1b_trn is not None:
h5f = self.h5f_l1b_trn
else:
h5f = self.h5f_l2_trn
time = h5f['time']
trn_idxs = np.arange(time.shape[0])
if seed is not None:
np.random.seed(seed)
np.random.shuffle(trn_idxs)
# def setup_pipeline(self, filename_l1b_trn, filename_l1b_tst, filename_l2_trn, filename_l2_tst, trn_idxs=None, tst_idxs=None, seed=None):
# if filename_l1b_trn is not None:
# self.h5f_l1b_trn = h5py.File(filename_l1b_trn, 'r')
# if filename_l1b_tst is not None:
# self.h5f_l1b_tst = h5py.File(filename_l1b_tst, 'r')
# if filename_l2_trn is not None:
# self.h5f_l2_trn = h5py.File(filename_l2_trn, 'r')
# if filename_l2_tst is not None:
# self.h5f_l2_tst = h5py.File(filename_l2_tst, 'r')
#
# if trn_idxs is None:
# # Note: time is same across both L1B and L2 for idxs
# if self.h5f_l1b_trn is not None:
# h5f = self.h5f_l1b_trn
# else:
# h5f = self.h5f_l2_trn
# time = h5f['time']
# trn_idxs = np.arange(time.shape[0])
# if seed is not None:
# np.random.seed(seed)
# np.random.shuffle(trn_idxs)
#
# if self.h5f_l1b_tst is not None:
# h5f = self.h5f_l1b_tst
# else:
# h5f = self.h5f_l2_tst
# time = h5f['time']
# tst_idxs = np.arange(time.shape[0])
# if seed is not None:
# np.random.seed(seed)
# np.random.shuffle(tst_idxs)
#
# self.num_data_samples = trn_idxs.shape[0]
#
# self.get_train_dataset(trn_idxs)
# self.get_test_dataset(tst_idxs)
#
# print('datetime: ', now)
# print('training and test data: ')
# print(filename_l1b_trn)
# print(filename_l1b_tst)
# print(filename_l2_trn)
# print(filename_l2_tst)
# print('---------------------------')
# print('num train samples: ', self.num_data_samples)
# print('BATCH SIZE: ', BATCH_SIZE)
# print('num test samples: ', tst_idxs.shape[0])
# print('setup_pipeline: Done')
def setup_pipeline(self, data_nda, label_nda, perc=0.20):
num_samples = data_nda.shape[0]
num_test = int(num_samples * perc)
self.num_data_samples = num_samples - num_test
num_train = self.num_data_samples
self.train_data_nda = data_nda[0:num_train]
self.train_label_nda = label_nda[0:num_train]
self.test_data_nda = data_nda[num_train:]
self.test_label_nda = label_nda[num_train:]
trn_idxs = np.arange(self.train_data_nda.shape[0])
tst_idxs = np.arange(self.test_data_nda.shape[0])
if self.h5f_l1b_tst is not None:
h5f = self.h5f_l1b_tst
else:
h5f = self.h5f_l2_tst
time = h5f['time']
tst_idxs = np.arange(time.shape[0])
if seed is not None:
np.random.seed(seed)
np.random.shuffle(tst_idxs)
self.num_data_samples = trn_idxs.shape[0]
self.get_train_dataset(trn_idxs)
self.get_test_dataset(tst_idxs)
print('datetime: ', now)
print('training and test data: ')
print(filename_l1b_trn)
print(filename_l1b_tst)
print(filename_l2_trn)
print(filename_l2_tst)
print('---------------------------')
print('num train samples: ', self.num_data_samples)
print('BATCH SIZE: ', BATCH_SIZE)
......@@ -560,7 +621,7 @@ class UNET:
momentum = 0.99
# num_filters = len(self.train_params) * 4
num_filters = 3 * 4
num_filters = self.n_chans * 4
input_2d = self.inputs[0]
conv = tf.keras.layers.Conv2D(num_filters, kernel_size=5, strides=1, padding=padding, activation=None)(input_2d)
......@@ -662,10 +723,14 @@ class UNET:
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
print('8: ', conv.shape)
if NumClasses == 2:
activation = tf.nn.sigmoid # For binary
else:
activation = tf.nn.softmax # For multi-class
conv = tf.keras.layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding=padding)(conv)
print('9: ', conv.shape)
# if NumClasses == 2:
# activation = tf.nn.sigmoid # For binary
# else:
# activation = tf.nn.softmax # For multi-class
activation = tf.nn.sigmoid
# Called logits, but these are actually probabilities, see activation
self.logits = tf.keras.layers.Conv2D(1, kernel_size=1, strides=1, padding=padding, name='probability', activation=activation)(conv)
......@@ -673,10 +738,11 @@ class UNET:
print(self.logits.shape)
def build_training(self):
if NumClasses == 2:
self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only
else:
self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) # For multi-class
# if NumClasses == 2:
# self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only
# else:
# self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) # For multi-class
self.loss = tf.keras.losses.MeanSquaredError() # Regression
# decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
initial_learning_rate = 0.006
......@@ -698,22 +764,26 @@ class UNET:
self.initial_learning_rate = initial_learning_rate
def build_evaluation(self):
#self.train_loss = tf.keras.metrics.Mean(name='train_loss')
#self.test_loss = tf.keras.metrics.Mean(name='test_loss')
self.train_accuracy = tf.keras.metrics.MeanAbsoluteError(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.MeanAbsoluteError(name='test_accuracy')
self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.test_loss = tf.keras.metrics.Mean(name='test_loss')
if NumClasses == 2:
self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
self.test_auc = tf.keras.metrics.AUC(name='test_auc')
self.test_recall = tf.keras.metrics.Recall(name='test_recall')
self.test_precision = tf.keras.metrics.Precision(name='test_precision')
self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
else:
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
# if NumClasses == 2:
# self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
# self.test_auc = tf.keras.metrics.AUC(name='test_auc')
# self.test_recall = tf.keras.metrics.Recall(name='test_recall')
# self.test_precision = tf.keras.metrics.Precision(name='test_precision')
# self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
# self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
# self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
# self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
# else:
# self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(self, mini_batch):
......@@ -745,14 +815,14 @@ class UNET:
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
if NumClasses == 2:
self.test_auc(labels, pred)
self.test_recall(labels, pred)
self.test_precision(labels, pred)
self.test_true_neg(labels, pred)
self.test_true_pos(labels, pred)
self.test_false_neg(labels, pred)
self.test_false_pos(labels, pred)
# if NumClasses == 2:
# self.test_auc(labels, pred)
# self.test_recall(labels, pred)
# self.test_precision(labels, pred)
# self.test_true_neg(labels, pred)
# self.test_true_pos(labels, pred)
# self.test_false_neg(labels, pred)
# self.test_false_pos(labels, pred)
def predict(self, mini_batch):
inputs = [mini_batch[0], mini_batch[1]]
......@@ -765,26 +835,26 @@ class UNET:
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
if NumClasses == 2:
self.test_auc(labels, pred)
self.test_recall(labels, pred)
self.test_precision(labels, pred)
self.test_true_neg(labels, pred)
self.test_true_pos(labels, pred)
self.test_false_neg(labels, pred)
self.test_false_pos(labels, pred)
# if NumClasses == 2:
# self.test_auc(labels, pred)
# self.test_recall(labels, pred)
# self.test_precision(labels, pred)
# self.test_true_neg(labels, pred)
# self.test_true_pos(labels, pred)
# self.test_false_neg(labels, pred)
# self.test_false_pos(labels, pred)
def reset_test_metrics(self):
self.test_loss.reset_states()
self.test_accuracy.reset_states()
if NumClasses == 2:
self.test_auc.reset_states()
self.test_recall.reset_states()
self.test_precision.reset_states()
self.test_true_neg.reset_states()
self.test_true_pos.reset_states()
self.test_false_neg.reset_states()
self.test_false_pos.reset_states()
# if NumClasses == 2:
# self.test_auc.reset_states()
# self.test_recall.reset_states()
# self.test_precision.reset_states()
# self.test_true_neg.reset_states()
# self.test_true_pos.reset_states()
# self.test_false_neg.reset_states()
# self.test_false_pos.reset_states()
def get_metrics(self):
recall = self.test_recall.result()
......@@ -858,20 +928,20 @@ class UNET:
for mini_batch_test in tst_ds:
self.test_step(mini_batch_test)
if NumClasses == 2:
f1, mcc = self.get_metrics()
# if NumClasses == 2:
# f1, mcc = self.get_metrics()
with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
if NumClasses == 2:
tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
tf.summary.scalar('f1_val', f1, step=step)
tf.summary.scalar('mcc_val', mcc, step=step)
tf.summary.scalar('num_train_steps', step, step=step)
tf.summary.scalar('num_epochs', epoch, step=step)
# if NumClasses == 2:
# tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
# tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
# tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
# tf.summary.scalar('f1_val', f1, step=step)
# tf.summary.scalar('mcc_val', mcc, step=step)
# tf.summary.scalar('num_train_steps', step, step=step)
# tf.summary.scalar('num_epochs', epoch, step=step)
with self.writer_train_valid_loss.as_default():
tf.summary.scalar('loss_trn', loss.numpy(), step=step)
......@@ -898,24 +968,25 @@ class UNET:
for mini_batch in ds:
self.test_step(mini_batch)
if NumClasses == 2:
f1, mcc = self.get_metrics()
print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
else:
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
# if NumClasses == 2:
# f1, mcc = self.get_metrics()
# print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
# self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
# else:
# print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss:
best_test_loss = tst_loss
if NumClasses == 2:
best_test_acc = self.test_accuracy.result().numpy()
best_test_recall = self.test_recall.result().numpy()
best_test_precision = self.test_precision.result().numpy()
best_test_auc = self.test_auc.result().numpy()
best_test_f1 = f1.numpy()
best_test_mcc = mcc.numpy()
# if NumClasses == 2:
# best_test_acc = self.test_accuracy.result().numpy()
# best_test_recall = self.test_recall.result().numpy()
# best_test_precision = self.test_precision.result().numpy()
# best_test_auc = self.test_auc.result().numpy()
# best_test_f1 = f1.numpy()
# best_test_mcc = mcc.numpy()
ckpt_manager.save()
......@@ -941,9 +1012,9 @@ class UNET:
if self.h5f_l2_tst is not None:
self.h5f_l2_tst.close()
f = open(home_dir+'/best_stats_'+now+'.pkl', 'wb')
pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
f.close()
# f = open(home_dir+'/best_stats_'+now+'.pkl', 'wb')
# pickle.dump((best_test_loss, best_test_acc, best_test_recall, best_test_precision, best_test_auc, best_test_f1, best_test_mcc), f)
# f.close()
def build_model(self):
self.build_unet()
......@@ -1008,13 +1079,20 @@ class UNET:
self.build_evaluation()
self.do_training()
def run_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/'):
def run_salt(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/'):
self.setup_salt_pipeline(data_path, label_path)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training()
def run_test(self, data_nda, label_nda):
self.setup_pipeline(data_nda, label_nda)
self.build_model()
self.build_training()
self.build_evaluation()
self.do_training()
def run_restore_test(self, data_path='/Users/tomrink/data/salt/train/images/', label_path='/Users/tomrink/data/salt/train/masks/', ckpt_dir=None):
#self.setup_salt_pipeline(data_path, label_path)
self.setup_salt_restore()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment