Skip to content
Snippets Groups Projects
utils.py 2.35 KiB
""" Module containing utility functions for the trainer """
import os

from absl import logging
from libs import settings
import tensorflow as tf
# Loading utilities from ESRGAN
from lib.utils import assign_to_worker
from lib.utils import PerceptualLoss
from lib.utils import RelativisticAverageLoss
from lib.utils import SingleDeviceStrategy
from lib.utils import preprocess_input

def checkpoint_exists(names, basepath="", use_student_settings=False):
  sett = settings.Settings(use_student_settings=use_student_settings)
  if tf.nest.is_nested(names):
    if isinstance(names, dict):
      return False
  else:
    names = [names]
  values = []
  for name in names:
    dir_ = os.path.join(basepath, sett["checkpoint_path"][name], "checkpoint")
    values.append(tf.io.gfile.exists(dir_))
  return any(values)


def save_checkpoint(checkpoint, name, basepath="", use_student_settings=False):
  """ Saves Checkpoint
      Args:
        checkpoint: tf.train.Checkpoint object to save.
        name: name of the checkpoint to save.
        basepath: base directory where checkpoint should be saved
        student: boolean to indicate if settings of the student should be used.
  """
  sett = settings.Settings(use_student_settings=use_student_settings)
  dir_ = os.path.join(basepath, sett["checkpoint_path"][name])
  logging.info("Saving checkpoint: %s Path: %s" % (name, dir_))
  prefix = os.path.join(dir_, os.path.basename(dir_))
  checkpoint.save(file_prefix=prefix)


def load_checkpoint(checkpoint, name, basepath="", use_student_settings=False):
  """ Restores Checkpoint
      Args:
        checkpoint: tf.train.Checkpoint object to restore.
        name: name of the checkpoint to restore.
        basepath: base directory where checkpoint is located.
        student: boolean to indicate if settings of the student should be used.
  """
  sett = settings.Settings(use_student_settings=use_student_settings)
  dir_ = os.path.join(basepath, sett["checkpoint_path"][name])
  if tf.io.gfile.exists(os.path.join(dir_, "checkpoint")):
    logging.info("Found checkpoint: %s Path: %s" % (name, dir_))
    status = checkpoint.restore(tf.train.latest_checkpoint(dir_))
    return status
  logging.info("No Checkpoint found for %s" % name)

# Losses


def pixelwise_mse(y_true, y_pred):
  mean_squared_error = tf.reduce_mean(tf.reduce_mean(
      (y_true - y_pred)**2, axis=0))
  return mean_squared_error