Skip to content
Snippets Groups Projects
Commit 565038f5 authored by tomrink's avatar tomrink
Browse files

minor

parent 7f8db191
No related branches found
No related tags found
No related merge requests found
...@@ -341,9 +341,9 @@ class ESPCN: ...@@ -341,9 +341,9 @@ class ESPCN:
print('build_cnn') print('build_cnn')
padding = "SAME" padding = "SAME"
# activation = tf.nn.relu
# activation = tf.nn.elu # activation = tf.nn.elu
activation = tf.nn.leaky_relu # activation = tf.nn.leaky_relu
activation = tf.nn.relu
# kernel_initializer = 'glorot_uniform' # kernel_initializer = 'glorot_uniform'
kernel_initializer = 'he_uniform' kernel_initializer = 'he_uniform'
momentum = 0.99 momentum = 0.99
...@@ -392,10 +392,6 @@ class ESPCN: ...@@ -392,10 +392,6 @@ class ESPCN:
print(self.logits.shape) print(self.logits.shape)
def build_training(self): def build_training(self):
# if NumClasses == 2:
# self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) # for two-class only
# else:
# self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) # For multi-class
self.loss = tf.keras.losses.MeanSquaredError() # Regression self.loss = tf.keras.losses.MeanSquaredError() # Regression
# decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
...@@ -425,20 +421,6 @@ class ESPCN: ...@@ -425,20 +421,6 @@ class ESPCN:
self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.train_loss = tf.keras.metrics.Mean(name='train_loss')
self.test_loss = tf.keras.metrics.Mean(name='test_loss') self.test_loss = tf.keras.metrics.Mean(name='test_loss')
# if NumClasses == 2:
# self.train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
# self.test_auc = tf.keras.metrics.AUC(name='test_auc')
# self.test_recall = tf.keras.metrics.Recall(name='test_recall')
# self.test_precision = tf.keras.metrics.Precision(name='test_precision')
# self.test_true_neg = tf.keras.metrics.TrueNegatives(name='test_true_neg')
# self.test_true_pos = tf.keras.metrics.TruePositives(name='test_true_pos')
# self.test_false_neg = tf.keras.metrics.FalseNegatives(name='test_false_neg')
# self.test_false_pos = tf.keras.metrics.FalsePositives(name='test_false_pos')
# else:
# self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function @tf.function
def train_step(self, mini_batch): def train_step(self, mini_batch):
inputs = [mini_batch[0]] inputs = [mini_batch[0]]
...@@ -469,14 +451,6 @@ class ESPCN: ...@@ -469,14 +451,6 @@ class ESPCN:
self.test_loss(t_loss) self.test_loss(t_loss)
self.test_accuracy(labels, pred) self.test_accuracy(labels, pred)
# if NumClasses == 2:
# self.test_auc(labels, pred)
# self.test_recall(labels, pred)
# self.test_precision(labels, pred)
# self.test_true_neg(labels, pred)
# self.test_true_pos(labels, pred)
# self.test_false_neg(labels, pred)
# self.test_false_pos(labels, pred)
def predict(self, mini_batch): def predict(self, mini_batch):
inputs = [mini_batch[0]] inputs = [mini_batch[0]]
...@@ -493,14 +467,6 @@ class ESPCN: ...@@ -493,14 +467,6 @@ class ESPCN:
def reset_test_metrics(self): def reset_test_metrics(self):
self.test_loss.reset_states() self.test_loss.reset_states()
self.test_accuracy.reset_states() self.test_accuracy.reset_states()
# if NumClasses == 2:
# self.test_auc.reset_states()
# self.test_recall.reset_states()
# self.test_precision.reset_states()
# self.test_true_neg.reset_states()
# self.test_true_pos.reset_states()
# self.test_false_neg.reset_states()
# self.test_false_pos.reset_states()
def get_metrics(self): def get_metrics(self):
recall = self.test_recall.result() recall = self.test_recall.result()
...@@ -533,12 +499,6 @@ class ESPCN: ...@@ -533,12 +499,6 @@ class ESPCN:
step = 0 step = 0
total_time = 0 total_time = 0
best_test_loss = np.finfo(dtype=np.float).max best_test_loss = np.finfo(dtype=np.float).max
best_test_acc = 0
best_test_recall = 0
best_test_precision = 0
best_test_auc = 0
best_test_f1 = 0
best_test_mcc = 0
if EARLY_STOP: if EARLY_STOP:
es = EarlyStop() es = EarlyStop()
...@@ -574,20 +534,9 @@ class ESPCN: ...@@ -574,20 +534,9 @@ class ESPCN:
for mini_batch_test in tst_ds: for mini_batch_test in tst_ds:
self.test_step(mini_batch_test) self.test_step(mini_batch_test)
# if NumClasses == 2:
# f1, mcc = self.get_metrics()
with self.writer_valid.as_default(): with self.writer_valid.as_default():
tf.summary.scalar('loss_val', self.test_loss.result(), step=step) tf.summary.scalar('loss_val', self.test_loss.result(), step=step)
tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step) tf.summary.scalar('acc_val', self.test_accuracy.result(), step=step)
# if NumClasses == 2:
# tf.summary.scalar('auc_val', self.test_auc.result(), step=step)
# tf.summary.scalar('recall_val', self.test_recall.result(), step=step)
# tf.summary.scalar('prec_val', self.test_precision.result(), step=step)
# tf.summary.scalar('f1_val', f1, step=step)
# tf.summary.scalar('mcc_val', mcc, step=step)
# tf.summary.scalar('num_train_steps', step, step=step)
# tf.summary.scalar('num_epochs', epoch, step=step)
with self.writer_train_valid_loss.as_default(): with self.writer_train_valid_loss.as_default():
tf.summary.scalar('loss_trn', loss.numpy(), step=step) tf.summary.scalar('loss_trn', loss.numpy(), step=step)
...@@ -615,25 +564,11 @@ class ESPCN: ...@@ -615,25 +564,11 @@ class ESPCN:
self.test_step(mini_batch) self.test_step(mini_batch)
print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy()) print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
# if NumClasses == 2:
# f1, mcc = self.get_metrics()
# print('loss, acc, recall, precision, auc, f1, mcc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy(),
# self.test_recall.result().numpy(), self.test_precision.result().numpy(), self.test_auc.result().numpy(), f1.numpy(), mcc.numpy())
# else:
# print('loss, acc: ', self.test_loss.result().numpy(), self.test_accuracy.result().numpy())
print('------------------------------------------------------') print('------------------------------------------------------')
tst_loss = self.test_loss.result().numpy() tst_loss = self.test_loss.result().numpy()
if tst_loss < best_test_loss: if tst_loss < best_test_loss:
best_test_loss = tst_loss best_test_loss = tst_loss
# if NumClasses == 2:
# best_test_acc = self.test_accuracy.result().numpy()
# best_test_recall = self.test_recall.result().numpy()
# best_test_precision = self.test_precision.result().numpy()
# best_test_auc = self.test_auc.result().numpy()
# best_test_f1 = f1.numpy()
# best_test_mcc = mcc.numpy()
ckpt_manager.save() ckpt_manager.save()
if EARLY_STOP and es.check_stop(tst_loss): if EARLY_STOP and es.check_stop(tst_loss):
...@@ -741,9 +676,9 @@ class ESPCN: ...@@ -741,9 +676,9 @@ class ESPCN:
def prepare(param_idx=1, filename='/Users/tomrink/data_valid_40.npy'): def prepare(param_idx=1, filename='/Users/tomrink/data_valid_40.npy'):
nda = np.load(filename) nda = np.load(filename)
#nda = nda[:, param_idx, :, :] # nda = nda[:, param_idx, :, :]
nda_lr = nda[:, param_idx, 2:133:2, 2:133:2] nda_lr = nda[:, param_idx, 2:133:2, 2:133:2]
# nda_lr = resample(x_70, y_70, nda, x_70_2, y_70_2) # nda_lr = resample(x_134, y_134, nda, x_134_2, y_134_2)
nda_lr = np.expand_dims(nda_lr, axis=3) nda_lr = np.expand_dims(nda_lr, axis=3)
return nda_lr return nda_lr
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment