Last active
July 24, 2024 14:21
-
-
Save donghee/f1ebce30f8b1e4773913fd7ad25771e4 to your computer and use it in GitHub Desktop.
Testing AI avtar project using Gradio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import time | |
import random | |
import subprocess | |
import numpy as np | |
import cv2 | |
import torch | |
from collections import OrderedDict | |
from utils.deep_speech import DeepSpeech | |
from utils.data_processing import load_landmark_openface,compute_crop_radius | |
from config.config import InferenceOptions | |
from models.segDINet import segDINet | |
class OpenFaceExtractor: | |
def __init__(self, install_path): | |
self.install_path = install_path | |
def extract_features(self, video, output_dir): | |
temp_cwd = os.getcwd() | |
os.chdir(self.install_path) | |
command = f"./FeatureExtraction -f {video} -out_dir {output_dir} -2Dfp" | |
os.system(command) | |
os.chdir(temp_cwd) | |
os.system("pwd") | |
print(f"\nCompleted! Please check that it is extracted to {output_dir}\n") | |
class DeepSpeechExtractor: | |
def __init__(self, model_path): | |
if not os.path.exists(model_path): | |
raise FileNotFoundError('Please download the pretrained model of DeepSpeech') | |
self.model = DeepSpeech(model_path) | |
def extract_features(self, audio_path): | |
if not os.path.exists(audio_path): | |
raise FileNotFoundError(f'Wrong audio path: {audio_path}') | |
ds_feature = self.model.compute_audio_feature(audio_path) | |
return ds_feature | |
class FrameExtractor: | |
def __init__(self, video_path): | |
self.video_path = video_path | |
def extract_frames_from_video(self, video_path, save_dir): | |
videoCapture = cv2.VideoCapture(video_path) | |
fps = videoCapture.get(cv2.CAP_PROP_FPS) | |
if int(fps) != 25: | |
print('warning: the input video is not 25 fps, it would be better to trans it to 25 fps!') | |
# frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) | |
frame_height = videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
frame_width = videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) | |
i=0 | |
while(True): | |
ret, frame = videoCapture.read() | |
if not ret: | |
break | |
result_path = os.path.join(save_dir, str(i).zfill(6) + '.jpg') | |
cv2.imwrite(result_path, frame) | |
i+=1 | |
print(f"\nall frames: {i} frame_height: {frame_height} frame_width: {frame_width}\n") | |
return (int(frame_width),int(frame_height)) | |
def extract_frames(self): | |
video_frame_dir = self.video_path.replace('.mp4', '') | |
if not os.path.exists(video_frame_dir): | |
os.mkdir(video_frame_dir) | |
video_size = self.extract_frames_from_video(self.video_path, video_frame_dir) | |
print(f"\nCompleted! Please check that it is extracted to {video_frame_dir}\n") | |
return video_frame_dir, video_size | |
class DrivingImageSelector: | |
def __init__(self, mouth_region_size): | |
self.resize_w = int(mouth_region_size + mouth_region_size // 4) | |
self.resize_h = int((mouth_region_size // 2) * 3 + mouth_region_size // 8) | |
def select_images(self, video_frame_path_list_pad, video_landmark_data_pad, video_size): | |
driving_img_list = [] | |
driving_index_list = random.sample(range(5, len(video_frame_path_list_pad) - 2), 5) | |
for driving_index in driving_index_list: | |
crop_flag, crop_radius = compute_crop_radius(video_size, video_landmark_data_pad[driving_index - 5:driving_index, :, :]) | |
if not crop_flag: | |
raise ValueError('Our method cannot handle videos with large change of facial size!!') | |
crop_radius_1_4 = crop_radius // 4 | |
driving_img = cv2.imread(video_frame_path_list_pad[driving_index - 3])[:, :, ::-1] | |
driving_landmark = video_landmark_data_pad[driving_index - 3, :, :] | |
driving_img_crop = driving_img[ | |
driving_landmark[29, 1] - crop_radius: driving_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, | |
driving_landmark[33, 0] - crop_radius - crop_radius_1_4: driving_landmark[33, 0] + crop_radius + crop_radius_1_4, :] | |
driving_img_crop = cv2.resize(driving_img_crop, (self.resize_w, self.resize_h)) | |
driving_img_crop = driving_img_crop / 255.0 | |
driving_img_list.append(driving_img_crop) | |
driving_video_frame = np.concatenate(driving_img_list, 2) | |
driving_img_tensor = torch.from_numpy(driving_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda() | |
return driving_img_tensor | |
class ModelHandler: | |
def __init__(self, model_path, source_channel, ref_channel, audio_channel): | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f'Wrong path of the pretrained model weight: {model_path}') | |
self.model = segDINet(source_channel, ref_channel, audio_channel).cuda() | |
state_dict = torch.load(model_path)['state_dict']['net_g'] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
self.model.load_state_dict(new_state_dict) | |
self.model.eval() | |
def infer_frame(self, crop_frame_tensor, driving_img_tensor, deepspeech_tensor, gt_frame_tensor): | |
with torch.no_grad(): | |
out = self.model(crop_frame_tensor, driving_img_tensor, deepspeech_tensor) | |
out_frame, out_mask = out[:, :-1], out[:, -1:] | |
idx = out_mask < 250 / 255.0 | |
pre_frame = torch.where(idx, gt_frame_tensor, out_frame) | |
pre_frame = pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255 | |
return pre_frame | |
return None | |
class VideoSynchronizer: | |
def __init__(self, video_landmark_data, video_frame_path_list, ds_feature_padding, video_size, mouth_region_size): | |
self.video_landmark_data = video_landmark_data | |
self.video_frame_path_list = video_frame_path_list | |
self.ds_feature_padding = ds_feature_padding | |
self.video_size = video_size | |
self.mouth_region_size = mouth_region_size | |
def align_frames_with_audio(self, res_frame_length): | |
# Frame Alignment Code | |
video_frame_path_list_cycle = self.video_frame_path_list + self.video_frame_path_list[::-1] | |
video_landmark_data_cycle = np.concatenate([self.video_landmark_data, np.flip(self.video_landmark_data, 0)], 0) | |
video_frame_path_list_cycle_length = len(video_frame_path_list_cycle) | |
# res_frame_length = self.ds_feature_padding.shape[0] | |
if video_frame_path_list_cycle_length >= res_frame_length: | |
res_video_frame_path_list = video_frame_path_list_cycle[:res_frame_length] | |
res_video_landmark_data = video_landmark_data_cycle[:res_frame_length, :, :] | |
else: | |
divisor = res_frame_length // video_frame_path_list_cycle_length | |
remainder = res_frame_length % video_frame_path_list_cycle_length | |
res_video_frame_path_list = video_frame_path_list_cycle * divisor + video_frame_path_list_cycle[:remainder] | |
res_video_landmark_data = np.concatenate([video_landmark_data_cycle] * divisor + [video_landmark_data_cycle[:remainder, :, :]], 0) | |
res_video_frame_path_list_pad = [video_frame_path_list_cycle[0]] * 2 + res_video_frame_path_list + [video_frame_path_list_cycle[-1]] * 2 | |
res_video_landmark_data_pad = np.pad(res_video_landmark_data, ((2, 2), (0, 0), (0, 0)), mode='edge') | |
assert self.ds_feature_padding.shape[0] == len(res_video_frame_path_list_pad) == res_video_landmark_data_pad.shape[0] | |
return res_video_frame_path_list_pad, res_video_landmark_data_pad | |
class VideoPlayer: | |
@staticmethod | |
def play_video(video_path): | |
vid = cv2.VideoCapture(video_path) | |
if vid.isOpened(): | |
fps = vid.get(cv2.CAP_PROP_FPS) | |
f_count = vid.get(cv2.CAP_PROP_FRAME_COUNT) | |
f_width = vid.get(cv2.CAP_PROP_FRAME_WIDTH) | |
f_height = vid.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
print('Frames per second : ', fps, 'FPS') | |
print('Frame count : ', f_count) | |
print('Frame width : ', f_width) | |
print('Frame height : ', f_height) | |
while vid.isOpened(): | |
ret, frame = vid.read() | |
if ret: | |
cv2.imshow('Generated Video', frame) | |
key = cv2.waitKey(1) | |
if key == ord('q'): | |
break | |
else: | |
break | |
vid.release() | |
cv2.destroyAllWindows() | |
class FacialDubbingPipeline: | |
def __init__(self, opt): | |
self.opt = opt | |
self.openface_extractor = OpenFaceExtractor(opt.OpenFace_install_path) | |
self.frame_extractor = FrameExtractor(opt.target_video_path) | |
self.deepspeech_extractor = DeepSpeechExtractor(opt.deepspeech_model_path) | |
self.driving_image_selector = DrivingImageSelector(opt.mouth_region_size) | |
self.model_handler = ModelHandler(opt.pretrained_clip_segDinet_path, opt.source_channel, opt.ref_channel, opt.audio_channel) | |
def run(self): | |
# OpenFace Feature Extraction | |
target_video = os.path.abspath(self.opt.target_video_path) | |
output_landmark_dir = os.path.abspath(os.path.split(self.opt.target_openface_landmark_path)[0]) | |
self.openface_extractor.extract_features(target_video, output_landmark_dir) | |
# input("if you want next step, press any key.") | |
# Frame Extraction | |
video_frame_dir, video_size = self.frame_extractor.extract_frames() | |
# input("if you want next step, press any key.") | |
# DeepSpeech Feature Extraction | |
ds_feature = self.deepspeech_extractor.extract_features(self.opt.source_audio_path) | |
res_frame_length = ds_feature.shape[0] | |
ds_feature_padding = np.pad(ds_feature, ((2, 2), (0, 0)), mode='edge') | |
# input("if you want next step, press any key.") | |
# Frame & Audio Alignment | |
## Load OpenFace Landmark Data | |
if not os.path.exists(self.opt.target_openface_landmark_path): | |
raise FileNotFoundError(f'Wrong target openface landmark path: {opt.target_openface_landmark_path}') | |
video_landmark_data = load_landmark_openface(self.opt.target_openface_landmark_path).astype(np.int) | |
video_frame_path_list = glob.glob(os.path.join(video_frame_dir, '*.jpg')) | |
video_frame_path_list.sort() | |
if len(video_frame_path_list) != video_landmark_data.shape[0]: | |
raise ValueError('video frames are misaligned with detected landmarks') | |
video_synchronizer = VideoSynchronizer(video_landmark_data, video_frame_path_list, ds_feature_padding, video_size, self.opt.mouth_region_size) | |
res_video_frame_path_list_pad, res_video_landmark_data_pad = video_synchronizer.align_frames_with_audio(res_frame_length) | |
pad_length = ds_feature_padding.shape[0] | |
print("complet aligning frames with source audio") | |
# Select Driving Images | |
print("select randomly select 5 driving images") | |
driving_img_tensor = self.driving_image_selector.select_images(res_video_frame_path_list_pad, res_video_landmark_data_pad, video_size) | |
# Create Video Writer | |
res_video_path = os.path.join(self.opt.res_video_dir, os.path.basename(self.opt.target_video_path)[:-4] + '_facial_dubbing.mp4') | |
res_face_path = res_video_path.replace('_facial_dubbing.mp4', '_synthetic_face.mp4') | |
videowriter = cv2.VideoWriter(res_video_path, cv2.VideoWriter_fourcc(*'XVID'), 25, video_size) | |
videowriter_face = cv2.VideoWriter(res_face_path, cv2.VideoWriter_fourcc(*'XVID'), 25, (int(self.opt.mouth_region_size + self.opt.mouth_region_size // 4), | |
int((self.opt.mouth_region_size // 2) * 3 + self.opt.mouth_region_size // 8))) | |
# synthesize video and audio | |
time_sum = self.inference_frame(res_video_frame_path_list_pad, res_video_landmark_data_pad, ds_feature_padding, video_size, driving_img_tensor, pad_length, videowriter, videowriter_face) | |
video_add_audio_path = self.combine_audio(res_video_path, self.opt.source_audio_path, time_sum, pad_length) | |
# self.select_closing(video_add_audio_path) | |
return video_add_audio_path | |
def inference_frame(self, res_video_frame_path_list_pad, res_video_landmark_data_pad, ds_feature_padding, video_size, driving_img_tensor, pad_length, videowriter, videowriter_face): | |
time_sum = 0 | |
for clip_end_index in range(5, pad_length, 1): | |
print(f'synthesizing {clip_end_index - 5}/{pad_length - 5} frame') | |
start = time.time() | |
crop_flag, crop_radius = compute_crop_radius(video_size, res_video_landmark_data_pad[clip_end_index - 5:clip_end_index:,:,:], random_scale=1.05) | |
if not crop_flag: | |
raise ValueError('Our method cannot handle videos with a large change in facial size!') | |
crop_radius_1_4 = crop_radius // 4 | |
frame_data = cv2.imread(res_video_frame_path_list_pad[clip_end_index - 3])[:, :, ::-1] | |
frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :] | |
crop_frame_data = frame_data[ | |
frame_landmark[29, 1] - crop_radius: frame_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, | |
frame_landmark[33, 0] - crop_radius - crop_radius_1_4: frame_landmark[33, 0] + crop_radius + crop_radius_1_4, :] | |
crop_frame_h, crop_frame_w = crop_frame_data.shape[0], crop_frame_data.shape[1] | |
crop_frame_data = cv2.resize(crop_frame_data, (int(self.opt.mouth_region_size + self.opt.mouth_region_size // 4), int((self.opt.mouth_region_size // 2) * | |
3 + self.opt.mouth_region_size // 8))) | |
crop_frame_data = crop_frame_data / 255.0 | |
gt_frame_data = crop_frame_data.copy() | |
crop_frame_data[self.opt.mouth_region_size // 2: self.opt.mouth_region_size // 2 + self.opt.mouth_region_size, | |
self.opt.mouth_region_size // 8: self.opt.mouth_region_size // 8 + self.opt.mouth_region_size, :] = 0 | |
gt_frame_tensor = torch.from_numpy(gt_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0) | |
crop_frame_tensor = torch.from_numpy(crop_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0) | |
deepspeech_tensor = torch.from_numpy(ds_feature_padding[clip_end_index - 5:clip_end_index, :]).permute(1, 0).unsqueeze(0).float().cuda() | |
pre_frame = self.model_handler.infer_frame(crop_frame_tensor, driving_img_tensor, deepspeech_tensor, gt_frame_tensor) | |
videowriter_face.write(pre_frame[:, :, ::-1].copy().astype(np.uint8)) | |
pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w, crop_frame_h)) | |
frame_data[ | |
frame_landmark[29, 1] - crop_radius: frame_landmark[29, 1] + crop_radius * 2, | |
frame_landmark[33, 0] - crop_radius - crop_radius_1_4: frame_landmark[33, 0] + crop_radius + crop_radius_1_4, :] = pre_frame_resize[:crop_radius * 3, :, :] | |
videowriter.write(frame_data[:, :, ::-1]) | |
end = time.time() | |
time_sum += end - start | |
videowriter.release() | |
videowriter_face.release() | |
return time_sum | |
def combine_audio(self, res_video_path, audio_path, time_sum, pad_length): | |
video_add_audio_path = res_video_path.replace('.mp4', '_add_audio.mp4') | |
if os.path.exists(video_add_audio_path): | |
os.remove(video_add_audio_path) | |
cmd = f'ffmpeg -i {res_video_path} -i {audio_path} -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 {video_add_audio_path}' | |
subprocess.call(cmd, shell=True) | |
print(f"\n\nVideo generation complete (average generation time: {time_sum/(pad_length-5):.2f}s, total time: {time_sum:.2f})\n") | |
print(f"Saved folder path: {self.opt.res_video_dir}") | |
return video_add_audio_path | |
def select_closing(self, video_add_audio_path): | |
while True: | |
print("----------------------------------------------------------------------------") | |
print("\n 6. Please select the closing option.\n\n") | |
print(f"1. View created images immediately. ({os.path.basename(video_add_audio_path)})") | |
print("2. Just turn it off.") | |
option = input("\n\nSelect option: ") | |
if option == '1': | |
print("\nLet's play the generated video.\n") | |
VideoPlayer.play_video(video_add_audio_path) | |
elif option == '2': | |
break | |
else: | |
print("Invalid option, please try again.") | |
print("\nExit the program.") | |
def usage(): | |
print("Usage: python facial_dubbing.py --target_video_path [target_video_path] --source_audio_path [source_audio_path] --res_video_dir [res_video_dir] --OpenFace_install_path [OpenFace_install_path] --deepspeech_model_path [deepspeech_model_path] --pretrained_clip_segDinet_path [pretrained_clip_segDinet_path] --source_channel [source_channel] --ref_channel [ref_channel] --audio_channel [audio_channel] --mouth_region_size [mouth_region_size]") | |
def segdinet_banner(opt): | |
print(""" | |
_____ _____ _ | |
| __ \_ _| | | | |
___ ___ __ _| | | || | _ __ ___| |_ | |
/ __|/ _ \/ _` | | | || | | '_ \ / _ \ __| | |
\__ \ __/ (_| | |__| || |_| | | | __/ |_ | |
|___/\___|\__, |_____/_____|_| |_|\___|\__| | |
__/ | | |
|___/ | |
""") | |
print("Load the pre-trained SegDINet to run a program that synthesizes images \nin which the target person (target video) speaks according to the source audio.") | |
print(f"target video: {os.path.basename(opt.target_video_path)}") | |
print(f"source audio: {os.path.basename(opt.source_audio_path)}") | |
if __name__ == "__main__": | |
opt = InferenceOptions().parse_args() | |
# Check target video exist | |
if not os.path.exists(opt.target_video_path): | |
raise FileNotFoundError(f'Wrong target video path: {opt.target_video_path}') | |
# Check source audio exist | |
if not os.path.exists(opt.source_audio_path): | |
raise FileNotFoundError(f'Wrong source audio path: {opt.source_audio_path}') | |
segdinet_banner(opt) | |
pipeline = FacialDubbingPipeline(opt) | |
pipeline.run() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, File, UploadFile, Form | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
import os | |
from config.config import InferenceOptions | |
from facial_dubbing import FacialDubbingPipeline, segdinet_banner | |
import shutil | |
app = FastAPI() | |
app.mount("/result", StaticFiles(directory="result"), name="result") | |
@app.get("/", response_class=HTMLResponse) | |
def home(): | |
html_content = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Facial Dubbing</title> | |
</head> | |
<body> | |
<h1>Facial Dubbing</h1> | |
<form action="/inference" method="post" enctype="multipart/form-data"> | |
<label for="target_video">Target Video</label> | |
<input type="file" id="target_video" name="target_video" accept="video/*" required><br><br> | |
<label for="source_audio">Source Audio</label> | |
<input type="file" id="source_audio" name="source_audio" accept="audio/*" required><br><br> | |
<button type="submit">Submit</button> | |
</form> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
@app.post("/inference") | |
async def inference(target_video: UploadFile = File(...), source_audio: UploadFile = File(...)): | |
opt = InferenceOptions().parse_args() | |
if not os.path.exists('uploads'): | |
os.makedirs('uploads') | |
target_video_path = os.path.join('uploads', target_video.filename) | |
source_audio_path = os.path.join('uploads', source_audio.filename) | |
target_openface_landmark_path = os.path.join('input_file', target_video.filename.split('.')[0] + '.csv') | |
with open(target_video_path, 'wb') as f: | |
shutil.copyfileobj(target_video.file, f) | |
with open(source_audio_path, 'wb') as f: | |
shutil.copyfileobj(source_audio.file, f) | |
segdinet_banner(opt) | |
opt_ = vars(opt) | |
opt_['target_video_path'] = target_video_path | |
opt_['source_audio_path'] = source_audio_path | |
opt_['target_openface_landmark_path'] = target_openface_landmark_path | |
print(opt) | |
pipeline = FacialDubbingPipeline(opt) | |
video_add_audio_path = pipeline.run() | |
host_ip = os.popen('hostname -I').read().split()[0] | |
return JSONResponse(content={'video_path': 'http://wearable.lan:8888/' + video_add_audio_path}) | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host='0.0.0.0', port=8888) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
viewer.py donghee@wearable | |
import sqlite3 | |
import gradio as gr | |
import pandas as pd | |
import random | |
from pathlib import Path | |
from apscheduler.schedulers.background import BackgroundScheduler | |
import requests | |
import os | |
import tempfile | |
DB_FILE = "./result.db" | |
db = sqlite3.connect(DB_FILE) | |
try: | |
db.execute("SELECT * FROM survey").fetchall() | |
db.close() | |
except Exception as e: | |
db.execute( | |
''' | |
CREATE TABLE survey (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, | |
create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, | |
name TEXT, age INTEGER, model TEXT) | |
''') | |
db.commit() | |
db.execute( | |
''' | |
CREATE TABLE user_study (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, | |
create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, | |
name TEXT, metric_a TEXT, metric_b TEXT, metric_c TEXT) | |
''') | |
db.commit() | |
db.close() | |
gr.set_static_paths(paths=["data/"]) | |
def get_surveys(db): | |
surveys = db.execute("SELECT * FROM survey").fetchall() | |
total_surveys = db.execute("SELECT COUNT(*) FROM survey").fetchone()[0] | |
# surveys = [{"name": name, "age": age} for name, age in surveys] | |
surveys = pd.DataFrame(surveys, columns=["id", "date_created", "name", "age", "model"]) | |
return surveys, total_surveys | |
def get_user_studies(db): | |
user_studies = db.execute("SELECT * FROM user_study").fetchall() | |
total_studies = db.execute("SELECT COUNT(*) FROM user_study").fetchone()[0] | |
user_studies = pd.DataFrame(user_studies, columns=["id", "date_created", "name", "metric_a", "metric_b", "metric_c"]) | |
return user_studies, total_studies | |
def insert_survey(name, age, model): | |
db = sqlite3.connect(DB_FILE) | |
db.execute("INSERT INTO survey (name, age, model) VALUES (?, ?, ?)", (name, age, model)) | |
db.commit() | |
surveys, total_surveys = get_surveys(db) | |
db.close() | |
return surveys, total_surveys | |
def insert_user_study(name, metric_a, metric_b, metric_c): | |
db = sqlite3.connect(DB_FILE) | |
db.execute("INSERT INTO user_study (name, metric_a, metric_b, metric_c) VALUES (?, ?, ?, ?)", (name, metric_a, metric_b, metric_c)) | |
db.commit() | |
db.close() | |
def validate_survey(name, age, model): | |
if not name: | |
raise gr.Error("Name is required") | |
if not age: | |
raise gr.Error("Age is required") | |
if not model: | |
raise gr.Error("Model is required") | |
return insert_survey(name, age, model) | |
def validate_user_study(metric_a, metric_b, metric_c, name): | |
if not metric_a: | |
raise gr.Error("Metric A is required") | |
if not metric_b: | |
raise gr.Error("Metric B is required") | |
if not metric_c: | |
raise gr.Error("Metric C is required") | |
if not name: | |
raise gr.Error("Name is required") | |
insert_user_study(name, metric_a, metric_b, metric_c) | |
gr.Info('Successfully submitted!') | |
def load_surveys(): | |
db = sqlite3.connect(DB_FILE) | |
surveys, total_surveys = get_surveys(db) | |
db.close() | |
return surveys, total_surveys | |
def load_user_studies(): | |
db = sqlite3.connect(DB_FILE) | |
user_studies, total_user_studies = get_user_studies(db) | |
db.close() | |
return user_studies, total_user_studies | |
insert_survey("John", 25, "model1") | |
insert_survey("Alice", 30, "model2") | |
# print(load_surveys()) | |
def generate_images(): | |
images = [ | |
(random.choice( | |
[ | |
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-1151ce9f4b2043de0d2e3b7826127998.jpg", | |
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-116b5e92936b766b7fdfc242649337f7.jpg", | |
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-1163530ca19b5cebe1b002b8ec67b6fc.jpg", | |
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-1116395d6e6a6581eef8b8038f4c8e55.jpg", | |
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-11319be65db395d0e8e6855d18ddcef0.jpg", | |
] | |
), f"model {i}") | |
for i in range(3) | |
] | |
print(images) | |
return images | |
def generate_videos(): | |
test_video = f"test00" + random.choice(["1", "2", "3", "4"]) | |
videos = [ | |
f"data/{test_video}/{test_video}_modelA.mp4", | |
f"data/{test_video}/{test_video}_modelB.mp4", | |
f"data/{test_video}/{test_video}_modelC.mp4", | |
f"data/{test_video}/{test_video}_modelD.mp4", | |
f"data/{test_video}/{test_video}_modelE.mp4", | |
] | |
return videos | |
def replay_videos(): | |
return [gr.Video(autoplay=True, value=video) for video in generate_videos()] | |
API_URL = "http://wearable.lan:8888/inference" | |
#API_URL = "http://127.0.0.1:8887/inference" | |
def inference_video(video_model, voice_model, image_input, sound_input, text_input, pose_video_input): | |
# send pose_video_input, and sound_input to AI model using rest api | |
r = requests.post(API_URL, files={"target_video": open(pose_video_input, 'rb'), "source_audio": open(sound_input, 'rb')}) | |
inferenced_video_url = r.json()["video_path"] | |
inference_video_dir = tempfile.TemporaryDirectory().name | |
if not os.path.exists(inference_video_dir): | |
os.makedirs(inference_video_dir) | |
inferenced_video = os.path.join(inference_video_dir, Path(inferenced_video_url).name) | |
with open(inferenced_video, 'wb') as f: | |
f.write(requests.get(inferenced_video_url).content) | |
return inferenced_video | |
css = """ | |
.radio-group .warp { | |
display: flex !important; | |
} | |
.radio-group label { | |
flex: 1 1 auto; | |
} | |
""" | |
pose_video_url = "./pose_video.mp4" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as SurveyDemo: | |
gr.Markdown("# AI Avatar") | |
with gr.Tab(label="AI framework"): | |
gr.Markdown("## AI Framework") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Input") | |
video_model = gr.Textbox(label="Video Model", placeholder="Enter model name", value="segDInet") | |
voice_model = gr.Textbox(label="Voice Model", placeholder="Enter model name", value="?") | |
image_input = gr.Image(label="Target Image") | |
sound_input = gr.Audio(label="Driving Audio", type="filepath") | |
text_input = gr.Textbox(label="Source Text", placeholder="Enter text") | |
pose_video_input = gr.Video(label="Driving Video", format="mp4") | |
def on_pose_video(value): | |
print(value) | |
def on_sound_input(value): | |
print(value) | |
pose_video_input.upload(on_pose_video, pose_video_input) | |
sound_input.upload(on_sound_input, sound_input) | |
submit = gr.Button(value="Submit") | |
with gr.Column(): | |
gr.Markdown("### Result") | |
video_output = gr.Video(label="Output Video") | |
submit.click(inference_video, [video_model, voice_model, image_input, sound_input, text_input, pose_video_input], [video_output]) | |
with gr.Tab(label="AI Avatar Result") as ai_avatar_result_tab: | |
gr.Markdown("## AI Avatar Result") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### 파일 목록") | |
gr.FileExplorer(file_count="multiple", root="./", ignore_glob=".*", | |
interactive=True, glob="**/*.*") | |
with gr.Column(): | |
gr.Markdown("### 원본 영상") | |
pose_video_output = gr.Video(autoplay=True) | |
with gr.Column(): | |
gr.Markdown("### 합성 영상") | |
synthesis_video_output = gr.Video(autoplay=True) | |
with gr.Tab(label="User Study"): | |
gr.Markdown("## User Study") | |
with gr.Row(): | |
video0 = gr.Video(autoplay=True, label="ModelA") | |
video1 = gr.Video(autoplay=True, label="ModelB") | |
video2 = gr.Video(autoplay=True, label="ModelC") | |
video3 = gr.Video(autoplay=True, label="ModelD") | |
video4 = gr.Video(autoplay=True, label="ModelE") | |
metric_a = gr.Radio(["A", "B", "C", "D", "E"], label="Metric A", info="영상 선명도", elem_classes="radio-group") | |
metric_b = gr.Radio(["A", "B", "C", "D", "E"], label="Metric B", info="입술 동기화", elem_classes="radio-group") | |
metric_c = gr.Radio(["A", "B", "C", "D", "E"], label="Metric C", info="영상 품질", elem_classes="radio-group") | |
with gr.Row(): | |
with gr.Column(): | |
name = gr.Textbox(label="Name", placeholder="Enter your name") | |
with gr.Column(): | |
retry = gr.ClearButton([metric_a, metric_b, metric_c, name], value='Retry', scale=2) | |
with gr.Column(): | |
check = gr.Button(value="Check", scale=2) | |
SurveyDemo.load(generate_videos, None, [video0, video1, video2, video3, video4]) | |
retry.click(replay_videos, None, [video0, video1, video2, video3, video4]) | |
check.click(validate_user_study, [metric_a, metric_b, metric_c, name], None) | |
with gr.Tab(label="User Study Results") as user_study_results_tab: | |
gr.Markdown("## User Study Results") | |
data = gr.Dataframe(headers=["Name", "MetricA", "MetricB", "MetricC"], visible=True) | |
count = gr.Number(label="Total User Studies") | |
SurveyDemo.load(load_user_studies, None, [data, count]) | |
def on_ai_avatar_result_tab_update(pose_video_input, video_output): | |
return pose_video_input, video_output | |
def on_user_study_results_tabs_update(): | |
return load_user_studies() | |
ai_avatar_result_tab.select(on_ai_avatar_result_tab_update, [pose_video_input, video_output], [pose_video_output, synthesis_video_output]) | |
user_study_results_tab.select(on_user_study_results_tabs_update, None, [data, count]) | |
def backup_data(): | |
user_studies, _ = load_user_studies() | |
user_studies.to_csv("./result.csv", index=False) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(backup_data, trigger='interval', seconds=1) | |
scheduler.start() | |
SurveyDemo.launch(share=True, server_name='0.0.0.0', server_port=7860) | |
#SurveyDemo.launch(share=True, server_name='0.0.0.0', server_port=8888) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment