Skip to content
Snippets Groups Projects
Commit 96a1d843 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 42377f83
Branches
No related tags found
No related merge requests found
...@@ -67,14 +67,15 @@ x_134_2 = x_134[2:133:2] ...@@ -67,14 +67,15 @@ x_134_2 = x_134[2:133:2]
y_134_2 = y_134[2:133:2] y_134_2 = y_134[2:133:2]
def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.leaky_relu, padding='SAME'): def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME', kernel_initializer='he_uniform', scale=None):
with tf.name_scope(block_name): with tf.name_scope(block_name):
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv) skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding, kernel_initializer=kernel_initializer, activation=activation)(conv)
skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip) skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding, activation=None)(skip)
if scale is not None:
skip = tf.keras.layers.Lambda(lambda x: x * scale)(skip)
conv = conv + skip conv = conv + skip
print(conv.shape) print(block_name+':', conv.shape)
return conv return conv
...@@ -334,16 +335,17 @@ class SRCNN: ...@@ -334,16 +335,17 @@ class SRCNN:
self.get_evaluate_dataset(idxs) self.get_evaluate_dataset(idxs)
def build_espcn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2): def build_srcnn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2):
print('build_cnn') print('build_cnn')
padding = "SAME" padding = "SAME"
# activation = tf.nn.relu # activation = tf.nn.relu
# activation = tf.nn.elu # activation = tf.nn.elu
activation = tf.nn.leaky_relu activation = tf.nn.relu
momentum = 0.99 momentum = 0.99
num_filters = 32 num_filters = 256
num_res_blocks = 4
input_2d = self.inputs[0] input_2d = self.inputs[0]
print('input: ', input_2d.shape) print('input: ', input_2d.shape)
...@@ -352,7 +354,7 @@ class SRCNN: ...@@ -352,7 +354,7 @@ class SRCNN:
print('input: ', conv.shape) print('input: ', conv.shape)
# conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding)(input_2d) # conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding)(input_2d)
conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding='VALID')(input_2d) conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, kernel_initializer='he_uniform', activation=activation, padding='VALID')(input_2d)
print(conv.shape) print(conv.shape)
if NOISE_TRAINING: if NOISE_TRAINING:
...@@ -366,7 +368,7 @@ class SRCNN: ...@@ -366,7 +368,7 @@ class SRCNN:
conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_4') conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_4')
conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding)(conv_b) conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, kernel_initializer='he_uniform', padding=padding)(conv_b)
conv = conv + conv_b conv = conv + conv_b
print(conv.shape) print(conv.shape)
...@@ -649,7 +651,7 @@ class SRCNN: ...@@ -649,7 +651,7 @@ class SRCNN:
# f.close() # f.close()
def build_model(self): def build_model(self):
self.build_espcn() self.build_srcnn()
self.model = tf.keras.Model(self.inputs, self.logits) self.model = tf.keras.Model(self.inputs, self.logits)
def restore(self, ckpt_dir): def restore(self, ckpt_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment