From 00fd0e6ad17b0426e24ef34fa90c8f2faf7cf20e Mon Sep 17 00:00:00 2001
From: tomrink <rink@ssec.wisc.edu>
Date: Mon, 1 Aug 2022 13:33:14 -0500
Subject: [PATCH] snapshot

---
 modules/deeplearning/espcn.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/modules/deeplearning/espcn.py b/modules/deeplearning/espcn.py
index 69f654a3..4086e006 100644
--- a/modules/deeplearning/espcn.py
+++ b/modules/deeplearning/espcn.py
@@ -211,12 +211,12 @@ class ESPCN:
 
         self.n_chans = 1
 
-        self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
-        # self.X_img = tf.keras.Input(shape=(30, 30, self.n_chans))
+        # self.X_img = tf.keras.Input(shape=(None, None, self.n_chans))
+        self.X_img = tf.keras.Input(shape=(30, 30, self.n_chans))
 
         self.inputs.append(self.X_img)
-        self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans)))
-        # self.inputs.append(tf.keras.Input(shape=(30, 30, self.n_chans)))
+        # self.inputs.append(tf.keras.Input(shape=(None, None, self.n_chans)))
+        self.inputs.append(tf.keras.Input(shape=(30, 30, self.n_chans)))
 
         self.DISK_CACHE = False
 
@@ -427,7 +427,7 @@ class ESPCN:
         conv = tf.keras.layers.BatchNormalization()(conv)
         print(conv.shape)
 
-        conv = tf.keras.layers.Conv2D(num_filters/2, kernel_size=3, strides=1, padding=padding, activation=None)(conv)
+        conv = tf.keras.layers.Conv2D(num_filters, kernel_size=3, strides=1, padding=padding, activation=None)(conv)
         conv = tf.keras.layers.BatchNormalization()(conv)
         print(conv.shape)
 
@@ -435,6 +435,10 @@ class ESPCN:
         conv = tf.keras.layers.LeakyReLU()(conv)
         print(conv.shape)
 
+        conv = tf.keras.layers.Conv2D(num_filters/2, kernel_size=3, strides=1, padding=padding, activation=None)(conv)
+        conv = tf.keras.layers.BatchNormalization()(conv)
+        print(conv.shape)
+
         conv = tf.keras.layers.Conv2D(4, kernel_size=3, strides=1, padding=padding, activation=activation)(conv)
         print(conv.shape)
 
-- 
GitLab