Skip to content
Snippets Groups Projects
Commit e35f27d7 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 559500db
No related branches found
No related tags found
No related merge requests found
...@@ -67,34 +67,6 @@ print('label_param: ', label_param) ...@@ -67,34 +67,6 @@ print('label_param: ', label_param)
KERNEL_SIZE = 3 # target size: (128, 128) KERNEL_SIZE = 3 # target size: (128, 128)
LEN_X = LEN_Y = 128 LEN_X = LEN_Y = 128
if KERNEL_SIZE == 3:
slc_x_m = slice(1, int(LEN_X/2) + 4)
slc_y_m = slice(1, int(LEN_Y/2) + 4)
slc_x = slice(3, LEN_X + 5)
slc_y = slice(3, LEN_Y + 5)
slc_x_2 = slice(2, LEN_X + 7, 2)
slc_y_2 = slice(2, LEN_Y + 7, 2)
x_2 = np.arange(int(LEN_X/2) + 3)
y_2 = np.arange(int(LEN_Y/2) + 3)
t = np.arange(0, int(LEN_X/2) + 3, 0.5)
s = np.arange(0, int(LEN_Y/2) + 3, 0.5)
x_k = slice(1, LEN_X + 3)
y_k = slice(1, LEN_Y + 3)
x_128 = slice(4, LEN_X + 4)
y_128 = slice(4, LEN_Y + 4)
elif KERNEL_SIZE == 5:
slc_x = slice(3, 135)
slc_y = slice(3, 135)
slc_x_2 = slice(2, 137, 2)
slc_y_2 = slice(2, 137, 2)
x_128 = slice(5, 133)
y_128 = slice(5, 133)
t = np.arange(1, 67, 0.5)
s = np.arange(1, 67, 0.5)
x_2 = np.arange(68)
y_2 = np.arange(68)
# ----------------------------------------
def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME', def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME',
kernel_initializer='he_uniform', scale=None, kernel_size=3, kernel_initializer='he_uniform', scale=None, kernel_size=3,
...@@ -119,12 +91,6 @@ def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn. ...@@ -119,12 +91,6 @@ def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.
return conv return conv
def upsample(tmp):
tmp = resample_2d_linear(x_2, y_2, tmp, t, s)
tmp = tmp[:, y_k, x_k]
return tmp
def upsample_nearest(grd): def upsample_nearest(grd):
bsize, ylen, xlen = grd.shape bsize, ylen, xlen = grd.shape
up = np.zeros((bsize, ylen*2, xlen*2)) up = np.zeros((bsize, ylen*2, xlen*2))
...@@ -177,7 +143,7 @@ def get_min_max_std(grd_k): ...@@ -177,7 +143,7 @@ def get_min_max_std(grd_k):
class SRCNN: class SRCNN:
def __init__(self): def __init__(self, LEN_Y=128, LEN_X=128):
self.train_data = None self.train_data = None
self.train_label = None self.train_label = None
...@@ -190,20 +156,7 @@ class SRCNN: ...@@ -190,20 +156,7 @@ class SRCNN:
self.test_dataset = None self.test_dataset = None
self.eval_dataset = None self.eval_dataset = None
self.X_img = None self.X_img = None
self.X_prof = None
self.X_u = None
self.X_v = None
self.X_sfc = None
self.inputs = [] self.inputs = []
self.y = None
self.handle = None
self.inner_handle = None
self.in_mem_batch = None
self.h5f_l1b_trn = None
self.h5f_l1b_tst = None
self.h5f_l2_trn = None
self.h5f_l2_tst = None
self.logits = None self.logits = None
...@@ -225,8 +178,6 @@ class SRCNN: ...@@ -225,8 +178,6 @@ class SRCNN:
self.writer_valid = None self.writer_valid = None
self.writer_train_valid_loss = None self.writer_train_valid_loss = None
self.OUT_OF_RANGE = False
self.model = None self.model = None
self.optimizer = None self.optimizer = None
self.ema = None self.ema = None
...@@ -234,14 +185,6 @@ class SRCNN: ...@@ -234,14 +185,6 @@ class SRCNN:
self.train_accuracy = None self.train_accuracy = None
self.test_loss = None self.test_loss = None
self.test_accuracy = None self.test_accuracy = None
self.test_auc = None
self.test_recall = None
self.test_precision = None
self.test_confusion_matrix = None
self.test_true_pos = None
self.test_true_neg = None
self.test_false_pos = None
self.test_false_neg = None
self.test_labels = [] self.test_labels = []
self.test_preds = [] self.test_preds = []
...@@ -251,7 +194,6 @@ class SRCNN: ...@@ -251,7 +194,6 @@ class SRCNN:
self.num_data_samples = None self.num_data_samples = None
self.initial_learning_rate = None self.initial_learning_rate = None
self.data_dct = None
self.train_data_files = None self.train_data_files = None
self.train_label_files = None self.train_label_files = None
self.test_data_files = None self.test_data_files = None
...@@ -264,8 +206,28 @@ class SRCNN: ...@@ -264,8 +206,28 @@ class SRCNN:
self.inputs.append(self.X_img) self.inputs.append(self.X_img)
self.slc_x_m = slice(1, int(LEN_X / 2) + 4)
self.slc_y_m = slice(1, int(LEN_Y / 2) + 4)
self.slc_x = slice(3, LEN_X + 5)
self.slc_y = slice(3, LEN_Y + 5)
self.slc_x_2 = slice(2, LEN_X + 7, 2)
self.slc_y_2 = slice(2, LEN_Y + 7, 2)
self.x_2 = np.arange(int(LEN_X / 2) + 3)
self.y_2 = np.arange(int(LEN_Y / 2) + 3)
self.t = np.arange(0, int(LEN_X / 2) + 3, 0.5)
self.s = np.arange(0, int(LEN_Y / 2) + 3, 0.5)
self.x_k = slice(1, LEN_X + 3)
self.y_k = slice(1, LEN_Y + 3)
self.x_128 = slice(4, LEN_X + 4)
self.y_128 = slice(4, LEN_Y + 4)
tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT) tf.debugging.set_log_device_placement(LOG_DEVICE_PLACEMENT)
def upsample(self, grd):
grd = resample_2d_linear(self.x_2, self.y_2, grd, self.t, self.s)
grd = grd[:, self.y_k, self.x_k]
return grd
def get_in_mem_data_batch(self, idxs, is_training): def get_in_mem_data_batch(self, idxs, is_training):
if is_training: if is_training:
data_files = self.train_data_files data_files = self.train_data_files
...@@ -292,8 +254,8 @@ class SRCNN: ...@@ -292,8 +254,8 @@ class SRCNN:
idx = params.index(param) idx = params.index(param)
tmp = input_data[:, idx, :, :] tmp = input_data[:, idx, :, :]
tmp = np.where(np.isnan(tmp), 0, tmp) tmp = np.where(np.isnan(tmp), 0, tmp)
tmp = tmp[:, slc_y_m, slc_x_m] tmp = tmp[:, self.slc_y_m, self.slc_x_m]
tmp = upsample(tmp) tmp = self.upsample(tmp)
tmp = normalize(tmp, param, mean_std_dct) tmp = normalize(tmp, param, mean_std_dct)
data_norm.append(tmp) data_norm.append(tmp)
...@@ -310,17 +272,17 @@ class SRCNN: ...@@ -310,17 +272,17 @@ class SRCNN:
# hi = normalize(hi, param, mean_std_dct) # hi = normalize(hi, param, mean_std_dct)
# avg = normalize(avg, param, mean_std_dct) # avg = normalize(avg, param, mean_std_dct)
# #
# data_norm.append(lo[:, slc_y, slc_x]) # data_norm.append(lo[:, self.slc_y, self.slc_x])
# data_norm.append(hi[:, slc_y, slc_x]) # data_norm.append(hi[:, self.slc_y, self.slc_x])
# data_norm.append(avg[:, slc_y, slc_x]) # data_norm.append(avg[:, self.slc_y, self.slc_x])
tmp = normalize(tmp, param, mean_std_dct) tmp = normalize(tmp, param, mean_std_dct)
data_norm.append(tmp[:, slc_y, slc_x]) data_norm.append(tmp[:, self.slc_y, self.slc_x])
# --------------------------------------------------- # ---------------------------------------------------
tmp = input_label[:, label_idx_i, :, :] tmp = input_label[:, label_idx_i, :, :]
tmp = np.where(np.isnan(tmp), 0, tmp) tmp = np.where(np.isnan(tmp), 0, tmp)
tmp = tmp[:, slc_y_2, slc_x_2] tmp = tmp[:, self.slc_y_2, self.slc_x_2]
tmp = upsample(tmp) tmp = self.upsample(tmp)
tmp = normalize(tmp, label_param, mean_std_dct) tmp = normalize(tmp, label_param, mean_std_dct)
data_norm.append(tmp) data_norm.append(tmp)
# --------- # ---------
...@@ -331,7 +293,7 @@ class SRCNN: ...@@ -331,7 +293,7 @@ class SRCNN:
# ----------------------------------------------------- # -----------------------------------------------------
label = input_label[:, label_idx_i, :, :] label = input_label[:, label_idx_i, :, :]
label = normalize(label, label_param, mean_std_dct) label = normalize(label, label_param, mean_std_dct)
label = label[:, y_128, x_128] label = label[:, self.y_128, self.x_128]
label = np.where(np.isnan(label), 0, label) label = np.where(np.isnan(label), 0, label)
label = np.expand_dims(label, axis=3) label = np.expand_dims(label, axis=3)
...@@ -535,19 +497,6 @@ class SRCNN: ...@@ -535,19 +497,6 @@ class SRCNN:
self.test_loss.reset_states() self.test_loss.reset_states()
self.test_accuracy.reset_states() self.test_accuracy.reset_states()
def get_metrics(self):
recall = self.test_recall.result()
precsn = self.test_precision.result()
f1 = 2 * (precsn * recall) / (precsn + recall)
tn = self.test_true_neg.result()
tp = self.test_true_pos.result()
fn = self.test_false_neg.result()
fp = self.test_false_pos.result()
mcc = ((tp * tn) - (fp * fn)) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
return f1, mcc
def do_training(self, ckpt_dir=None): def do_training(self, ckpt_dir=None):
if ckpt_dir is None: if ckpt_dir is None:
...@@ -671,10 +620,10 @@ class SRCNN: ...@@ -671,10 +620,10 @@ class SRCNN:
preds = np.concatenate(self.test_preds) preds = np.concatenate(self.test_preds)
print(labels.shape, preds.shape) print(labels.shape, preds.shape)
labels = denormalize(labels, label_param, mean_std_dct) labels_denorm = denormalize(labels, label_param, mean_std_dct)
preds = denormalize(preds, label_param, mean_std_dct) preds_denorm = denormalize(preds, label_param, mean_std_dct)
return labels, preds return labels_denorm, preds_denorm
def do_evaluate(self, inputs, ckpt_dir): def do_evaluate(self, inputs, ckpt_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment