Skip to content

Instantly share code, notes, and snippets.

@afpro
Last active September 16, 2025 01:23
Show Gist options
  • Save afpro/1903f9c298c8ed32d9ba328bdea6471d to your computer and use it in GitHub Desktop.
Save afpro/1903f9c298c8ed32d9ba328bdea6471d to your computer and use it in GitHub Desktop.
subtitle inpaint remove
import typing
import cv2 as cv
import numpy as np
from galipa import Rect
def contour_remove(image: 'cv.typing.MatLike',
chunk: 'Rect',
white_text: 'bool' = True) -> 'cv.typing.MatLike':
image_h, image_w, image_c = image.shape
assert image_c == 3, 'input must be BGR'
mask = np.zeros((image_h, image_w), dtype=np.uint8)
k = cv.getStructuringElement(cv.MORPH_RECT, (7, 5))
r_img = cv.cvtColor(image, cv.COLOR_BGR2GRAY)[chunk.top:chunk.bottom, chunk.left:chunk.right]
if white_text:
_, r_img = cv.threshold(r_img, 200, 255, cv.THRESH_BINARY)
else:
_, r_img = cv.threshold(255 - r_img, 200, 255, cv.THRESH_BINARY)
r_img = cv.morphologyEx(r_img, cv.MORPH_CLOSE, k)
r_img = cv.dilate(r_img, k, iterations=3)
c, h = cv.findContours(r_img, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
if len(c) > 0:
c = max(c, key=cv.contourArea)
x, y, w, h = cv.boundingRect(c)
mask[chunk.top + y:chunk.top + y + h, chunk.left + x:chunk.left + x + w] = r_img[y:y + h, x:x + w]
else:
mask[chunk.top:chunk.bottom, chunk.left:chunk.right] = r_img
mask = cv.erode(mask, k, iterations=1)
mask = cv.GaussianBlur(mask, (7, 7), 0)
return cv.inpaint(image, mask, 3, cv.INPAINT_TELEA)
def inpaint_remove(image: 'cv.typing.MatLike',
chunks: 'typing.Iterable[Rect]') -> 'cv.typing.MatLike':
image_h, image_w, image_c = image.shape
assert image_c == 3, 'input must be BGR'
mask = np.zeros((image_h, image_w), dtype=np.uint8)
for chunk in chunks:
mask[chunk.top - 10:chunk.bottom + 10, chunk.left - 10:chunk.right + 10].fill(255)
return cv.inpaint(image, mask, 5, cv.INPAINT_TELEA)
import io
import math
import os.path
import typing
from collections import defaultdict
from dataclasses import dataclass, field
import cv2 as cv
import numpy as np
from paddleocr import PaddleOCR
from scipy.cluster.vq import kmeans2
from textdistance import levenshtein
from zhconv import convert
from galipa import Rect, clip
from galipa.codec import reconvert_video
from galipa.subtitle.extractor import BaseExtractor, SrtChunk, SrtLine, remove_inner_empty
__all__ = [
'OcrSrtExtractor',
'OcrSrtError',
]
from galipa.subtitle.remover.inpaint import inpaint_remove
from galipa import VideoWatermark
K_MEANS_CLUSTER_COUNT = 10
# noinspection PyUnresolvedReferences
FOURCC_MP4V = cv.VideoWriter_fourcc(*'mp4v')
class OcrSrtExtractor(BaseExtractor):
def __init__(self, video_path: 'str',
lang: 'str' = 'ch',
clean_video_path: 'typing.Optional[str]' = None,
sample_frame_count: 'int' = 64,
use_tolerance_mask: 'bool' = False,
debug: 'bool' = False,
watermarks: 'typing.Optional[typing.List[VideoWatermark]]' = None,
reconvert: 'bool' = True):
self.debug = debug
if reconvert is not None:
reconvert_path = os.path.join(
os.path.dirname(video_path),
f'{os.path.basename(video_path)}.reconvert.mp4'
)
if not os.path.isfile(reconvert_path) and not reconvert_video(video_path, reconvert_path):
raise RuntimeError(f"can't reconvert input video")
capture_path = reconvert_path
else:
capture_path = video_path
self.capture = cv.VideoCapture(capture_path)
if not self.capture.isOpened():
raise OcrSrtError(f"can't open video {capture_path}")
self.fps = self.capture.get(cv.CAP_PROP_FPS)
self.total = int(self.capture.get(cv.CAP_PROP_FRAME_COUNT))
self.width = int(self.capture.get(cv.CAP_PROP_FRAME_WIDTH))
self.height = int(self.capture.get(cv.CAP_PROP_FRAME_HEIGHT))
if clean_video_path is None:
self.output = None
else:
self.output = cv.VideoWriter(clean_video_path,
FOURCC_MP4V, # use mp4v for compatibility
self.fps,
(self.width, self.height))
self.processed_frame = 0
self.chunk_text_size = 0
self.frames = [] # type: typing.List[Frame]
self.sample_frame_count = sample_frame_count
self.ocr_engine = PaddleOCR(use_gpu=True, use_angle_cls=True, lang=lang, show_log=False)
self.max_text_size = 96
if self.height > 1080:
self.max_text_size = int(self.max_text_size * self.height / 1080)
self.text_height_min = 14
self.text_height_max = self.max_text_size
self.mid_y_range = [range(self.height // 2, self.height)] # type: typing.List[range]
self.use_tolerance_mask = use_tolerance_mask
self.tolerance_mask = [] # type: typing.List[Rect]
self.guess_srt_pos = None
if watermarks is None:
self.watermarks = []
else:
self.watermarks = watermarks
self.preprocess_data = None
self.cache_pos = 0 # type: int
self.cache_data = None # type: typing.Optional[OcrData]
def debug_skip_preprocess(self):
if not self.debug:
return
if self.preprocess_data is not None:
return
self.preprocess_data = OcrData()
def preprocess(self):
if self.preprocess_data is not None:
return
data = OcrData()
for pos in range(0, self.total, self.total // self.sample_frame_count):
self.frame_pos = pos
ok, image = self.capture.read()
if not ok:
raise EOFError()
frame_data = OcrData()
frame_time = pos / self.fps
# ocr bottom half
enhanced_image_to_data(image[self.height // 2:], frame_data, self.ocr_engine)
# shift
for w in frame_data.words:
w.top += self.height // 2
for word in frame_data.words:
word_is_watermark = False
for watermark in self.watermarks:
if not watermark.present_at(frame_time):
continue
if word.region.has_overlap(watermark.rect):
word_is_watermark = True
break
if not word_is_watermark and '备案号' in word.text:
word_is_watermark = True
if not word_is_watermark:
data.words.append(word)
self.preprocess_data = data
words = data.words
if self.debug:
print(f'sample word {words}')
height = np.array([word.height for word in words], dtype=np.float32)
if len(height) > K_MEANS_CLUSTER_COUNT:
height_center, height_id = kmeans2(height, K_MEANS_CLUSTER_COUNT)
if self.debug:
print(f' center {height_center}')
print(f' id {height_id}')
height = [height_center[v] for v in find_top_id(0.8, height_id)]
self.text_height_min = min(height) * 0.8
self.text_height_max = max(height) * 1.2
# filter words by height
words = [word for word in words if self.text_height_min <= word.height <= self.text_height_max]
text_left = int(self.width * 0.2)
text_right = int(self.width * 0.8)
for word in words:
text_left = min(text_left, word.left)
text_right = max(text_right, word.right)
text_width = text_right - text_left
mid_y = np.array([word.mid_y for word in words], dtype=np.float32)
if len(mid_y) > K_MEANS_CLUSTER_COUNT:
mid_y_center, mid_y_id = kmeans2(mid_y, K_MEANS_CLUSTER_COUNT)
delta = min(self.text_height_max, self.max_text_size) * 0.5
mid_y_range = []
tolerance_mask = []
guess_mid_y_sum = 0
max_mid_y_distance = delta * 3
for mid_y in (mid_y_center[v] for v in find_top_id(0.8, mid_y_id)):
if len(mid_y_range) > 0 and not any(
abs(r.start + (r.stop - r.start) * 0.5 - mid_y) < max_mid_y_distance for r in mid_y_range):
continue
mid_y_range.append(range(int(mid_y - delta), int(mid_y + delta)))
tolerance_mask.append(Rect(
left=text_left,
top=int(mid_y - self.text_height_max * 0.6),
width=text_width,
height=int(self.text_height_max * 1.2),
))
guess_mid_y_sum += mid_y
if len(mid_y_range) > 0:
self.tolerance_mask = tolerance_mask
self.mid_y_range = mid_y_range
self.guess_srt_pos = guess_mid_y_sum / len(mid_y_range)
def postprocess(self):
self.capture.release()
if self.output is not None:
self.output.release()
@property
def total_duration(self) -> 'float':
return self.total / self.fps
@property
def processed_duration(self) -> 'float':
return self.processed_frame / self.fps
@processed_duration.setter
def processed_duration(self, value: 'float'):
self.processed_frame = int(math.ceil(value * self.fps))
@property
def ocr_top(self) -> 'int':
return clip(int(min(r.start for r in self.mid_y_range) - self.text_height_max * 2), 0, self.height)
@property
def ocr_bottom(self) -> 'int':
return clip(int(max(r.stop for r in self.mid_y_range) + self.text_height_max * 2), 0, self.height)
@property
def chunk(self) -> 'SrtChunk':
if self.guess_srt_pos is None:
chunk = SrtChunk()
else:
chunk = SrtChunk(comment=f'pos={int(self.guess_srt_pos)}')
cache = None
prev_index = 0
if self.output is None:
# output is none, step=0.3s
srt_extend = 0.3
else:
# output is not none, process frame by frame, no extend is needed
srt_extend = 0
for frame in self.frames:
if cache is None:
cache = frame
prev_index = frame.frame_index
continue
if cache == frame:
prev_index = frame.frame_index
continue
if len(cache.blocks) > 0:
chunk.lines.append(SrtLine(
text=remove_inner_empty('\n'.join(b.text for b in cache.blocks if len(b.text.strip()))),
start=max(cache.frame_index / self.fps - srt_extend, 0),
end=max(prev_index / self.fps, cache.frame_index / self.fps + srt_extend),
))
cache = frame
prev_index = frame.frame_index
if cache is not None and len(cache.blocks) > 0:
chunk.lines.append(SrtLine(
text=remove_inner_empty('\n'.join((b.text for b in cache.blocks))),
start=max(cache.frame_index / self.fps - 0.3, 0),
end=max(prev_index / self.fps, cache.frame_index / self.fps + 0.3),
))
return chunk
def process_step(self):
# decode one frame
self.frame_pos = self.processed_frame
ok, frame = self.capture.read()
if not ok:
raise EOFError()
# ocr
self.ocr_frame(frame, self.processed_frame)
# forward
if self.output is None:
self.processed_frame += int(self.fps * 0.3)
else:
self.processed_frame += 1
def ocr_frame(self, frame, index):
# basic image info
video_height, _, _ = frame.shape
# run ocr
ocr_data = OcrData()
enhanced_image_to_data(frame[self.ocr_top:self.ocr_bottom], ocr_data, self.ocr_engine)
# shift
for w in ocr_data.words:
w.top += self.ocr_top
# filter
if self.debug:
print(f'frame = {self.processed_frame}')
print(f' word = {ocr_data.words}')
ocr_data.words = [word for word in ocr_data.words if self.is_valid_word(word)]
if self.debug:
print(f' used = {ocr_data.words}')
# check cache
if self.cache_data is not None and self.cache_pos == index - 1 and ocr_data.is_subset_of(self.cache_data):
ocr_data = self.cache_data
else:
self.cache_data = ocr_data
self.cache_pos = index
# merge frame
if len(ocr_data.words) > 0:
ocr_frame = Frame(frame_index=index)
merge_frame_data(ocr_data, ocr_frame, video_height)
if len(ocr_frame.blocks) > 0:
self.frames.append(ocr_frame)
# remove subtitle & watermark
if self.output is not None:
masks = []
if self.watermarks is not None:
for watermark in self.watermarks:
if not watermark.present_at(self.processed_duration):
continue
masks.append(watermark.rect)
if self.use_tolerance_mask:
masks += self.tolerance_mask
ocr_masks = [
Rect(left=word.left, top=word.top, width=word.width, height=word.height)
for word in ocr_data.words
]
masks += ocr_masks
if self.debug:
for watermark in self.watermarks:
if not watermark.present_at(self.processed_duration):
continue
watermark.rect.draw_border(frame, [0x00, 0x00, 0xff], width=1)
if self.use_tolerance_mask:
for mask in self.tolerance_mask:
mask.draw_border(frame, [0x00, 0xff, 0x00], width=2)
for mask in ocr_masks:
mask.draw_border(frame, [0x99, 0x66, 0x33], width=4)
elif len(masks) > 0:
frame = inpaint_remove(frame, masks)
self.output.write(frame)
def is_valid_word(self, word: 'OcrWord') -> 'bool':
if '备案号' in word.text:
return False
if word.height < self.text_height_min:
return False
if word.height > self.text_height_max:
return False
for watermark in self.watermarks:
if not watermark.present_at(self.processed_duration):
continue
if word.region.has_overlap(watermark.rect):
return False
for r in self.mid_y_range:
if word.mid_y in r:
return True
return False
@property
def frame_pos(self) -> 'int':
return int(self.capture.get(cv.CAP_PROP_POS_FRAMES))
@frame_pos.setter
def frame_pos(self, value: 'int'):
if int(self.capture.get(cv.CAP_PROP_POS_FRAMES)) != value:
self.capture.set(cv.CAP_PROP_POS_FRAMES, value)
@property
def debug_info(self) -> 'str':
with io.StringIO() as summary:
print(f'preprocess data={self.preprocess_data}', file=summary)
print(f'max_text_size={self.max_text_size}', file=summary)
print(f'height_max={self.text_height_max}', file=summary)
print(f'height_min={self.text_height_min}', file=summary)
for r in self.mid_y_range:
print(f'range ({r.start}, {r.stop})', file=summary)
for frame in self.frames:
print(f'Frame {frame.frame_index}', file=summary)
for block in frame.blocks:
print(f' {block.y:.1f}: {block.text}', file=summary)
return summary.getvalue()
class OcrSrtError(Exception):
def __init__(self, message: 'str'):
self.message = message
def __str__(self):
return f'OcrSrtError({self.message})'
def __repr__(self):
return f'OcrSrtError({self.message})'
@dataclass
class OcrData:
words: 'typing.List[OcrWord]' = field(default_factory=list)
def is_subset_of(self, data: 'OcrData') -> 'bool':
if len(self.words) == 0: # empty data is never subset
return False
if len(self.words) > len(data.words):
return False
for word in self.words:
if not any((word.is_same_word(w) for w in data.words)):
return False
return True
@staticmethod
def from_raw(raw,
offset_left: 'int' = 0,
offset_top: 'int' = 0) -> 'OcrData':
data = OcrData()
for i in range(0, len(raw['text'])):
text = raw['text'][i].strip() # type: str
if len(text) == 0:
continue
data.words.append(OcrWord(
text=text,
left=raw['left'][i] + offset_left,
top=raw['top'][i] + offset_top,
width=raw['width'][i],
height=raw['height'][i],
))
return data
def __len__(self):
return len(self.words)
def __add__(self, other: 'OcrData') -> 'OcrData':
data = OcrData()
for word in self.words + other.words:
is_dup = False
for current in data.words:
if word.text == current.text and abs(word.top - current.top) < 10 and abs(
word.left - current.left) < 10:
is_dup = True
break
if not is_dup:
data.words.append(word)
return data
def height_array(self, indices=None, dtype: 'np.dtype' = np.float32) -> 'np.ndarray':
if indices is None:
indices = range(0, len(self))
return np.array([self.words[i].height for i in indices], dtype=dtype)
def x_array(self, indices=None, dtype: 'np.dtype' = np.float32) -> 'np.ndarray':
if indices is None:
indices = range(0, len(self))
return np.array([self.words[i].left for i in indices], dtype=dtype)
def mid_y_array(self, indices=None, dtype: 'np.dtype' = np.float32) -> 'np.ndarray':
if indices is None:
indices = range(0, len(self))
return np.array([self.words[i].mid_y for i in indices], dtype=dtype)
@dataclass
class OcrWord:
text: 'str'
left: 'int'
top: 'int'
width: 'int'
height: 'int'
def is_same_word(self, word: 'OcrWord') -> 'bool':
if levenshtein(self.text, word.text) > 3:
return False
if abs(self.left - word.left) > 5:
return False
if abs(self.top - word.top) > 5:
return False
if abs(self.right - word.right) > 5:
return False
if abs(self.bottom - word.bottom) > 5:
return False
return True
@property
def right(self) -> 'int':
return self.left + self.width
@property
def bottom(self) -> 'int':
return self.top + self.height
@property
def mid_y(self) -> 'int':
return self.top + self.height // 2
@property
def region(self) -> 'Rect':
return Rect(left=self.left, top=self.top, width=self.width, height=self.height)
@dataclass
class Frame:
frame_index: 'int'
blocks: 'typing.List[Block]' = field(default_factory=list)
def __eq__(self, other: 'Frame') -> 'bool':
if len(self.blocks) != len(other.blocks):
return False
for a, b in zip(self.blocks, other.blocks): # type: Block, Block
if a.text != b.text:
return False
if abs(a.y - b.y) > 96:
return False
return True
@dataclass
class Block:
y: 'float'
text: 'str'
def enhanced_image_to_data(input_image, output_data: 'OcrData', ocr: 'PaddleOCR'):
r = ocr.ocr(input_image, cls=False)
if r is None or len(r) == 0 or r[0] is None:
return
for [box, [text, conf]] in r[0]:
# confidence check
if conf < 0.75:
continue
# empty text check
text = text.strip()
if len(text) == 0:
continue
# box to numpy array
box = np.array(box, dtype=np.int32)
# horizontal layout check
if abs(box[0, 1] - box[1, 1]) > 8:
continue
if abs(box[2, 1] - box[3, 1]) > 8:
continue
# vertical layout check
if abs(box[0, 0] - box[3, 0]) > 8:
continue
if abs(box[1, 0] - box[2, 0]) > 8:
continue
# save output
left = box[:, 0].min()
right = box[:, 0].max()
top = box[:, 1].min()
bottom = box[:, 1].max()
output_data.words.append(OcrWord(
text=convert(text, 'zh-cn').strip(),
left=int(left),
top=int(top),
width=int(right - left),
height=int(bottom - top),
))
def merge_frame_data(data: 'OcrData', frame: 'Frame', video_height: 'int'):
min_text_size = 28
max_text_size = 96
if video_height > 1080:
max_text_size = int(max_text_size * video_height / 1080)
indices = [i for i in range(0, len(data)) if min_text_size <= data.words[i].height <= max_text_size]
if len(indices) == 0:
return
# find subtitle text cluster y size
if len(indices) > K_MEANS_CLUSTER_COUNT:
mid_y_center, mid_y_id = kmeans2(data.mid_y_array(indices), K_MEANS_CLUSTER_COUNT)
else:
mid_y_center = data.mid_y_array(indices)
mid_y_id = list(range(0, len(indices)))
for cluster in merge_mid_y_id(mid_y_center, mid_y_id, max_text_size):
y = mid_y_center[cluster[0]]
words = [data.words[i] for i, mid in zip(indices, mid_y_id) if mid in cluster]
words = list(sorted(words, key=lambda w: w.left))
text = ''
right = 0
for i, w in enumerate(words): # type: int, OcrWord
if i > 0 and words[i].left > right + w.height * 0.3:
text += ' '
text += w.text
right = w.right
frame.blocks.append(Block(text=text, y=float(y)))
def find_top_id(pick_ratio: 'float', ary: 'typing.Iterable[int]') -> 'typing.List[int]':
cnt = defaultdict(int)
total = 0
for v in ary:
cnt[v] += 1
total += 1
to_pick = int(total * pick_ratio)
picked = []
sorted_items = sorted(cnt.items(), key=lambda pair: pair[1])
pick_pos = len(sorted_items) - 1
while to_pick > 0 and pick_pos >= 0:
max_id, max_id_count = sorted_items[pick_pos]
picked.append(max_id)
to_pick -= max_id_count
pick_pos -= 1
return picked
def merge_mid_y_id(center, ids, max_diff: 'float') -> 'typing.List[typing.List[int]]':
ids = set(ids) # type: typing.Set[int]
merged = []
while len(ids) > 0:
min_id = min(ids, key=lambda i: center[i])
ids.remove(min_id)
if len(merged) > 0 and abs(center[merged[-1][0]] - center[min_id]) < max_diff:
merged[-1].append(min_id)
else:
merged.append([min_id])
return merged
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment