diff --git a/modules/GSOC/E2_ESRGAN/lib/utils.py b/modules/GSOC/E2_ESRGAN/lib/utils.py index db1d4b2b4d674233bd1477784d620789b5a293b9..ae66b4fd9db30378e9b893ae1277bc6345da82b8 100644 --- a/modules/GSOC/E2_ESRGAN/lib/utils.py +++ b/modules/GSOC/E2_ESRGAN/lib/utils.py @@ -215,6 +215,29 @@ def RelativisticAverageLoss(non_transformed_disc, type_="G"): return loss +# Standard GAN loss functions --------------------------------------- +def gen_loss(non_transformed_disc, fake_image): + fake_logits = non_transformed_disc(fake_image) + fake_loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.ones_like(fake_logits), logits=fake_logits) + return fake_loss + + +def disc_loss(non_transformed_disc, real_image, fake_image): + real_logits = non_transformed_disc(real_image) + fake_logits = non_transformed_disc(fake_image) + + real_loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.ones_like(real_logits), logits=real_logits) + + fake_loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=tf.zeros_like(fake_logits), logits=fake_logits) + + return real_loss + fake_loss + +# ---------------------------------------------------------------------- + + # Strategy Utils def assign_to_worker(use_tpu):