diff --git a/modules/GSOC/E2_ESRGAN/lib/utils.py b/modules/GSOC/E2_ESRGAN/lib/utils.py index f724c3d271779c561b8fdcf9d0dcc564bd3ce54e..99d463d6c0b972dc05947a1ec57af5bdf2fc67aa 100644 --- a/modules/GSOC/E2_ESRGAN/lib/utils.py +++ b/modules/GSOC/E2_ESRGAN/lib/utils.py @@ -164,6 +164,12 @@ def pixel_loss(y_true, y_pred): return tf.reduce_mean(tf.reduce_mean(tf.abs(y_true - y_pred), axis=0)) +def pixel_loss_mse(y_true, y_pred): + y_true = tf.cast(y_true, tf.float32) + y_pred = tf.cast(y_pred, tf.float32) + return tf.reduce_mean(tf.reduce_mean(tf.square(y_true - y_pred), axis=0)) + + def RelativisticAverageLoss(non_transformed_disc, type_="G"): """ Relativistic Average Loss based on RaGAN Args: