From 96a1d843319b086795ada98ca376a47c9bb099b8 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 18 Aug 2022 10:46:12 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/srcnn.py | 24 +++++++++++++-----------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/modules/deeplearning/srcnn.py b/modules/deeplearning/srcnn.py
index 51c96323..2153d54d 100644
--- a/modules/deeplearning/srcnn.py
+++ b/modules/deeplearning/srcnn.py
@@ -67,14 +67,15 @@ x_134_2 = x_134[2:133:2]
 y_134_2 = y_134[2:133:2]
 
 
-def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.leaky_relu, padding='SAME'):
+def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME', kernel_initializer='he_uniform', scale=None):
 
     with tf.name_scope(block_name):
-        skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
-        skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(skip)
-
+        skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding, kernel_initializer=kernel_initializer, activation=activation)(conv)
+        skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding, activation=None)(skip)
+        if scale is not None:
+            skip = tf.keras.layers.Lambda(lambda x: x * scale)(skip)
         conv = conv + skip
-        print(conv.shape)
+        print(block_name+':', conv.shape)
 
     return conv
 
@@ -334,16 +335,17 @@ class SRCNN:
 
         self.get_evaluate_dataset(idxs)
 
-    def build_espcn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2):
+    def build_srcnn(self, do_drop_out=False, do_batch_norm=False, drop_rate=0.5, factor=2):
         print('build_cnn')
         padding = "SAME"
 
         # activation = tf.nn.relu
         # activation = tf.nn.elu
-        activation = tf.nn.leaky_relu
+        activation = tf.nn.relu
         momentum = 0.99
 
-        num_filters = 32
+        num_filters = 256
+        num_res_blocks = 4
 
         input_2d = self.inputs[0]
         print('input: ', input_2d.shape)
@@ -352,7 +354,7 @@ class SRCNN:
         print('input: ', conv.shape)
 
         # conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding)(input_2d)
-        conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding='VALID')(input_2d)
+        conv = conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, kernel_initializer='he_uniform', activation=activation, padding='VALID')(input_2d)
         print(conv.shape)
 
         if NOISE_TRAINING:
@@ -366,7 +368,7 @@ class SRCNN:
 
         conv_b = build_residual_conv2d_block(conv_b, num_filters, 'Residual_Block_4')
 
-        conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding)(conv_b)
+        conv_b = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, kernel_initializer='he_uniform', padding=padding)(conv_b)
 
         conv = conv + conv_b
         print(conv.shape)
@@ -649,7 +651,7 @@ class SRCNN:
         # f.close()
 
     def build_model(self):
-        self.build_espcn()
+        self.build_srcnn()
         self.model = tf.keras.Model(self.inputs, self.logits)
 
     def restore(self, ckpt_dir):
-- 
GitLab