Skip to content
Snippets Groups Projects
Commit 4e70b030 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 5e9292f2
Branches
No related tags found
No related merge requests found
......@@ -208,7 +208,7 @@ class SRCNN:
self.test_label_files = None
# self.n_chans = len(data_params_half) + len(data_params_full) + 1
self.n_chans = 6
self.n_chans = 3
self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
......@@ -267,16 +267,16 @@ class SRCNN:
tmp = normalize(tmp, param, mean_std_dct)
data_norm.append(tmp)
for param in sub_fields:
idx = params.index(param)
tmp = input_data[:, idx, :, :]
tmp = upsample_nearest(tmp)
tmp = tmp[:, self.slc_y, self.slc_x]
if param != 'refl_substddev_ch01':
tmp = normalize(tmp, 'refl_0_65um_nom', mean_std_dct)
else:
tmp = np.where(np.isnan(tmp), 0, tmp)
data_norm.append(tmp)
# for param in sub_fields:
# idx = params.index(param)
# tmp = input_data[:, idx, :, :]
# tmp = upsample_nearest(tmp)
# tmp = tmp[:, self.slc_y, self.slc_x]
# if param != 'refl_substddev_ch01':
# tmp = normalize(tmp, 'refl_0_65um_nom', mean_std_dct)
# else:
# tmp = np.where(np.isnan(tmp), 0, tmp)
# data_norm.append(tmp)
# for param in data_params_full:
# idx = params_i.index(param)
......@@ -287,6 +287,7 @@ class SRCNN:
# data_norm.append(tmp[:, self.slc_y, self.slc_x])
# ---------------------------------------------------
tmp = input_label[:, label_idx_i, ::2, ::2]
tmp = tmp.copy()
tmp = np.where(np.isnan(tmp), 0, tmp)
tmp = tmp[:, self.slc_y_2, self.slc_x_2]
tmp = self.upsample(tmp)
......@@ -299,8 +300,9 @@ class SRCNN:
# -----------------------------------------------------
# -----------------------------------------------------
label = input_label[:, label_idx_i, ::2, ::2]
label = normalize(label, label_param, mean_std_dct)
# label = scale(label, label_param, mean_std_dct)
label = label.copy()
# label = normalize(label, label_param, mean_std_dct)
label = scale(label, label_param, mean_std_dct)
label = label[:, self.y_128, self.x_128]
label = np.where(np.isnan(label), 0, label)
......@@ -415,7 +417,7 @@ class SRCNN:
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_2', kernel_size=KERNEL_SIZE, scale=scale)
#conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_3', kernel_size=KERNEL_SIZE, scale=scale)
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_3', kernel_size=KERNEL_SIZE, scale=scale)
#conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_4', kernel_size=KERNEL_SIZE, scale=scale)
......@@ -628,10 +630,10 @@ class SRCNN:
preds = np.concatenate(self.test_preds)
print(labels.shape, preds.shape)
labels_denorm = denormalize(labels, label_param, mean_std_dct)
preds_denorm = denormalize(preds, label_param, mean_std_dct)
# labels_denorm = descale(labels, label_param, mean_std_dct)
# preds_denorm = descale(preds, label_param, mean_std_dct)
# labels_denorm = denormalize(labels, label_param, mean_std_dct)
# preds_denorm = denormalize(preds, label_param, mean_std_dct)
labels_denorm = descale(labels, label_param, mean_std_dct)
preds_denorm = descale(preds, label_param, mean_std_dct)
return labels_denorm, preds_denorm
......@@ -761,8 +763,8 @@ def run_evaluate_static(in_file, out_file, ckpt_dir):
print('INPUT: ', data.shape)
cld_opd_sres = nn.run_evaluate(data, ckpt_dir)
# cld_opd_sres = descale(cld_opd_sres, label_param, mean_std_dct)
cld_opd_sres = denormalize(cld_opd_sres, label_param, mean_std_dct)
cld_opd_sres = descale(cld_opd_sres, label_param, mean_std_dct)
# cld_opd_sres = denormalize(cld_opd_sres, label_param, mean_std_dct)
_, ylen, xlen, _ = cld_opd_sres.shape
print('OUT: ', ylen, xlen)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment