Skip to content
Snippets Groups Projects
Commit cab215cf authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 9782049a
No related branches found
No related tags found
No related merge requests found
......@@ -40,7 +40,6 @@ DO_AUGMENT = True
DO_SMOOTH = False
SIGMA = 1.0
DO_ZERO_OUT = False
DO_ESPCN = False # Note: If True, cannot do mixed resolution input fields (Adjust accordingly below)
# setup scaling parameters dictionary
mean_std_dct = {}
......@@ -103,12 +102,6 @@ elif KERNEL_SIZE == 5:
x_2 = np.arange(68)
y_2 = np.arange(68)
# ----------------------------------------
# Exp for ESPCN version
if DO_ESPCN:
slc_x_2 = slice(0, 132, 2)
slc_y_2 = slice(0, 132, 2)
x_128 = slice(2, 130)
y_128 = slice(2, 130)
def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME',
......@@ -340,10 +333,7 @@ class SRCNN:
idx = params.index(param)
tmp = input_data[:, idx, :, :]
tmp = tmp.copy()
if DO_ESPCN:
tmp = tmp[:, slc_y_2, slc_x_2]
else:
tmp = tmp[:, slc_y, slc_x]
tmp = tmp[:, slc_y, slc_x]
tmp = normalize(tmp, param, mean_std_dct)
data_norm.append(tmp)
......@@ -365,16 +355,12 @@ class SRCNN:
# ---------------------------------------------------
tmp = input_data[:, label_idx, :, :]
tmp = tmp.copy()
if DO_ESPCN:
tmp = tmp[:, slc_y_2, slc_x_2]
else:
tmp = tmp[:, slc_y, slc_x]
if label_param != 'cloud_probability':
tmp = normalize(tmp, label_param, mean_std_dct)
tmp = tmp[:, slc_y, slc_x]
data_norm.append(tmp)
# ---------
data = np.stack(data_norm, axis=3)
data = data.astype(np.float32)
# -----------------------------------------------------
# -----------------------------------------------------
label = input_label[:, label_idx_i, :, :]
......@@ -385,10 +371,7 @@ class SRCNN:
else:
label = get_label_data(label)
if label_param != 'cloud_probability':
label = normalize(label, label_param, mean_std_dct)
else:
label = np.where(np.isnan(label), 0, label)
label = np.where(np.isnan(label), 0, label)
label = np.expand_dims(label, axis=3)
data = data.astype(np.float32)
......@@ -519,15 +502,8 @@ class SRCNN:
else:
final_activation = tf.nn.softmax # For multi-class
if not DO_ESPCN:
# This is effectively a Dense layer
self.logits = tf.keras.layers.Conv2D(NumLogits, kernel_size=1, strides=1, padding=padding, activation=final_activation)(conv)
else:
conv = tf.keras.layers.Conv2D(num_filters * (factor ** 2), 3, padding=padding, activation=activation)(conv)
print(conv.shape)
conv = tf.nn.depth_to_space(conv, factor)
print(conv.shape)
self.logits = tf.keras.layers.Conv2D(IMG_DEPTH, kernel_size=3, strides=1, padding=padding, activation=final_activation)(conv)
# This is effectively a Dense layer
self.logits = tf.keras.layers.Conv2D(NumLogits, kernel_size=1, strides=1, padding=padding, activation=final_activation)(conv)
print(self.logits.shape)
def build_training(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment