diff --git a/modules/GSOC/E2_ESRGAN/main.py b/modules/GSOC/E2_ESRGAN/main.py index d49e4fbdf7546860fb83fe923c44212a6d415c13..d78209ec8ceeac1e702773fb6a84295674d780eb 100644 --- a/modules/GSOC/E2_ESRGAN/main.py +++ b/modules/GSOC/E2_ESRGAN/main.py @@ -54,10 +54,11 @@ def run(kwargs): tf.config.experimental.set_memory_growth(physical_device, True) strategy = utils.SingleDeviceStrategy() - #scope = utils.assign_to_worker(kwargs["tpu"]) + # scope = utils.assign_to_worker(kwargs["tpu"]) sett = settings.Settings(kwargs["config"]) Stats = settings.Stats(os.path.join(sett.path, "stats.yaml")) tf.random.set_seed(10) + # if kwargs["tpu"]: # cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( # kwargs["tpu"]) @@ -65,6 +66,7 @@ def run(kwargs): # tf.tpu.experimental.initialize_tpu_system(cluster_resolver) # strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) # with tf.device(scope), strategy.scope(): + with strategy.scope(): summary_writer_1 = tf.summary.create_file_writer(os.path.join(kwargs["log_dir"], "phase1")) summary_writer_2 = tf.summary.create_file_writer(os.path.join(kwargs["log_dir"], "phase2"))