diff --git a/modules/GSOC/E2_ESRGAN/main.py b/modules/GSOC/E2_ESRGAN/main.py index 44b219ec888a5c03a0b7ff9113225c240093e7a5..663b6f1daeb0671229bfd8cb204253d69c368f44 100644 --- a/modules/GSOC/E2_ESRGAN/main.py +++ b/modules/GSOC/E2_ESRGAN/main.py @@ -13,6 +13,8 @@ tf.compat.v1.disable_v2_behavior logging.basicConfig(filename='/Users/tomrink/esrgan_log.txt', level=logging.DEBUG) +num_chans = 2 + """ Enhanced Super Resolution GAN. Citation: @@ -69,7 +71,7 @@ def run(kwargs): if not kwargs["export_only"]: generator = model.RRDBNet(out_channel=1) logging.debug("Initiating Convolutions") - generator.unsigned_call(tf.random.normal([1, 128, 128, 1])) + generator.unsigned_call(tf.random.normal([1, 128, 128, num_chans])) training = train.Trainer( summary_writer=summary_writer_1, summary_writer_2=summary_writer_2,