From b7a276305b0aee801b24fd72afdc914819f57794 Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Thu, 6 Oct 2022 14:40:11 -0500
Subject: [PATCH] snapshot...

---
 modules/deeplearning/srcnn_l1b_l2.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/modules/deeplearning/srcnn_l1b_l2.py b/modules/deeplearning/srcnn_l1b_l2.py
index 60d89f0d..bb3bfbf3 100644
--- a/modules/deeplearning/srcnn_l1b_l2.py
+++ b/modules/deeplearning/srcnn_l1b_l2.py
@@ -73,13 +73,23 @@ t = np.arange(0, 64, 0.5)
 s = np.arange(0, 64, 0.5)
 
 
-def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME', kernel_initializer='he_uniform', scale=None):
+def build_residual_conv2d_block(conv, num_filters, block_name, activation=tf.nn.relu, padding='SAME',
+                                kernel_initializer='he_uniform', scale=None,
+                                do_drop_out=False, drop_rate=0.5, do_batch_norm=False):
 
     with tf.name_scope(block_name):
         skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding, kernel_initializer=kernel_initializer, activation=activation)(conv)
         skip = tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding=padding, activation=None)(skip)
+
         if scale is not None:
             skip = tf.keras.layers.Lambda(lambda x: x * scale)(skip)
+
+        if do_drop_out:
+            skip = tf.keras.layers.Dropout(drop_rate)(skip)
+
+        if do_batch_norm:
+            skip = tf.keras.layers.BatchNormalization()(skip)
+
         conv = conv + skip
         print(block_name+':', conv.shape)
 
-- 
GitLab