Skip to content
Snippets Groups Projects
Commit 04bb51f7 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 05cd1d4f
Branches
No related tags found
No related merge requests found
Showing
with 1409 additions and 0 deletions
# GAN Distillation on Enhanced Super Resolution GAN
Link to ESRGAN Tranining Codes: [Click Here](../E2_ESRGAN)
Knowledge Distillation on Enhanced Super Resolution GAN to perform Super Resolution on model with much smaller number of
variables.
The Training Algorithm is **inspired** from https://arxiv.org/abs/1902.00159, with a custom loss function specific to the
problem of image super resolution.
Results
------------------
### ESRGAN
**Latency**: 17.117 Seconds
**Mean PSNR Achieved**: 28.2
**Sample**:
Input Image Shape: 180 x 320
Output image shape: 720 x 1280
![esrgan](https://user-images.githubusercontent.com/13994201/63640629-251a1e80-c6c0-11e9-98bc-04432c7064e2.jpg "ESRGAN")
**PSNR of the Image**: 30.462
### Compressed ESRGAN
**Latency**: _**0.4 Seconds**_
**Mean PSNR Achieved**: 25.3
**Sample**
Input Image Shape: 180 x 320
Output image shape: 720 x 1280
![compressed_esrgan](https://user-images.githubusercontent.com/13994201/63640526-1121ed00-c6bf-11e9-99f5-0b48069fe784.jpg "Compressed ESRGAN")
**PSNR of the Image**: 26.942
Student Model
----------------
The Residual in Residual Architecture of ESRGAN was followed. With much shallower trunk.
Specifically,
|Name of Node|Depth|
|:-:|:-:|
|Residual Dense Blocks(RDB)|2 Depthwise Convolutions|
|Residual in Residual Blocks(RRDB)|2 RDB units|
|Trunk|3 RRDB units|
|UpSample Layer|1 ConvTranspose unit with a stride length of 4|
**Size of Growth Channel (intermediate channel) used:** 32
Trained Saved Model and TF Lite File
-----------------------------------------
#### Specification of the Saved Model
Input Dimension: `[None, 180, 320, 3]`
Input Data Type: `Float32`
Output Dimension: `[None, 180, 320, 3]`
TensorFlow loadable link: https://github.com/captain-pool/GSOC/releases/download/2.0.0/compressed_esrgan.tar.gz
#### Specification of TF Lite
Input Dimension: `[1, 180, 320, 3]`
Input Data Type: `Float32`
Output Dimension: `[1, 720, 1280, 3]`
TensorFlow Lite: https://github.com/captain-pool/GSOC/releases/download/2.0.0/compressed_esrgan.tflite
# Config File for Distillation
teacher_directory: "../E2_ESRGAN"
teacher_config: "../E2_ESRGAN/config/config.yaml"
checkpoint_path:
comparative_checkpoint: "checkpoints/mse_checkpoints/"
adversarial_checkpoint: "checkpoints/adversarial_checkpoints/"
hr_size: [512, 512, 3]
train:
comparative:
print_step: 1000
num_steps: 200000
checkpoint_step: 2000
decay_rate: 0.5
decay_steps:
- 20000
- 50000
- 90000
adversarial:
initial_lr: !!float 5e-5
num_steps: 400000
print_step: 1000
checkpoint_step: 2000
decay_rate: 0.5
decay_steps:
- 50000
- 100000
- 200000
lambda: !!float 5e-2
alpha: !!float 1e-4
scale_factor: 1
scale_value: 4
student_network: "rrdb_student"
student_config:
vgg_student:
trunk_depth: 3 # Minimum 2
# growth_channels: 32
use_bias: True
residual_student:
trunk_depth: 2
use_bias: True
residual_scale_beta: 0.8
rrdb_student:
rdb_config:
depth: 3
residual_scale_beta: 0.2
rrdb_config:
rdb_units: 2
residual_scale_beta: 0.2
trunk_size: 3
growth_channels: 32
""" Module for exporting TFLite of the Student Model """
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
import os
import argparse
from tqdm import tqdm
from absl import logging
import tensorflow as tf
import numpy as np
from libs import settings
from libs import model
from libs import lazy_loader
from PIL import Image
def representative_dataset_gen(
num_calibartion_steps,
datadir,
hr_size,
tflite_size):
""" Creates Generator for Representative datasets for Quantizing TFLite
Args:
num_calibartion_steps: Number of steps to calibrate the model for.
datadir: Directory containing TFRecord dataset.
hr_size: [height, width, channel] for high resolution images.
tflite_size: [height, width] for the TFLite File.
"""
hr_size = tf.cast(tf.convert_to_tensor(hr_size), tf.float32)
lr_size = tf.cast(
hr_size * tf.convert_to_tensor([1. / 4, 1. / 4, 1]), tf.int32)
hr_size = tf.cast(hr_size, tf.int32)
# Loading TFRecord Dataset
ds = (dataset.load_dataset(
datadir,
lr_size=lr_size,
hr_size=hr_size)
.take(num_calibartion_steps)
.prefetch(10))
def _gen_fn():
for _, image_hr in tqdm(ds, total=num_calibartion_steps):
image_hr = tf.cast(image_hr, tf.uint8)
lr_image = np.asarray(
Image.fromarray(image_hr.numpy())
.resize([tflite_size[1], tflite_size[0]],
Image.BICUBIC))
yield [tf.expand_dims(tf.cast(lr_image, tf.float32), 0).numpy()]
return _gen_fn
def export_tflite(config="", modeldir="", mode="", **kwargs):
"""
Exports SavedModel(if not present) TFLite of the student generator.
Args:
config: Path to config file of the student.
modeldir: Path to export the SavedModel and the TFLite to.
mode: Mode of training to export. (Advsersarial / comparative)
"""
# TODO (@captain-pool): Fix Quantization and mention them in the args list.
lazy = lazy_loader.LazyLoader()
lazy.import_("teacher_imports", parent="libs", return_=False)
lazy.import_("utils", parent="libs", return_=False)
lazy.import_("dataset", parent="libs", return_=False)
globals().update(lazy.import_dict)
status = None
sett = settings.Settings(config, use_student_settings=True)
stats = settings.Stats(os.path.join(sett.path, "stats.yaml"))
student_name = sett["student_network"]
student_generator = model.Registry.models[student_name](first_call=False)
ckpt = tf.train.Checkpoint(student_generator=student_generator)
logging.info("Initiating Variables. Tracing Function.")
student_generator.predict(tf.random.normal([1, 180, 320, 3]))
if stats.get(mode):
status = utils.load_checkpoint(
ckpt,
"%s_checkpoint" % mode,
basepath=modeldir,
use_student_settings=True)
if not status:
raise IOError("No checkpoint found to restore")
saved_model_dir = os.path.join(modeldir, "compressed_esrgan")
if not tf.io.gfile.exists(
os.path.join(saved_model_dir, "saved_model.pb")):
tf.saved_model.save(
student_generator,
saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# TODO (@captain-pool): Try to fix Qunatization
# Current Error: Cannot Quantize LEAKY_RELU and CONV2D_TRANSPOSE
# Quantization Code Fragment
# converter.target_spec.supported_types = [tf.float16]
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.representative_dataset = representative_dataset_gen(
# kwargs["calibration_steps"],
# kwargs["datadir"],
# sett["hr_size"],
# [180, 320, 3])
tflite_model = converter.convert()
tflite_path = os.path.join(modeldir, "tflite", "compressed_esrgan.tflite")
with tf.io.gfile.GFile(tflite_path, "wb") as f:
f.write(tflite_model)
logging.info("Successfully writen the TFLite to: %s" % tflite_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--modeldir",
default="",
help="Directory of the saved checkpoints")
parser.add_argument(
"--datadir",
default=None,
help="Path to TFRecords dataset")
parser.add_argument(
"--calibration_steps",
default=1000,
type=int,
help="Number of Steps to calibrate on")
parser.add_argument(
"--config",
default="config/config.yaml",
help="Configuration File to be loaded")
parser.add_argument(
"--mode",
default="none",
help="mode of training to load (adversarial / comparative)")
parser.add_argument(
"--verbose",
"-v",
default=0,
action="count",
help="Increases Verbosity. Repeat to increase more")
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG]
FLAGS, unknown = parser.parse_known_args()
log_level = log_levels[min(FLAGS.verbose, len(log_levels) - 1)]
logging.set_verbosity(log_level)
export_tflite(**vars(FLAGS))
from absl import logging
import os
from lib import dataset
from libs import settings
import tensorflow as tf
def generate_tf_record(
data_dir,
raw_data=False,
tfrecord_path="serialized_dataset",
num_shards=8):
teacher_sett = settings.Settings(use_student_settings=False)
student_sett = settings.Settings(use_student_settings=True)
dataset_args = teacher_sett["dataset"]
if dataset_args["name"].lower().strip() == "div2k":
assert len(data_dir) == 2
ds = dataset.load_div2k_dataset(
data_dir[0],
data_dir[1],
student_sett["hr_size"],
shuffle=True)
elif raw_data:
ds = dataset.load_dataset_directory(
dataset_args["name"],
data_dir,
dataset.scale_down(
method=dataset_args["scale_method"],
size=student_sett["hr_size"]))
else:
ds = dataset.load_dataset(
dataset_args["name"],
dataset.scale_down(
method=dataset_args["scale_method"],
size=student_sett["hr_size"]),
data_dir=data_dir)
to_tfrecord(ds, tfrecord_path, num_shards)
def load_dataset(tfrecord_path, lr_size, hr_size):
def _parse_tf_record(serialized_example):
features = {
"low_res_image": tf.io.FixedLenFeature([], dtype=tf.string),
"high_res_image": tf.io.FixedLenFeature([], dtype=tf.string)}
example = tf.io.parse_single_example(serialized_example, features)
lr_image = tf.io.parse_tensor(
example["low_res_image"],
out_type=tf.float32)
lr_image = tf.reshape(lr_image, lr_size)
hr_image = tf.io.parse_tensor(
example["high_res_image"],
out_type=tf.float32)
hr_image = tf.reshape(hr_image, hr_size)
return lr_image, hr_image
files = tf.io.gfile.glob(
os.path.join(tfrecord_path, "*.tfrecord"))
if len(files) == 0:
raise ValueError("Path Doesn't contain any file")
ds = tf.data.TFRecordDataset(files).map(_parse_tf_record)
if len(files) == 1:
option = tf.data.Options()
option.auto_shard = False
ds.with_options(ds)
ds = ds.shuffle(128, reshuffle_each_iteration=True)
return ds
def to_tfrecord(ds, tfrecord_path, num_shards=8):
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def serialize_to_string(image_lr, image_hr):
features = {
"low_res_image": _bytes_feature(
tf.io.serialize_tensor(image_lr).numpy()),
"high_res_image": _bytes_feature(
tf.io.serialize_tensor(image_hr).numpy())}
example_proto = tf.train.Example(
features=tf.train.Features(feature=features))
return example_proto.SerializeToString()
def write_to_tfrecord(shard_id, ds):
filename = tf.strings.join(
[tfrecord_path, "/dataset.", tf.strings.as_string(shard_id),
".tfrecord"])
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(ds.map(lambda _, x: x))
return tf.data.Dataset.from_tensors(filename)
def map_serialize_to_string(image_lr, image_hr):
map_fn = tf.py_function(
serialize_to_string,
(image_lr, image_hr),
tf.string)
return tf.reshape(map_fn, ())
ds = ds.map(map_serialize_to_string)
ds = ds.enumerate()
ds = ds.apply(tf.data.experimental.group_by_window(
lambda i, _: i % num_shards,
write_to_tfrecord,
tf.int64.max))
for data in ds:
logging.info("Written to: %s" % data.numpy())
import importlib
from .settings import singleton
@singleton
class LazyLoader:
def __init__(self):
self.__import_dict = {}
@property
def import_dict(self):
return self.__import_dict
@import_dict.setter
def import_dict(self, key, value):
self.__import_dict[key] = value
def import_(self, name, alias=None, parent=None, return_=True):
if alias is None:
alias = name
if self.__import_dict.get(alias, None) is not None and return_:
return self.__import_dict[alias]
if parent is not None:
exec("from {} import {} as {}".format(parent, name, "_" + alias))
self.__import_dict[alias] = locals()["_" + alias]
else:
module = importlib.import_module(name)
self.__import_dict[alias] = module
if return_:
return self.__import_dict[alias]
def __getitem__(self, key):
return self.__import_dict[key]
""" Module to auto-register students and expose student registry """
from libs import models
from libs.models.abstract import Registry
from libs.models.student_vgg import VGGStudent
from libs.models.student_rrdb import RRDBStudent
from libs.models.student_residual import ResidualStudent
import re
from tensorflow.python import keras
class Registry(type):
models = {}
def _convert_to_snake(name):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def __init__(cls, name, bases, attrs):
if Registry._convert_to_snake(name) != "model":
Registry.models[Registry._convert_to_snake(cls.__name__)] = cls
# Abstract class for Auto Registration of Kernels
class Model(keras.models.Model, metaclass=Registry):
def __init__(self, *args, **kwargs):
super(Model, self).__init__()
self.init(*args, **kwargs)
def init(self, *args, **kwargs):
pass
""" Module containing ResNet architecture as a student Network """
import tensorflow as tf
from libs.models import abstract
from libs import settings
from functools import partial
class ResidualStudent(abstract.Model):
""" ResNet Student Network Model """
def init(self):
self.student_settings = settings.Settings(use_student_settings=True)
self._scale_factor = self.student_settings["scale_factor"]
self._scale_value = self.student_settings["scale_value"]
model_args = self.student_settings["student_config"]["residual_student"]
depth = model_args["trunk_depth"]
depthwise_convolution = partial(
tf.keras.layers.DepthwiseConv2D,
kernel_size=[3, 3],
strides=[1, 1],
padding="same",
use_bias=model_args["use_bias"])
convolution = partial(
tf.keras.layers.Conv2D,
kernel_size=[3, 3],
strides=[1, 1],
padding="same",
use_bias=model_args["use_bias"])
conv_transpose = partial(
tf.keras.layers.Conv2DTranspose,
kernel_size=[3, 3],
strides=self._scale_value,
padding="same")
self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
self._residual_scale = model_args["residual_scale_beta"]
self._conv_layers = {
"conv_%d" % index:
depthwise_convolution() for index in range(1, depth + 1)}
self._upscale_layers = {
"upscale_%d" % index:
conv_transpose(filters=32) for index in range(1, self._scale_factor)}
self._last_layer = conv_transpose(filters=3)
@tf.function(
input_signature=[
tf.TensorSpec(
shape=[None, 180, 270, 3], # 720x1080 Images
dtype=tf.float32)])
def call(self, inputs):
return self.unsigned_call(inputs)
def unsigned_call(self, inputs):
intermediate = inputs
for layer_name in self._conv_layers:
intermediate += self._lrelu(self._conv_layers[layer_name](intermediate))
intermediate = inputs + self._residual_scale * intermediate # Residual Trunk
# Upsampling
for layer_name in self._upscale_layers:
intermediate = self._lrelu(
self._upscale_layers[layer_name](intermediate))
return self._last_layer(intermediate)
""" Module contains Residual in Residual student network and helper layers. """
from functools import partial
from libs.models import abstract
from libs import settings
import tensorflow as tf
from absl import logging
class ResidualDenseBlock(tf.keras.layers.Layer):
"""
Keras implmentation of Residual Dense Block.
(https://arxiv.org/pdf/1809.00219.pdf)
"""
def __init__(self, first_call=True):
super(ResidualDenseBlock, self).__init__()
self.settings = settings.Settings(use_student_settings=True)
rdb_config = self.settings["student_config"]["rrdb_student"]["rdb_config"]
convolution = partial(
tf.keras.layers.DepthwiseConv2D,
depthwise_initializer="he_normal",
bias_initializer="he_normal",
kernel_size=[3, 3],
strides=[1, 1],
padding="same")
self._first_call = first_call
self._conv_layers = {
"conv_%d" % index: convolution()
for index in range(1, rdb_config["depth"])}
self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
self._beta = rdb_config["residual_scale_beta"]
def call(self, inputs):
intermediates = [inputs]
for layer_name in self._conv_layers:
residual_junction = tf.math.add_n(intermediates)
raw_intermediate = self._conv_layers[layer_name](residual_junction)
activated_intermediate = self._lrelu(raw_intermediate)
intermediates.append(activated_intermediate)
if self._first_call:
logging.debug("Initializing Layers with 0.1 x MSRA")
for _, layer in self._conv_layers.items():
for weight in layer.trainable_variables:
weight.assign(0.1 * weight)
self._first_call = False
return inputs + self._beta * raw_intermediate
class ResidualInResidualBlock(tf.keras.layers.Layer):
"""
Keras implmentation of Residual in Residual Block.
(https://arxiv.org/pdf/1809.00219.pdf)
"""
def __init__(self, first_call=True):
super(ResidualInResidualBlock, self).__init__()
self.settings = settings.Settings(use_student_settings=True)
rrdb_config = self.settings["student_config"]["rrdb_student"]["rrdb_config"]
self._rdb_layers = {
"rdb_%d" % index: ResidualDenseBlock(first_call=first_call)
for index in range(1, rrdb_config["rdb_units"])}
self._beta = rrdb_config["residual_scale_beta"]
def call(self, inputs):
intermediate = inputs
for layer_name in self._rdb_layers:
intermediate = self._rdb_layers[layer_name](intermediate)
return inputs + self._beta * intermediate
class RRDBStudent(abstract.Model):
"""
Keras implmentation of Residual in Residual Student Network.
(https://arxiv.org/pdf/1809.00219.pdf)
"""
def init(self, first_call=True):
self.settings = settings.Settings(use_student_settings=True)
self._scale_factor = self.settings["scale_factor"]
self._scale_value = self.settings["scale_value"]
rrdb_student_config = self.settings["student_config"]["rrdb_student"]
rrdb_block = partial(ResidualInResidualBlock, first_call=first_call)
growth_channels = rrdb_student_config["growth_channels"]
depthwise_conv = partial(
tf.keras.layers.DepthwiseConv2D,
kernel_size=[3, 3],
strides=[1, 1],
use_bias=True,
padding="same")
convolution = partial(
tf.keras.layers.Conv2D,
kernel_size=[3, 3],
use_bias=True,
strides=[1, 1],
padding="same")
conv_transpose = partial(
tf.keras.layers.Conv2DTranspose,
kernel_size=[3, 3],
use_bias=True,
strides=self._scale_value,
padding="same")
self._rrdb_trunk = tf.keras.Sequential(
[rrdb_block() for _ in range(rrdb_student_config["trunk_size"])])
self._first_conv = convolution(filters=64)
self._upsample_layers = {
"upsample_%d" % index: conv_transpose(filters=growth_channels)
for index in range(1, self._scale_factor)}
key = "upsample_%d" % self._scale_factor
self._upsample_layers[key] = conv_transpose(filters=3)
self._conv_last = depthwise_conv()
self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
def call(self, inputs):
return self.unsigned_call(inputs)
def unsigned_call(self, inputs):
residual_start = self._first_conv(inputs)
intermediate = residual_start + self._rrdb_trunk(residual_start)
for layer_name in self._upsample_layers:
intermediate = self._lrelu(
self._upsample_layers[layer_name](intermediate))
out = self._conv_last(intermediate)
return out
""" Student Network with VGG Architecture """
import tensorflow as tf
from libs import settings
from libs.models import abstract
from functools import partial
class VGGStudent(abstract.Model):
def init(self):
sett = settings.Settings(use_student_settings=True)
model_args = sett["student_config"]["vgg_student"]
self._scale_factor = sett["scale_factor"]
self._scale_value = sett["scale_value"]
depth = model_args["trunk_depth"] # Minimum 2 for scale factor of 4
depthwise_convolution = partial(
tf.keras.layers.DepthwiseConv2D,
kernel_size=[3, 3],
padding="same",
use_bias=model_args["use_bias"])
conv_transpose = partial(
tf.keras.layers.Conv2DTranspose,
kernel_size=[3, 3],
strides=self._scale_value,
padding="same")
convolution = partial(
tf.keras.layers.Conv2D,
kernel_size=[3, 3],
strides=[1, 1],
use_bias=model_args["use_bias"],
padding="same")
trunk_depth = depth - self._scale_factor
self._conv_layers = {
"conv_%d" % index: depthwise_convolution()
for index in range(1, trunk_depth + 1)}
self._upsample_layers = {
"upsample_%d" % index: conv_transpose(filters=32)
for index in range(1, self._scale_factor)}
self._last_layer = conv_transpose(filters=3)
@tf.function(
input_signature=[
tf.TensorSpec(
shape=[None, 180, 270, 3], # For 720x1080 images
dtype=tf.float32)])
def call(self, inputs):
return self.unsigned_call(self, inputs)
def unsigned_call(self, inputs):
intermediate = inputs
for layer_name in self._conv_layers:
intermediate = self._conv_layers[layer_name](intermediate)
intermediate = tf.keras.layers.LeakyReLU(alpha=0.2)(intermediate)
for layer_name in self._upsample_layers:
intermediate = self._upsample_layers[layer_name](intermediate)
intermediate = tf.keras.layers.LeakyReLU(alpha=0.2)(intermediate)
return self._last_layer(intermediate)
""" Module to load Teacher Models from Teacher Directory """
import os
import sys
from libs import teacher_imports
from lib import settings
settings.Settings("%s/config/config.yaml" % teacher_imports.TEACHER_DIR)
from lib.model import VGGArch as discriminator
from lib.model import RRDBNet as generator
""" Module containing settings and training statistics file handler."""
import os
import yaml
def singleton(cls):
instances = {}
def getinstance(*args, **kwargs):
distill_config = kwargs.get("use_student_settings", "")
key = cls.__name__
if distill_config:
key = "%s_student" % (cls.__name__)
if key not in instances:
instances[key] = cls(*args, **kwargs)
return instances[key]
return getinstance
@singleton
class Settings(object):
""" Settings class to handle yaml config files """
def __init__(
self,
filename="config/config.yaml",
use_student_settings=False):
self.__path = os.path.abspath(filename)
@property
def path(self):
return os.path.dirname(self.__path)
def __getitem__(self, index):
with open(self.__path, "r") as file_:
return yaml.load(file_.read(), Loader=yaml.FullLoader)[index]
def get(self, index, default=None):
with open(self.__path, "r") as file_:
return yaml.load(
file_.read(),
Loader=yaml.FullLoader).get(
index,
default)
class Stats(object):
""" Class to handle Training Statistics File """
def __init__(self, filename="stats.yaml"):
if os.path.exists(filename):
with open(filename, "r") as file_:
self.__data = yaml.load(file_.read(), Loader=yaml.FullLoader)
else:
self.__data = {}
self.file = filename
def get(self, index, default=None):
return self.__data.get(index, default)
def __getitem__(self, index):
return self.__data[index]
def __repr__(self):
return "Stats(" + str(self.__data) + ")"
def __setitem__(self, index, data):
self.__data[index] = data
with open(self.file, "w") as file_:
yaml.dump(self.__data, file_, default_flow_style=False)
import sys
import os
from libs import settings
_teacher_directory = settings.Settings(use_student_settings=True)[
"teacher_directory"]
TEACHER_DIR = os.path.abspath(_teacher_directory)
if _teacher_directory not in sys.path:
sys.path.insert(0, TEACHER_DIR)
""" Trainer class to train student network to compress ESRGAN """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
from libs import dataset
from libs import settings
from libs import utils
import tensorflow as tf
class Trainer(object):
"""Trainer Class for Knowledge Distillation of ESRGAN"""
def __init__(
self,
teacher,
discriminator,
summary_writer,
summary_writer_2=None,
model_dir="",
data_dir="",
strategy=None):
"""
Args:
teacher: Keras Model of pre-trained teacher generator.
(Generator of ESRGAN)
discriminator: Keras Model of pre-trained teacher discriminator.
(Discriminator of ESRGAN)
summary_writer: tf.summary.SummaryWriter object for writing
summary for Tensorboard.
data_dir: Location of the stored dataset.
raw_data: Indicate if data_dir contains Raw Data or TFRecords.
model_dir: Location to store checkpoints and SavedModel directory.
"""
self.teacher_generator = teacher
self.teacher_discriminator = discriminator
self.teacher_settings = settings.Settings(use_student_settings=False)
self.student_settings = settings.Settings(use_student_settings=True)
self.model_dir = model_dir
self.strategy = strategy
self.train_args = self.student_settings["train"]
self.batch_size = self.teacher_settings["batch_size"]
self.hr_size = self.student_settings["hr_size"]
self.lr_size = tf.unstack(self.hr_size)[:-1]
self.lr_size.append(tf.gather(self.hr_size, len(self.hr_size) - 1) * 4)
self.lr_size = tf.stack(self.lr_size) // 4
self.summary_writer = summary_writer
self.summary_writer_2 = summary_writer_2
# Loading TFRecord Dataset
self.dataset = dataset.load_dataset(
data_dir,
lr_size=self.lr_size,
hr_size=self.hr_size)
self.dataset = (
self.dataset.repeat()
.batch(self.batch_size, drop_remainder=True)
.prefetch(1024))
self.dataset = iter(self.strategy.experimental_distribute_dataset(
self.dataset))
# Reloading Checkpoint from Phase 2 Training of ESRGAN
checkpoint = tf.train.Checkpoint(
G=self.teacher_generator,
D=self.teacher_discriminator)
utils.load_checkpoint(
checkpoint,
"phase_2",
basepath=model_dir,
use_student_settings=False)
def train_comparative(self, student, export_only=False):
"""
Trains the student using a comparative loss function (Mean Squared Error)
based on the output of Teacher.
Args:
student: Keras model of the student.
"""
train_args = self.train_args["comparative"]
total_steps = train_args["num_steps"]
decay_rate = train_args["decay_rate"]
decay_steps = train_args["decay_steps"]
optimizer = tf.optimizers.Adam()
checkpoint = tf.train.Checkpoint(
student_generator=student,
student_optimizer=optimizer)
status = utils.load_checkpoint(
checkpoint,
"comparative_checkpoint",
basepath=self.model_dir,
use_student_settings=True)
if export_only:
return
loss_fn = tf.keras.losses.MeanSquaredError(reduction="none")
metric_fn = tf.keras.metrics.Mean()
student_psnr = tf.keras.metrics.Mean()
teacher_psnr = tf.keras.metrics.Mean()
def step_fn(image_lr, image_hr):
"""
Function to be replicated among the worker nodes
Args:
image_lr: Distributed Batch of Low Resolution Images
image_hr: Distributed Batch of High Resolution Images
"""
with tf.GradientTape() as tape:
teacher_fake = self.teacher_generator.unsigned_call(image_lr)
student_fake = student.unsigned_call(image_lr)
student_fake = tf.clip_by_value(student_fake, 0, 255)
teacher_fake = tf.clip_by_value(teacher_fake, 0, 255)
image_hr = tf.clip_by_value(image_hr, 0, 255)
student_psnr(tf.reduce_mean(tf.image.psnr(student_fake, image_hr, max_val=256.0)))
teacher_psnr(tf.reduce_mean(tf.image.psnr(teacher_fake, image_hr, max_val=256.0)))
loss = utils.pixelwise_mse(teacher_fake, student_fake)
loss = tf.reduce_mean(loss) * (1.0 / self.batch_size)
metric_fn(loss)
student_vars = list(set(student.trainable_variables))
gradient = tape.gradient(loss, student_vars)
train_op = optimizer.apply_gradients(
zip(gradient, student_vars))
with tf.control_dependencies([train_op]):
return tf.cast(optimizer.iterations, tf.float32)
@tf.function
def train_step(image_lr, image_hr):
"""
In Graph Function to assign trainer function to
replicate among worker nodes.
Args:
image_lr: Distributed batch of Low Resolution Images
image_hr: Distributed batch of High Resolution Images
"""
distributed_metric = self.strategy.experimental_run_v2(
step_fn, args=(image_lr, image_hr))
mean_metric = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, distributed_metric, axis=None)
return mean_metric
logging.info("Starting comparative loss training")
while True:
image_lr, image_hr = next(self.dataset)
step = train_step(image_lr, image_hr)
if step >= total_steps:
return
for _step in decay_steps.copy():
if step >= _step:
decay_steps.pop(0)
logging.debug("Reducing Learning Rate by: %f" % decay_rate)
optimizer.learning_rate.assign(optimizer.learning_rate * decay_rate)
if status:
status.assert_consumed()
logging.info("Checkpoint loaded successfully")
status = None
# Writing Summary
with self.summary_writer.as_default():
tf.summary.scalar("student_loss", metric_fn.result(), step=optimizer.iterations)
tf.summary.scalar("mean_psnr", student_psnr.result(), step=optimizer.iterations)
if self.summary_writer_2:
with self.summary_writer_2.as_default():
tf.summary.scalar("mean_psnr", teacher_psnr.result(), step=optimizer.iterations)
if not step % train_args["print_step"]:
logging.info("[COMPARATIVE LOSS] Step: %s\tLoss: %s" %
(step, metric_fn.result()))
# Saving Checkpoint
if not step % train_args["checkpoint_step"]:
utils.save_checkpoint(
checkpoint,
"comparative_checkpoint",
basepath=self.model_dir,
use_student_settings=True)
def train_adversarial(self, student, export_only=False):
"""
Train the student adversarially using a joint loss between teacher discriminator
and mean squared error between the output of the student-teacher generator pair.
Args:
student: Keras model of the student to train.
"""
train_args = self.train_args["adversarial"]
total_steps = train_args["num_steps"]
decay_steps = train_args["decay_steps"]
decay_rate = train_args["decay_rate"]
lambda_ = train_args["lambda"]
alpha = train_args["alpha"]
generator_metric = tf.keras.metrics.Mean()
discriminator_metric = tf.keras.metrics.Mean()
generator_optimizer = tf.optimizers.Adam(learning_rate=train_args["initial_lr"])
dummy_optimizer = tf.optimizers.Adam()
discriminator_optimizer = tf.optimizers.Adam(learning_rate=train_args["initial_lr"])
checkpoint = tf.train.Checkpoint(
student_generator=student,
student_optimizer=generator_optimizer,
teacher_optimizer=discriminator_optimizer,
teacher_generator=self.teacher_generator,
teacher_discriminator=self.teacher_discriminator)
status = None
if not utils.checkpoint_exists(
names="adversarial_checkpoint",
basepath=self.model_dir,
use_student_settings=True):
if export_only:
raise ValueError("Adversarial checkpoints not found")
else:
status = utils.load_checkpoint(
checkpoint,
"adversarial_checkpoint",
basepath=self.model_dir,
use_student_settings=True)
if export_only:
if not status:
raise ValueError("No Checkpoint Loaded!")
return
ra_generator = utils.RelativisticAverageLoss(
self.teacher_discriminator, type_="G")
ra_discriminator = utils.RelativisticAverageLoss(
self.teacher_discriminator, type_="D")
perceptual_loss = utils.PerceptualLoss(
weights="imagenet",
input_shape=self.hr_size,
loss_type="L2")
student_psnr = tf.keras.metrics.Mean()
teacher_psnr = tf.keras.metrics.Mean()
def step_fn(image_lr, image_hr):
"""
Function to be replicated among the worker nodes
Args:
image_lr: Distributed Batch of Low Resolution Images
image_hr: Distributed Batch of High Resolution Images
"""
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
teacher_fake = self.teacher_generator.unsigned_call(image_lr)
logging.debug("Fetched Fake: Teacher")
teacher_fake = tf.clip_by_value(teacher_fake, 0, 255)
student_fake = student.unsigned_call(image_lr)
logging.debug("Fetched Fake: Student")
student_fake = tf.clip_by_value(student_fake, 0, 255)
image_hr = tf.clip_by_value(image_hr, 0, 255)
psnr = tf.image.psnr(student_fake, image_hr, max_val=255.0)
student_psnr(tf.reduce_mean(psnr))
psnr = tf.image.psnr(teacher_fake, image_hr, max_val=255.0)
teacher_psnr(tf.reduce_mean(psnr))
mse_loss = utils.pixelwise_mse(teacher_fake, student_fake)
image_lr = utils.preprocess_input(image_lr)
image_hr = utils.preprocess_input(image_hr)
student_fake = utils.preprocess_input(student_fake)
teacher_fake = utils.preprocess_input(teacher_fake)
student_ra_loss = ra_generator(teacher_fake, student_fake)
logging.debug("Relativistic Average Loss: Student")
discriminator_loss = ra_discriminator(teacher_fake, student_fake)
discriminator_metric(discriminator_loss)
discriminator_loss = tf.reduce_mean(
discriminator_loss) * (1.0 / self.batch_size)
logging.debug("Relativistic Average Loss: Teacher")
percep_loss = perceptual_loss(image_hr, student_fake)
generator_loss = lambda_ * percep_loss + alpha * student_ra_loss + (1 - alpha) * mse_loss
generator_metric(generator_loss)
logging.debug("Calculated Joint Loss for Generator")
generator_loss = tf.reduce_mean(
generator_loss) * (1.0 / self.batch_size)
generator_gradient = gen_tape.gradient(
generator_loss, student.trainable_variables)
logging.debug("calculated gradient: generator")
discriminator_gradient = disc_tape.gradient(
discriminator_loss, self.teacher_discriminator.trainable_variables)
logging.debug("calculated gradient: discriminator")
generator_op = generator_optimizer.apply_gradients(
zip(generator_gradient, student.trainable_variables))
logging.debug("applied generator gradients")
discriminator_op = discriminator_optimizer.apply_gradients(
zip(discriminator_gradient, self.teacher_discriminator.trainable_variables))
logging.debug("applied discriminator gradients")
with tf.control_dependencies(
[generator_op, discriminator_op]):
return tf.cast(discriminator_optimizer.iterations, tf.float32)
@tf.function
def train_step(image_lr, image_hr):
"""
In Graph Function to assign trainer function to
replicate among worker nodes.
Args:
image_lr: Distributed batch of Low Resolution Images
image_hr: Distributed batch of High Resolution Images
"""
distributed_metric = self.strategy.experimental_run_v2(
step_fn,
args=(image_lr, image_hr))
mean_metric = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN,
distributed_metric, axis=None)
return mean_metric
logging.info("Starting Adversarial Training")
while True:
image_lr, image_hr = next(self.dataset)
step = train_step(image_lr, image_hr)
if status:
status.assert_consumed()
logging.info("Loaded Checkpoint successfully!")
status = None
if not isinstance(decay_steps, list):
if not step % decay_steps:
logging.debug("Decaying Learning Rate by: %s" % decay_rate)
generator_optimizer.learning_rate.assign(
generator_optimizer.learning_rate * decay_rate)
discriminator_optimizer.learning_rate.assign(
discriminator_optimizer.learning_rate * decay_rate)
else:
for decay_step in decay_steps.copy():
if decay_step <= step:
decay_steps.pop(0)
logging.debug("Decaying Learning Rate by: %s" % decay_rate)
generator_optimizer.learning_rate.assign(
generator_optimizer.learning_rate * decay_rate)
discriminator_optimizer.learning_rate.assign(
discriminator_optimizer.learning_rate * decay_rate)
# Setting Up Logging
with self.summary_writer.as_default():
tf.summary.scalar(
"student_loss",
generator_metric.result(),
step=discriminator_optimizer.iterations)
tf.summary.scalar(
"teacher_discriminator_loss",
discriminator_metric.result(),
step=discriminator_optimizer.iterations)
tf.summary.scalar(
"mean_psnr",
student_psnr.result(),
step=discriminator_optimizer.iterations)
if self.summary_writer_2:
with self.summary_writer_2.as_default():
tf.summary.scalar(
"mean_psnr",
teacher_psnr.result(),
step=discriminator_optimizer.iterations)
if not step % train_args["print_step"]:
logging.info(
"[ADVERSARIAL] Step: %s\tStudent Loss: %s\t"
"Discriminator Loss: %s" %
(step, generator_metric.result(),
discriminator_metric.result()))
# Setting Up Checkpoint
if not step % train_args["checkpoint_step"]:
utils.save_checkpoint(
checkpoint,
"adversarial_checkpoint",
basepath=self.model_dir,
use_student_settings=True)
if step >= total_steps:
return
""" Module containing utility functions for the trainer """
import os
from absl import logging
from libs import settings
import tensorflow as tf
# Loading utilities from ESRGAN
from lib.utils import assign_to_worker
from lib.utils import PerceptualLoss
from lib.utils import RelativisticAverageLoss
from lib.utils import SingleDeviceStrategy
from lib.utils import preprocess_input
def checkpoint_exists(names, basepath="", use_student_settings=False):
sett = settings.Settings(use_student_settings=use_student_settings)
if tf.nest.is_nested(names):
if isinstance(names, dict):
return False
else:
names = [names]
values = []
for name in names:
dir_ = os.path.join(basepath, sett["checkpoint_path"][name], "checkpoint")
values.append(tf.io.gfile.exists(dir_))
return any(values)
def save_checkpoint(checkpoint, name, basepath="", use_student_settings=False):
""" Saves Checkpoint
Args:
checkpoint: tf.train.Checkpoint object to save.
name: name of the checkpoint to save.
basepath: base directory where checkpoint should be saved
student: boolean to indicate if settings of the student should be used.
"""
sett = settings.Settings(use_student_settings=use_student_settings)
dir_ = os.path.join(basepath, sett["checkpoint_path"][name])
logging.info("Saving checkpoint: %s Path: %s" % (name, dir_))
prefix = os.path.join(dir_, os.path.basename(dir_))
checkpoint.save(file_prefix=prefix)
def load_checkpoint(checkpoint, name, basepath="", use_student_settings=False):
""" Restores Checkpoint
Args:
checkpoint: tf.train.Checkpoint object to restore.
name: name of the checkpoint to restore.
basepath: base directory where checkpoint is located.
student: boolean to indicate if settings of the student should be used.
"""
sett = settings.Settings(use_student_settings=use_student_settings)
dir_ = os.path.join(basepath, sett["checkpoint_path"][name])
if tf.io.gfile.exists(os.path.join(dir_, "checkpoint")):
logging.info("Found checkpoint: %s Path: %s" % (name, dir_))
status = checkpoint.restore(tf.train.latest_checkpoint(dir_))
return status
logging.info("No Checkpoint found for %s" % name)
# Losses
def pixelwise_mse(y_true, y_pred):
mean_squared_error = tf.reduce_mean(tf.reduce_mean(
(y_true - y_pred)**2, axis=0))
return mean_squared_error
"""
Compressing GANs using Knowledge Distillation.
Teacher GAN: ESRGAN (https://github.com/captain-pool/E2_ESRGAN)
Citation:
@article{DBLP:journals/corr/abs-1902-00159,
author = {Angeline Aguinaldo and
Ping{-}Yeh Chiang and
Alexander Gain and
Ameya Patil and
Kolten Pearson and
Soheil Feizi},
title = {Compressing GANs using Knowledge Distillation},
journal = {CoRR},
volume = {abs/1902.00159},
year = {2019},
url = {http://arxiv.org/abs/1902.00159},
archivePrefix = {arXiv},
eprint = {1902.00159},
timestamp = {Tue, 21 May 2019 18:03:39 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1902-00159},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
import os
from absl import logging
import argparse
from libs import lazy_loader
from libs import model
from libs import settings
import tensorflow as tf
def train_and_export(**kwargs):
""" Train and Export Compressed ESRGAN
Args:
config: path to config file.
logdir: path to logging directory
modeldir: Path to store the checkpoints and exported model.
datadir: Path to custom data directory.
manual: Boolean to indicate if `datadir` contains Raw Files(True) / TFRecords (False)
"""
lazy = lazy_loader.LazyLoader()
student_settings = settings.Settings(
kwargs["config"], use_student_settings=True)
# Lazy importing dependencies from teacher
lazy.import_("teacher_imports", parent="libs", return_=False)
lazy.import_("teacher", parent="libs.models", return_=False)
lazy.import_("train", parent="libs", return_=False)
lazy.import_("utils", parent="libs", return_=False)
globals().update(lazy.import_dict)
tf.random.set_seed(10)
teacher_settings = settings.Settings(
student_settings["teacher_config"], use_student_settings=False)
stats = settings.Stats(os.path.join(student_settings.path, "stats.yaml"))
strategy = utils.SingleDeviceStrategy()
if kwargs["tpu"]:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
kwargs["tpu"])
tf.config.experimental_connect_to_host(cluster_resolver.get_master())
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
device_name = utils.assign_to_worker(kwargs["tpu"])
with tf.device(device_name), strategy.scope():
summary_writer = tf.summary.create_file_writer(
os.path.join(kwargs["logdir"], "student"))
teacher_summary_writer = tf.summary.create_file_writer(
os.path.join(kwargs["logdir"], "teacher"))
teacher_generator = teacher.generator(out_channel=3, first_call=False)
teacher_discriminator = teacher.discriminator(
batch_size=teacher_settings["batch_size"])
student_generator = (
model.Registry
.models[student_settings["student_network"]]())
hr_size = tf.cast(tf.convert_to_tensor([1] + student_settings['hr_size']), tf.float32)
lr_size = tf.cast(hr_size * tf.convert_to_tensor([1, 1/4, 1/4, 1]), tf.int32)
logging.debug("Initializing Convolutions")
student_generator.unsigned_call(tf.random.normal(lr_size))
trainer = train.Trainer(
teacher_generator,
teacher_discriminator,
summary_writer,
summary_writer_2=teacher_summary_writer,
model_dir=kwargs["modeldir"],
data_dir=kwargs["datadir"],
strategy=strategy)
phase_name = None
if kwargs["type"].lower().startswith("comparative"):
trainer.train_comparative(
student_generator,
export_only=stats.get("comparative") or kwargs["export_only"])
if not kwargs["export_only"]:
stats["comparative"] = True
elif kwargs["type"].lower().startswith("adversarial"):
trainer.train_adversarial(
student_generator,
export_only=stats.get("adversarial") or kwargs["export_only"])
if not kwargs["export_only"]:
stats["adversarial"] = True
# Tracing Graph to put input signature
_ = student_generator.predict(
tf.random.normal([1, 180, 320, 3]))
tf.saved_model.save(
student_generator,
os.path.join(
kwargs["modeldir"],
"compressed_esrgan"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tpu", default=None, help="Name of the TPU to use")
parser.add_argument("--logdir", default=None, help="Path to log directory")
parser.add_argument(
"--export_only",
default=False,
action="store_true",
help="Do not train, only export the model")
parser.add_argument(
"--config",
default="config/config.yaml",
help="path to config file")
parser.add_argument(
"--datadir",
default=None,
help="Path to custom data directory containing sharded TFRecords")
parser.add_argument(
"--modeldir",
default=None,
help="directory to store checkpoints and SavedModel")
parser.add_argument(
"--type",
default=None,
help="Train Student 'adversarial'-ly / 'comparative'-ly")
parser.add_argument(
"--verbose",
"-v",
default=0,
action="count",
help="Increases Verbosity. Repeat to increase more")
FLAGS, unparsed = parser.parse_known_args()
log_levels = [logging.FATAL, logging.WARNING, logging.INFO, logging.DEBUG]
log_level = log_levels[min(FLAGS.verbose, len(log_levels) - 1)]
logging.set_verbosity(log_level)
train_and_export(**vars(FLAGS))
#! /bin/bash
# $BASE contains the path to the cloud bucket.
# $TPU_NAME contains the name of the cloud bucket.
gsutil -m rm -r $BASE/distillation/logdir/** &>/dev/null
sudo tensorboard --logdir $BASE/distillation/logdir --port 80 &> $HOME/logdir/access.log &
python3 main.py --logdir $BASE/distillation/logdir \
--datadir $BASE/datadir/coco2014_esrgan\
--tpu $TPU_NAME \
--modeldir $BASE/distillation/modeldir \
--type "comparative" -vvv &> $HOME/logdir/main.log &
## Super Resolution Video Player
With success from Distilling ESRGAN to produce high resolution images in milliseconds, (0.3 seconds to be specific). This
application tests the model on the extreme last level. This proof of concept video player takes in normal videos and downscales the frames
bicubically and then perform super resolution on the video frames on CPU and displays them in realtime.
Here's a sample.
Video used: Slash's Anastatia Live in Sydney. (https://www.youtube.com/watch?v=bC8EmPA6H6g)
![sr_player](https://user-images.githubusercontent.com/13994201/63652070-cdd88480-c779-11e9-8890-8dc379755926.png)
Usage
-----------
```bash
$ python3 player.py --file /path/to/video/file \
--saved_model /path/to/saved_model
```
Issues
-------
Check the issue tracker with "Super Resolution Player" as issue label or
[click here](https://github.com/captain-pool/GSOC/issues?q=is%3Aissue+is%3Aopen+label%3A%22Super+Resolution+Player%22)
Further Concepts
------------------
Building an online video streamer sending low resolution videos running on Super Resolved video on the client side.
Check the [experimental](experimental) folder for the codes.
syntax = "proto2";
// Protocol Buffer for Video
// Packets
message Metadata{
required int32 duration = 1;
required int32 video_fps = 2;
required int32 audio_fps = 3;
repeated int32 dimension = 4;
}
message FramePacket{
repeated bytes video_frames = 3;
optional bytes audio_chunk = 4;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment