Skip to content
Snippets Groups Projects
Commit 04bb51f7 authored by tomrink's avatar tomrink
Browse files

snapshot...

parent 05cd1d4f
Branches
No related tags found
No related merge requests found
from __future__ import print_function
import io
import socket
import threading
import time
from absl import logging
import argparse
import queue
import numpy as np
import datapacket_pb2
import pyaudio as pya
import pygame
import tensorflow as tf
from pygame.locals import * # pylint: disable=wildcard-import
pygame.init()
class Client(object):
SYN = b'SYN'
SYNACK = b'SYN/ACK'
ACK = b'ACK'
def __init__(self, ip, port):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._metadata = datapacket_pb2.Metadata()
self._audio_queue = queue.Queue()
self._video_queue = queue.Queue()
self._audio_thread = threading.Thread(
target=self.write_to_stream, args=(True,))
self._video_thread = threading.Thread(
target=self.write_to_stream, args=(False,))
self._running = False
self.tolerance = 30 # Higher Tolerance Higher Frame Rate
self._lock = threading.Lock()
pyaudio = pya.PyAudio()
self._audio_stream = pyaudio.open(
format=pya.paFloat32,
channels=2,
rate=44100,
output=True,
frames_per_buffer=1024)
self._socket.connect((ip, port))
self.fetch_metadata()
def readpacket(self, buffersize=2**32):
buffer_ = io.BytesIO()
done = False
eof = False
while not done:
data = self._socket.recv(buffersize)
if data:
logging.debug("Reading Stream: Buffer Size: %d" % buffersize)
if data[-5:] == b'<EOF>':
logging.debug("Found EOF")
data = data[:-5]
eof = True
done = True
if data[-5:] == b'<END>':
logging.debug("Find End of Message")
data = data[:-5]
done = True
buffer_.write(data)
buffer_.seek(0)
return buffer_.read(), eof
def fetch_metadata(self):
logging.debug("Sending SYN...")
self._socket.send(b'SYN')
logging.debug("Sent Syn. Awating Metadata")
data, eof = self.readpacket(8)
self._metadata.ParseFromString(data)
dimension = self._metadata.dimension
self.screen = pygame.display.set_mode(dimension[:-1][::-1], 0, 32)
def fetch_video(self):
data, eof = self.readpacket()
framedata = datapacket_pb2.FramePacket()
framedata.ParseFromString(data)
frames = []
for frame in framedata.video_frames:
frames.append(self.parse_frames(frame, False))
self._audio_queue.put(framedata.audio_chunk)
self._video_queue.put(frames)
return eof
def parse_frames(self, bytestring, superresolve=False):
frame = np.asarray(bytestring)
if superresolve:
# Perform super resolution here
pass
frame = tf.cast(tf.clip_by_value(frame, 0, 255), tf.float32)
return frame.numpy()
def start(self):
with self._lock:
self._running = True
if not self._audio_thread.isAlive():
self._audio_thread.start()
if not self._video_thread.isAlive():
self._video_thread.start()
self._socket.send(b'ACK')
while not self.fetch_video():
pass # Wait till the end
self.wait_to_end()
def wait_to_end(self):
self._audio_thread.join()
self._video_thread.join()
def stop(self):
with self.lock:
self._running = False
def write_to_stream(self, isaudio=False):
while self._running:
try:
if isaudio:
if self._audio_queue.qsize() < 5:
continue
audio_chunk = self._audio_queue.get(timeout=10)
self._audio_stream.write(audio_chunk)
else:
if self._video_queue.qsize() < 5:
continue
for video_frame in self._video_queue.get(timeout=10):
video_frame = pygame.surfarray.make_surface(
np.rot90(np.fliplr(video_frame)))
self.screen.fill((0, 0, 2))
self.screen.blit(video_frame, (0, 0))
pygame.display.update()
time.sleep(
(1000 / self._metadata.video_fps - self.tolerance) / 1000)
except StopIteration:
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--server",
default="127.0.0.1",
help="Address of stream server.")
parser.add_argument(
"--port",
default=8001,
help="Port of the server to connect to.")
logging.set_verbosity(logging.DEBUG)
client = Client("127.0.0.1", 8001)
client.start()
import socket
from absl import logging
from functools import partial
import threading
import datapacket_pb2
import tensorflow as tf
from moviepy import editor
class StreamClient(object):
SYN = b'SYN'
SYNACK = b'SYN/ACK'
ACK = b'ACK'
def __init__(self, client_socket, client_address, video_path):
self.video_object = editor.VideoFileClip(video_path)
self.audio_object = self.video_object.audio
dimension = list(self.video_object.get_frame(0).shape)
self.metadata = datapacket_pb2.Metadata(
duration=int(self.video_object.duration),
video_fps=int(self.video_object.fps),
audio_fps=int(self.audio_object.fps),
dimension=dimension)
self._client_socket = client_socket
self._client_address = client_address
self._video_iterator = self.video_object.iter_frames()
self._audio_iterator = self.audio_object.iter_chunks(self.metadata.audio_fps)
self.send_video()
def _handshake(self):
logging.debug("Awaiting Handshake")
data = self._client_socket.recv(128)
if data == StreamClient.SYN:
logging.debug("SYN Recieved. Sending Media Metadata")
num_bytes = self._client_socket.send(self.metadata.SerializeToString()+b'<END>')
logging.debug("Metadata sent: Num Bytes Written: %d" % num_bytes)
logging.debug("Awaiting ACK")
data = self._client_socket.recv(128)
logging.debug("Data Recieved")
if data == StreamClient.ACK:
logging.debug("ACK Recieved")
return True
return data
def _video_second(self):
def shrink_fn(image):
image = tf.convert_to_tensor(image)
return image.numpy().tostring()
frames = []
for _ in range(int(self.metadata.video_fps)):
frames.append(shrink_fn(next(self._video_iterator)))
return frames
def _fetch_video(self):
try:
audio = next(self._audio_iterator).astype("float32").tostring()
video = self._video_second()
frame_packet = datapacket_pb2.FramePacket(
video_frames=video,
audio_chunk=audio)
return frame_packet
except StopIteration:
pass
def send_video(self):
sent = False
while not sent:
handshake = self._handshake()
if handshake:
if handshake is not True:
logging.info("[%s] Says: %s" % (self._client_address, handshake))
else:
video_packet = self._fetch_video()
while video_packet:
num_bytes = self._client_socket.send(
video_packet.SerializeToString() + b'<END>')
video_packet = self._fetch_video()
logging.debug("Sending: %d" % num_bytes)
sent = True
self._client_socket.send(b'<EOF>')
class Server(object):
def __init__(self, ip, port):
self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._server_socket.bind((ip, port))
self._server_socket.listen()
self._client_template = partial(threading.Thread, target=StreamClient)
def run(self):
while True:
client_socket, client_addr = self._server_socket.accept()
client_addr = list(map(str, client_addr))
video_path = "/home/rick/video.mp4"
self._client_template(
args=(client_socket, client_addr, video_path)
).start()
logging.info("[SERVER]: %s connected." % ":".join(client_addr))
if __name__ == "__main__":
logging.set_verbosity(logging.DEBUG)
server = Server("127.0.0.1", 8001)
server.run()
from absl import logging
import argparse
import tensorflow as tf
import tensorflow_hub as hub
import os
from PIL import Image
import multiprocessing
from functools import partial
import time
import pyaudio as pya
import threading
import queue
import numpy as np
from moviepy import editor
import pygame
pygame.init()
os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
BUFFER_SIZE = 8
class Player(object):
def __init__(self, videofile, tflite="", saved_model=""):
"""
Player Class for the Video
Args:
videofile: Path to the video file
tflite: Path to the Super Resolution TFLite
saved_model: path to Super Resolution SavedModel
"""
self.video = editor.VideoFileClip(videofile)
self.audio = self.video.audio
self.tolerance = 2.25 # Higher Tolerance Faster Video
self.running = False
self.interpreter = None
self.saved_model = None
if saved_model:
self.saved_model = hub.load(saved_model)
if tflite:
self.interpreter = tf.lite.Interpreter(model_path=tflite)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
self.lock = threading.Lock()
self.audio_thread = threading.Thread(target=self.write_audio_stream)
self.video_thread = threading.Thread(target=self.write_video_stream)
self.video_iterator = self.video.iter_frames()
self.audio_iterator = self.audio.iter_chunks(int(self.audio.fps))
self.video_queue = queue.Queue()
self.audio_queue = queue.Queue()
pyaudio = pya.PyAudio()
issmallscreen = 1 if saved_model or tflite else 0.25
self.screen = pygame.display.set_mode(
(int(1280 * issmallscreen),
int(720 * issmallscreen)), 0, 32)
self.stream = pyaudio.open(
format=pya.paFloat32,
channels=2,
rate=44100,
output=True,
frames_per_buffer=1024)
def tflite_super_resolve(self, frame):
"""
Super Resolve bicubically downsampled image frames
using the TFLite of the model.
Args:
frame: Image frame to scale up.
"""
self.interpreter.set_tensor(self.input_details[0]['index'], frame)
self.interpreter.invoke()
frame = self.interpreter.get_tensor(self.output_details[0]['index'])
frame = tf.squeeze(tf.cast(tf.clip_by_value(frame, 0, 255), "uint8"))
return frame.numpy()
def saved_model_super_resolve(self, frame):
"""
Super Resolve using exported SavedModel.
Args:
frames: Batch of Frames to Scale Up.
"""
if self.saved_model:
start = time.time()
frame = self.saved_model.call(frame)
logging.debug("[SAVED_MODEL] Super Resolving Time: %f" % (time.time() - start))
logging.debug("Returning Modified Frames")
return np.squeeze(np.clip(frame.numpy(), 0, 255).astype("uint8"))
def video_second(self):
"""
Fetch Video Frames for each second
and super resolve them accordingly.
"""
frames = []
logging.debug("Fetching Frames")
start = time.time()
loop_time = time.time()
for _ in range(int(self.video.fps)):
logging.debug("Fetching Video Frame. %f" % (time.time() - loop_time))
loop_time = time.time()
frame = next(self.video_iterator)
frame = np.asarray(
Image.fromarray(frame)
.resize(
[1280 // 4, 720 // 4],
Image.BICUBIC), dtype="float32")
frames.append(tf.expand_dims(frame, 0))
logging.debug("Frame Fetching Time: %f" % (time.time() - start))
if self.interpreter and not self.saved_model:
resolution_fn = self.tflite_super_resolve
else:
resolution_fn = self.saved_model_super_resolve
start = time.time()
with multiprocessing.pool.ThreadPool(30) as pool:
frames = pool.map(resolution_fn, frames)
logging.debug("Fetched Frames. Time: %f" % (time.time() - start))
return frames
def fetch_video(self):
"""
Fetches audio and video frames from the file.
And put them in player cache.
"""
audio = next(self.audio_iterator)
video = self.video_second()
self.audio_queue.put(audio)
self.video_queue.put(video)
def write_audio_stream(self):
"""
Write Audio Frames to default audio device.
"""
try:
while self.audio_queue.qsize() < BUFFER_SIZE:
continue
while self.running:
audio = self.audio_queue.get(timeout=10)
self.stream.write(audio.astype("float32").tostring())
except BaseException:
raise
def write_video_stream(self):
"""
Write Video frames to the player display.
"""
try:
while self.video_queue.qsize() < BUFFER_SIZE:
continue
while self.running:
logging.info("Displaying Frame")
for video_frame in self.video_queue.get(timeout=10):
video_frame = pygame.surfarray.make_surface(
np.rot90(np.fliplr(video_frame)))
self.screen.fill((0, 0, 2))
self.screen.blit(video_frame, (0, 0))
pygame.display.update()
time.sleep((1000 / self.video.fps - self.tolerance) / 1000)
except BaseException:
raise
def run(self):
"""
Start the player threads and the frame streaming simulator.
"""
with self.lock:
if not self.running:
self.running = True
self.audio_thread.start()
self.video_thread.start()
for _ in range(int(self.video.duration)):
logging.debug("Fetching Video")
self.fetch_video()
time.sleep(0.1)
with self.lock:
if not self.running:
self.running = True
self.audio_thread.join()
self.video_thread.join()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-v", "--verbose",
action="count",
default=0,
help="Increases Verbosity of Logging")
parser.add_argument(
"--file",
default=None,
help="File to play")
parser.add_argument(
"--tflite",
default="",
help="Path to TFLite File")
parser.add_argument(
"--saved_model",
default="",
help="Path to Saved Model File")
FLAGS, unknown_args = parser.parse_known_args()
log_levels = [logging.FATAL, logging.WARNING, logging.INFO, logging.DEBUG]
current_log_level = log_levels[min(len(log_levels) - 1, FLAGS.verbose)]
logging.set_verbosity(current_log_level)
player = Player(
videofile=FLAGS.file,
saved_model=FLAGS.saved_model,
tflite=FLAGS.tflite)
player.run()
MIT License
Copyright (c) 2020 Adrish Dey
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# GSOC
Repository for Google Summer of Code 2019 at TensorFlow
---------------------------------------------
### GSOC Project Page
https://summerofcode.withgoogle.com/projects/#4662790671826944
### Mentors
- Sachin Joglekar ([@srjoglekar246](https://github.com/srjoglekar246))
- Vojtech Bardiovsky ([@vbardiovskyg](https://github.com/vbardiovskyg))
### Blogs Written
- Training TF2.0 models on TPUs: https://github.com/captain-pool/GSOC/wiki/Training-TF-2.0-Models-on-TPUs
### Tasks
|Evaluation|Task|Link|Done|
|:-:|:-:|:-:|:-:|
|E1|Sample TF Hub Module Deploy.|[Here](E1_TFHub_Sample_Deploy)| :heavy_check_mark: |
|E1|Image Retraining with TF Hub, TF 2.0 and Cloud TPU|[Here](E1_TPU_Sample)| :heavy_check_mark: |
|E1|Convert ShuffleNet from ONNX to SavedModel 2.0|[Here](E1_ShuffleNet)| [:warning:](https://github.com/captain-pool/GSOC/issues/3) |
|E2|Train ESRGAN Model from Scratch and Export as TF Hub Module|[Here](E2_ESRGAN)|:heavy_check_mark:|
|E2|Adding Support to SavedModel 2.0 in `saved_model_cli`|[Here](https://github.com/tensorflow/tensorflow/pull/30752)|:heavy_check_mark:|
|E3|Add Sample Notebook demonstrating usage of ESRGAN TF Hub Module|[Here](https://www.tensorflow.org/hub/tutorials/image_enhancing)|:heavy_check_mark:|
|E3|Knowledge Distillation of ESRGAN|[Here](E3_Distill_ESRGAN)|:heavy_check_mark:|
|E3| Proof of concept video player for real time video frame Super Resolution|[Here](E3_Streamer)|:heavy_check_mark:|
## UPDATE
ESRGAN just got published on [tfhub.dev](https://tfhub.dev)
Link: [https://tfhub.dev/captain-pool/esrgan-tf2/1](https://tfhub.dev/captain-pool/esrgan-tf2/1)
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
workspace(name = "gsoc")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
git_repository(
name = "org_tensorflow_hub",
commit = "eeefe60480d8471b1662259f8eb5a965eba13d18",
remote = "https://github.com/tensorflow/hub.git",
)
bind(
name = "tensorflow_hub",
actual = "@org_tensorflow_hub//tensorflow_hub",
)
git_repository(
name = "protobuf_bzl",
# v3.6.1.3
commit = "66dc42d891a4fc8e9190c524fd67961688a37bbe",
remote = "https://github.com/google/protobuf.git",
)
bind(
name = "protobuf",
actual = "@protobuf_bzl//:protobuf",
)
bind(
name = "protobuf_python",
actual = "@protobuf_bzl//:protobuf_python",
)
bind(
name = "protobuf_python_genproto",
actual = "@protobuf_bzl//:protobuf_python_genproto",
)
bind(
name = "protoc",
actual = "@protobuf_bzl//:protoc",
)
# Using protobuf version 3.6.1.3
http_archive(
name = "com_google_protobuf",
strip_prefix = "protobuf-3.6.1.3",
urls = ["https://github.com/google/protobuf/archive/v3.6.1.3.zip"],
)
# required by protobuf_python
http_archive(
name = "six_archive",
build_file = "@protobuf_bzl//:six.BUILD",
sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
url = "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz#md5=34eed507548117b2ab523ab14b2f8b55",
)
bind(
name = "six",
actual = "@six_archive//:six",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment