Skip to content
Snippets Groups Projects
Commit 4751ca86 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 09ef318d
Branches
No related tags found
No related merge requests found
...@@ -350,17 +350,19 @@ class SRCNN: ...@@ -350,17 +350,19 @@ class SRCNN:
tmp = tmp[:, slc_y, slc_x] tmp = tmp[:, slc_y, slc_x]
data_norm.append(tmp) data_norm.append(tmp)
# --------- # ---------
data = np.stack(data_norm, axis=3) # data = np.stack(data_norm, axis=3)
data = data.astype(np.float32) # data = data.astype(np.float32)
# ----------------------------------------------------- # -----------------------------------------------------
# ----------------------------------------------------- # -----------------------------------------------------
label = input_label[:, label_idx_i, :, :] label = input_label[:, label_idx_i, :, :]
label = label[:, y_64, x_64] label = label[:, y_64, x_64]
cld_prob = cld_prob[:, y_64, x_64] cld_prob = cld_prob[:, y_64, x_64]
if not is_training:
cat_cf = get_label_data_5cat(cld_prob) cat_cf = get_label_data_5cat(cld_prob)
self.test_cat_cf.append(cat_cf) data_norm.append(cat_cf)
data = np.stack(data_norm, axis=3)
data = data.astype(np.float32)
label = get_cldy_frac_opd(cld_prob, label) label = get_cldy_frac_opd(cld_prob, label)
# label = scale(label, label_param, mean_std_dct) # label = scale(label, label_param, mean_std_dct)
label = np.where(np.isnan(label), 0, label) label = np.where(np.isnan(label), 0, label)
...@@ -470,6 +472,7 @@ class SRCNN: ...@@ -470,6 +472,7 @@ class SRCNN:
num_filters = 64 num_filters = 64
input_2d = self.inputs[0] input_2d = self.inputs[0]
input_2d = input_2d[:, :, :, 0:self.n_chans]
print('input: ', input_2d.shape) print('input: ', input_2d.shape)
conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=KERNEL_SIZE, kernel_initializer='he_uniform', activation=activation, padding='VALID')(input_2d) conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=KERNEL_SIZE, kernel_initializer='he_uniform', activation=activation, padding='VALID')(input_2d)
...@@ -568,6 +571,7 @@ class SRCNN: ...@@ -568,6 +571,7 @@ class SRCNN:
self.test_labels.append(labels) self.test_labels.append(labels)
self.test_preds.append(pred.numpy()) self.test_preds.append(pred.numpy())
self.test_input.append(inputs) self.test_input.append(inputs)
self.test_cat_cf.append(inputs[:, :, :, self.n_chans])
self.test_loss(t_loss) self.test_loss(t_loss)
self.test_accuracy(labels, pred) self.test_accuracy(labels, pred)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment