Skip to content
Snippets Groups Projects
Commit 565038f5 authored by tomrink's avatar tomrink
Browse files

minor

parent 7f8db191
No related branches found
No related tags found
No related merge requests found
......@@ -341,9 +341,9 @@ class ESPCN:
print('build_cnn')
padding = "SAME"
# activation = tf.nn.relu
# activation = tf.nn.elu
activation = tf.nn.leaky_relu
# activation = tf.nn.leaky_relu
activation = tf.nn.relu
# kernel_initializer = 'glorot_uniform'
kernel_initializer = 'he_uniform'
momentum = 0.99
......@@ -392,10 +392,6 @@ class ESPCN:
print(self.logits.shape)
def build_training(self):
# if NumClasses == 2:
# self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only
# else:
# self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) # For multi-class
self.loss = tf.keras.losses.MeanSquaredError() # Regression
# decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
......@@ -425,20 +421,6 @@ class ESPCN:
self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.test_loss = tf.keras.metrics.Mean(name='test_loss')
# if NumClasses == 2:
# self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
# self.test_auc = tf.keras.metrics.AUC(name='test_auc')
# self.test_recall = tf.keras.metrics.Recall(name='test_recall')
# self.test_precision = tf.keras.metrics.Precision(name='test_precision')
# self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
# self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
# self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
# self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
# else:
# self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(self, mini_batch):
inputs = [mini_batch[0]]
......@@ -469,14 +451,6 @@ class ESPCN:
self.test_loss(t_loss)
self.test_accuracy(labels, pred)
# if NumClasses == 2:
# self.test_auc(labels, pred)
# self.test_recall(labels, pred)
# self.test_precision(labels, pred)
# self.test_true_neg(labels, pred)
# self.test_true_pos(labels, pred)
# self.test_false_neg(labels, pred)
# self.test_false_pos(labels, pred)
def predict(self, mini_batch):
inputs = [mini_batch[0]]
......@@ -493,14 +467,6 @@ class ESPCN:
def reset_test_metrics(self):
self.test_loss.reset_states()
self.test_accuracy.reset_states()
# if NumClasses == 2:
# self.test_auc.reset_states()
# self.test_recall.reset_states()
# self.test_precision.reset_states()
# self.test_true_neg.reset_states()
# self.test_true_pos.reset_states()
# self.test_false_neg.reset_states()
# self.test_false_pos.reset_states()
def get_metrics(self):
recall = self.test_recall.result()
......@@ -533,12 +499,6 @@ class ESPCN:
step = 0
total_time = 0
best_test_loss = np.finfo(dtype=np.float).max
best_test_acc = 0
best_test_recall = 0
best_test_precision = 0
best_test_auc = 0
best_test_f1 = 0
best_test_mcc = 0
if EARLY_STOP:
es = EarlyStop()
......@@ -574,20 +534,9 @@ class ESPCN:
for mini_batch_test in tst_ds:
self.test_step(mini_batch_test)
# if NumClasses == 2:
# f1, mcc = self.get_metrics()
with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
# if NumClasses == 2:
# tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
# tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
# tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
# tf.summary.scalar('f1_val', f1, step=step)
# tf.summary.scalar('mcc_val', mcc, step=step)
# tf.summary.scalar('num_train_steps', step, step=step)
# tf.summary.scalar('num_epochs', epoch, step=step)
with self.writer_train_valid_loss.as_default():
tf.summary.scalar('loss_trn', loss.numpy(), step=step)
......@@ -615,25 +564,11 @@ class ESPCN:
self.test_step(mini_batch)
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
# if NumClasses == 2:
# f1, mcc = self.get_metrics()
# print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
# self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
# else:
# print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------')
tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss:
best_test_loss = tst_loss
# if NumClasses == 2:
# best_test_acc = self.test_accuracy.result().numpy()
# best_test_recall = self.test_recall.result().numpy()
# best_test_precision = self.test_precision.result().numpy()
# best_test_auc = self.test_auc.result().numpy()
# best_test_f1 = f1.numpy()
# best_test_mcc = mcc.numpy()
ckpt_manager.save()
if EARLY_STOP and es.check_stop(tst_loss):
......@@ -741,9 +676,9 @@ class ESPCN:
def prepare(param_idx=1, filename='/Users/tomrink/data_valid_40.npy'):
nda = np.load(filename)
#nda = nda[:, param_idx, :, :]
# nda = nda[:, param_idx, :, :]
nda_lr = nda[:, param_idx, 2:133:2, 2:133:2]
# nda_lr = resample(x_70, y_70, nda, x_70_2, y_70_2)
# nda_lr = resample(x_134, y_134, nda, x_134_2, y_134_2)
nda_lr = np.expand_dims(nda_lr, axis=3)
return nda_lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment