Created
September 6, 2024 15:19
-
-
Save a-r-r-o-w/4dee20250e82f4e44690a02351324a4a to your computer and use it in GitHub Desktop.
Extremely simple and naive video captioning with MiniCPM-V-2_6 and Llama-3.1-8B-Instruct
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import random | |
from typing import Any, Dict | |
import numpy as np | |
import pandas as pd | |
import torch | |
import transformers | |
from decord import VideoReader, cpu | |
from transformers import AutoModel, AutoTokenizer | |
from PIL import Image | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--instance_data_root", type=str, required=True, help="Base folder where video files are located." | |
) | |
parser.add_argument("--cache_dir", type=str, default=None, help="Path to where models are stored.") | |
parser.add_argument( | |
"--output_path", type=str, default="video_dataset.csv", help="File path where dataset csv should be stored." | |
) | |
parser.add_argument("--max_new_tokens", type=int, default=226, help="Maximum number of new tokens to generate.") | |
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility") | |
return parser.parse_args() | |
SYSTEM_PROMPT = """ | |
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. | |
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. | |
There are a few rules to follow: | |
- You will only ever output a single video description per request. | |
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. | |
You responses should just be the video generation prompt. Here are examples: | |
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." | |
- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" | |
""".strip() | |
USER_PROMPT = """ | |
Could you generate a prompt for a video generation model for the following video summary: | |
``` | |
{0} | |
``` | |
Please limit the prompt to [{1}] words. | |
""".strip() | |
QUESTION = """ | |
Describe the video. You should pay close attention to every detail in the video and describe it in as much detail as possible. | |
""".strip() | |
MAX_NUM_FRAMES = 49 | |
def set_seed(seed: int): | |
""" | |
Args: | |
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. | |
seed (`int`): The seed to set. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def encode_video(video_path: str): | |
def uniform_sample(l, n): | |
gap = len(l) / n | |
idxs = [int(i * gap + gap / 2) for i in range(n)] | |
return [l[i] for i in idxs] | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
sample_fps = round(vr.get_avg_fps()) | |
sample_fps //= 2 | |
frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
if len(frame_idx) > MAX_NUM_FRAMES: | |
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) | |
frames = vr.get_batch(frame_idx).asnumpy() | |
frames = [Image.fromarray(v.astype("uint8")) for v in frames] | |
print("num frames:", len(frames)) | |
return frames | |
@torch.no_grad() | |
def main(args: Dict[str, Any]): | |
set_seed(args.seed) | |
video_files = [file for file in os.listdir(args.instance_data_root) if file.endswith(".mp4")] | |
video_files = [os.path.join(args.instance_data_root, file) for file in video_files] | |
video_descriptions = {} | |
model_id = "openbmb/MiniCPM-V-2_6" | |
model = AutoModel.from_pretrained( | |
model_id, trust_remote_code=True, attn_implementation="sdpa", torch_dtype=torch.bfloat16 | |
).to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
for filepath in video_files: | |
print(f"Generating video summary for file: `{filepath}`") | |
frames = encode_video(filepath) | |
msgs = [{"role": "user", "content": frames + [QUESTION]}] | |
params = { | |
"use_image_id": False, | |
"max_slice_nums": 10, | |
} | |
description = model.chat(image=None, msgs=msgs, tokenizer=tokenizer, **params) | |
video_descriptions[filepath] = {"summary": description} | |
del model | |
del tokenizer | |
model = None | |
tokenizer = None | |
torch.cuda.empty_cache() | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model_id, | |
device_map="auto", | |
model_kwargs={ | |
"local_files_only": True, | |
"cache_dir": args.cache_dir, | |
"torch_dtype": torch.bfloat16, | |
}, | |
) | |
for filepath in video_files: | |
print(f"Generating captions for file: `{filepath}`") | |
for prompt_type, num_words in [("short_prompt", 25), ("prompt", 75), ("verbose_prompt", 125)]: | |
user_prompt = USER_PROMPT.format(video_descriptions[filepath]["summary"], num_words) | |
messages = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": user_prompt}, | |
] | |
outputs = pipeline( | |
messages, | |
max_new_tokens=226, | |
) | |
video_descriptions[filepath][prompt_type] = outputs[0]["generated_text"][-1]["content"] | |
df_data = [] | |
for filepath, description in video_descriptions.items(): | |
relative_path = os.path.relpath(filepath, args.instance_data_root) | |
df_data.append( | |
{ | |
"path": relative_path, | |
"short_prompt": description.get("short_prompt", ""), | |
"prompt": description.get("prompt", ""), | |
"verbose_prompt": description.get("verbose_prompt", ""), | |
"summary": description.get("summary", ""), | |
} | |
) | |
df = pd.DataFrame(df_data) | |
df.to_csv(args.output_path) | |
if __name__ == "__main__": | |
args = get_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage: