Created
August 29, 2018 22:33
-
-
Save nvbn/f1365d2548f48fad449bb66d650ad95f to your computer and use it in GitHub Desktop.
Bob's Burgers to The Simpsons with TensorFlow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from typing import NamedTuple | |
from collections import defaultdict | |
from datetime import timedelta | |
from subprocess import call | |
from pycaption.srt import SRTReader | |
import lxml.html | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
import numpy as np | |
lang = 'en-US' | |
output_dir = '' | |
root = Path('') | |
class Caption(NamedTuple): | |
path: str | |
start: int | |
length: int | |
text: str | |
def to_text(raw_text): | |
if not raw_text: | |
return '' | |
raw_text = raw_text.replace('\n', ' ') | |
return lxml.html.document_fromstring(raw_text).text_content() | |
def _read_subtitles(path, offset=0): | |
with open(path, 'rb') as f: | |
data = f.read().decode()[offset:] | |
raw_captions = SRTReader().read(data, lang=lang).get_captions(lang) | |
for raw_caption, next_raw_caption in zip(raw_captions, raw_captions[1:] + [None]): | |
if next_raw_caption: | |
length = next_raw_caption.start - raw_caption.start | |
else: | |
length = raw_caption.end - raw_caption.start | |
yield Caption( | |
path=path, | |
start=raw_caption.start, | |
length=length, | |
text=to_text(raw_caption.get_text()), | |
) | |
def read_subtitles(path): | |
try: | |
return _read_subtitles(path, 0) | |
except: | |
return _read_subtitles(path, 1) | |
data_text2captions = defaultdict(lambda: []) | |
for season in root.glob('*'): | |
if season.is_dir(): | |
for subtitles in season.glob('*.srt'): | |
print(subtitles) | |
try: | |
for caption in read_subtitles(subtitles.as_posix(), offset=1): | |
data_text2captions[caption.text].append(caption) | |
except: | |
print('pass', subtitles) | |
data_texts = [*data_text2captions] | |
print('got data texts') | |
# Tina-rannosaurus Wrecks | |
# https://www.opensubtitles.org/en/subtitles/5643476/bob-s-burgers-tina-rannosaurus-wrecks-en | |
# https://www.youtube.com/watch?v=hZ_EKHGgWJQ | |
play = [*read_subtitles('Bobs.Burgers.S03E07.HDTV.XviD-AFG.srt')][1:54] | |
play_text2captions = defaultdict(lambda: []) | |
for caption in play: | |
play_text2captions[caption.text].append(caption) | |
play_texts = [*play_text2captions] | |
print('got play texts') | |
module_url = "https://tfhub.dev/google/universal-sentence-encoder/2" | |
embed = hub.Module(module_url) | |
print('got module') | |
vec_a = tf.placeholder(tf.float32, shape=None) | |
vec_b = tf.placeholder(tf.float32, shape=None) | |
# For evaluation we use exactly normalized rather than | |
# approximately normalized. | |
normalized_a = tf.nn.l2_normalize(vec_a, axis=1) | |
normalized_b = tf.nn.l2_normalize(vec_b, axis=1) | |
sim_scores = -tf.acos(tf.reduce_sum(tf.multiply(normalized_a, normalized_b), axis=1)) | |
def get_similarity_score(text_vec_a, text_vec_b): | |
emba, embb, scores = session.run( | |
[normalized_a, normalized_b, sim_scores], | |
feed_dict={ | |
vec_a: text_vec_a, | |
vec_b: text_vec_b | |
}) | |
return scores | |
def get_most_similar_text(vec_a, data_vectors): | |
scores = get_similarity_score([vec_a] * len(data_texts), data_vectors) | |
return data_texts[sorted(enumerate(scores), key=lambda score: -score[1])[3][0]] | |
with tf.Session() as session: | |
session.run([tf.global_variables_initializer(), tf.tables_initializer()]) | |
data_vecs, play_vecs = session.run([embed(data_texts), embed(play_texts)]) | |
data_vecs = np.array(data_vecs).tolist() | |
play_vecs = np.array(play_vecs).tolist() | |
print('got vecs') | |
similar_texts = {play_text: get_most_similar_text(play_vecs[n], data_vecs) | |
for n, play_text in enumerate(play_texts)} | |
print('got similarity') | |
class Part(NamedTuple): | |
video: str | |
start: str | |
end: str | |
output: str | |
def generate_parts(): | |
for n, caption in enumerate(play): | |
similar = similar_texts[caption.text] | |
similar_caption = sorted( | |
data_text2captions[similar], | |
key=lambda maybe_similar: abs(caption.length - maybe_similar.length), | |
reverse=True)[0] | |
yield Part( | |
video=similar_caption.path.replace('.srt', '.mp4'), | |
start=str(timedelta(microseconds=similar_caption.start))[:-3], | |
end=str(timedelta(microseconds=similar_caption.length))[:-3], | |
output=Path(output_dir).joinpath(f'part_{n}.mp4').as_posix()) | |
parts = [*generate_parts()] | |
for part in parts: | |
call(['ffmpeg', '-y', '-i', part.video, | |
'-ss', part.start, '-t', part.end, | |
'-c:v', 'libx264', '-c:a', 'aac', '-strict', 'experimental', | |
'-vf', 'fps=30', | |
'-b:a', '128k', part.output]) | |
concat = '\n'.join(f"file '{part.output}'" for part in parts) + '\n' | |
with open('concat.txt', 'w') as f: | |
f.write(concat) | |
call(['ffmpeg', '-y', '-safe', '0', '-f', 'concat', '-i', 'concat.txt', | |
'-c:v', 'libx264', '-c:a', 'aac', '-strict', 'experimental', | |
'-vf', 'fps=30', 'output.mp4']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment