Skip to content

Instantly share code, notes, and snippets.

@donghee
Last active November 26, 2024 01:21
Show Gist options
  • Save donghee/8ca7e6c4d3c67aba061f227e76e9b2cf to your computer and use it in GitHub Desktop.
Save donghee/8ca7e6c4d3c67aba061f227e76e9b2cf to your computer and use it in GitHub Desktop.
Segdinet viewer.py 20241126
import sqlite3
import gradio as gr
import pandas as pd
import random
from pathlib import Path
from apscheduler.schedulers.background import BackgroundScheduler
import requests
import os
import tempfile
DB_FILE = "./result.db"
db = sqlite3.connect(DB_FILE)
try:
db.execute("SELECT * FROM user_study").fetchall()
db.close()
except Exception as e:
db.execute(
'''
CREATE TABLE user_study (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
type TEXT, name TEXT, metric_a TEXT, metric_b TEXT, metric_c TEXT)
''')
db.commit()
db.close()
gr.set_static_paths(paths=["data/"])
def get_user_studies(db):
user_studies = db.execute("SELECT * FROM user_study").fetchall()
total_studies = db.execute("SELECT COUNT(*) FROM user_study").fetchone()[0]
user_studies = pd.DataFrame(user_studies, columns=["id", "date_created", "type", "name", "metric_a", "metric_b", "metric_c"])
return user_studies, total_studies
def insert_user_study(name, type_, metric_a, metric_b, metric_c):
db = sqlite3.connect(DB_FILE)
db.execute("INSERT INTO user_study (type, name, metric_a, metric_b, metric_c) VALUES (?, ?, ?, ?, ?)", (name, type_, metric_a, metric_b, metric_c))
db.commit()
db.close()
def validate_user_study(type_, metric_a, metric_b, metric_c, name):
print(type_, metric_a, metric_b, metric_c, name)
if not type_:
raise gr.Error("Type is required")
if not metric_a:
raise gr.Error("Metric A is required")
if not metric_b:
raise gr.Error("Metric B is required")
if not metric_c:
raise gr.Error("Metric C is required")
if not name:
raise gr.Error("Name is required")
insert_user_study(type_, name, metric_a, metric_b, metric_c)
gr.Info('Successfully submitted!')
def load_user_studies():
db = sqlite3.connect(DB_FILE)
user_studies, total_user_studies = get_user_studies(db)
db.close()
return user_studies, total_user_studies
def generate_videos_reconstruction():
test_video = f"test0" + random.choice(["01", "02", "03", "05", "10", "11", "12", "13", "14", "16", "17",
"18", "19", "20", "22", "26", "27", "28"])
test_video_path = Path(f"data/reconstruction/{test_video}")
read_videos = [file.name for file in test_video_path.glob("*.mp4")]
read_videos.sort()
print(test_video_path)
print(read_videos)
if len(read_videos) == 0:
return []
else:
return [f"data/reconstruction/{test_video}/{video}" for video in read_videos]
def generate_videos_dubbing():
test_video = f"test0" + random.choice(["01", "02", "03", "05", "10", "11", "12", "13", "14", "16", "17",
"18", "19", "20", "22", "26", "27"])
test_video_path = Path(f"data/dubbing/{test_video}")
read_videos = [file.name for file in test_video_path.glob("*.mp4")]
read_videos.sort()
#read_videos[0] = f"ground truth.png"
read_videos[0] = f"GT.png"
print(test_video_path)
print(read_videos)
if len(read_videos) == 0:
return []
else:
return [f"data/dubbing/{test_video}/{video}" for video in read_videos]
def replay_videos_reconstruction():
return [gr.Video(autoplay=False, value=video, elem_classes="user_study_video") for video in generate_videos_reconstruction()]
def replay_videos_dubbing():
output = []
generated_videos = generate_videos_dubbing()
output.append(gr.Image(value=generated_videos[0], elem_classes="user_study_dubbing_image"))
for video in generated_videos[1:]:
output.append(gr.Video(autoplay=False, value=video, elem_classes="user_study_dubbing_video"))
return output
blocks_js = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
sync_video_js = """
function sync_videos() {
var videos = document.querySelectorAll(".user_study_video video");
//console.log(videos);
for (var i = 0; i < videos.length; i++) {
//videos[i].currentTime = 0.0;
videos[i].muted = true;
}
videos[0].muted = false;
for (var i = 0; i < videos.length; i++) {
if (videos[i].paused) {
videos[i].play();
}
}
}
"""
pause_video_js = """
function sync_videos() {
var videos = document.querySelectorAll(".user_study_video video");
for (var i = 0; i < videos.length; i++) {
videos[i].pause();
}
}
"""
sync_dubbing_video_js = """
function sync_dubbing_videos() {
var videos = document.querySelectorAll(".user_study_dubbing_video video");
//console.log(videos);
for (var i = 0; i < videos.length; i++) {
//videos[i].currentTime = 0.0;
videos[i].muted = true;
}
videos[0].muted = false;
for (var i = 0; i < videos.length; i++) {
if (videos[i].paused) {
videos[i].play();
}
}
}
"""
pause_dubbing_video_js = """
function sync_dubbing_videos() {
var videos = document.querySelectorAll(".user_study_dubbing_video video");
for (var i = 0; i < videos.length; i++) {
videos[i].pause();
}
}
"""
API_URL = "http://127.0.0.1:8887/inference"
def inference_video(video_model, voice_model, image_input, sound_input, text_input, pose_video_input):
# send pose_video_input, and sound_input to AI model using rest api
r = requests.post(API_URL, files={"target_video": open(pose_video_input, 'rb'), "source_audio": open(sound_input, 'rb')})
inferenced_video_url = r.json()["video_path"]
inference_video_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(inference_video_dir):
os.makedirs(inference_video_dir)
inferenced_video = os.path.join(inference_video_dir, Path(inferenced_video_url).name)
with open(inferenced_video, 'wb') as f:
f.write(requests.get(inferenced_video_url).content)
return inferenced_video
css = """
.radio-group .warp {
display: flex !important;
}
.radio-group label {
flex: 1 1 auto;
}
fieldset > div {
font-size: 1em;
margin-top: 2px;
color: var(--body-text-color);
}
.user_study_video_first_row > .user_study_video:nth-of-type(1) video {
border: 8px solid #cc0;
}
.user_study_video_first_row > .user_study_dubbing_image:nth-of-type(1) img {
border: 8px solid #cc0;
}
/* User Study Dubbing Image */
.image-frame {
width: 100%;
}
.image-frame img {
object-fit: contain;
}
"""
pose_video_url = "./pose_video.mp4"
with gr.Blocks(theme=gr.themes.Soft(), css=css, js=blocks_js) as SurveyDemo:
gr.Markdown("# AI Avatar")
# with gr.Tab(label="AI framework"):
# gr.Markdown("## AI Framework")
# with gr.Row():
# with gr.Column():
# gr.Markdown("### Input")
# video_model = gr.Textbox(label="Video Model", placeholder="Enter model name", value="segDInet")
# voice_model = gr.Textbox(label="Voice Model", placeholder="Enter model name", value="?")
# image_input = gr.Image(label="Target Image")
# sound_input = gr.Audio(label="Driving Audio", type="filepath")
# text_input = gr.Textbox(label="Source Text", placeholder="Enter text")
# pose_video_input = gr.Video(label="Driving Video", format="mp4")
# def on_pose_video(value):
# print(value)
# def on_sound_input(value):
# print(value)
# pose_video_input.upload(on_pose_video, pose_video_input)
# sound_input.upload(on_sound_input, sound_input)
# submit = gr.Button(value="Submit")
# with gr.Column():
# gr.Markdown("### Result")
# video_output = gr.Video(label="Output Video")
# submit.click(inference_video, [video_model, voice_model, image_input, sound_input, text_input, pose_video_input], [video_output])
# with gr.Tab(label="AI Avatar Result") as ai_avatar_result_tab:
# gr.Markdown("## AI Avatar Result")
# with gr.Row():
# with gr.Column():
# gr.Markdown("### 파일 목록")
# gr.FileExplorer(file_count="multiple", root="./", ignore_glob=".*",
# interactive=True, glob="**/*.*")
# with gr.Column():
# gr.Markdown("### 원본 영상")
# pose_video_output = gr.Video(autoplay=True)
#
# with gr.Column():
# gr.Markdown("### 합성 영상")
# synthesis_video_output = gr.Video(autoplay=True)
with gr.Tab(label="User Study (Reconstruction)"):
gr.Markdown("## 연구내용 소개")
gr.Markdown("해당 웹 페이지는 ETRI 차세대 주역 신진연구 사업 내 <span style=\"text-decoration:underline; color: yellow\">\"초실사 영상 생성AI 제어 프레임워크 요소기술 개발\"</span> 연구과제 개발내용을 포함하며, 생성AI기술을 통한 발화 인물 합성(Talking Face Generation)의 성능 평가를 위한 설문조사를 위해 구성되었습니다.")
gr.Markdown("## 설문 목적 및 방법")
gr.Markdown("이 설문은 생성AI를 이용하여 합성된 영상들을 평가해주시는 것을 목적으로 합니다. <br/> 아래를 보시면 6개의 영상이 있고. 좌측 상단에 노란색 테두리로 강조된 영상(GT, Ground Truth)은 실존하는 인물 발화영상입니다. 그 외 5개 영상은 서로 다른 생성AI모델들이 합성해낸 영상(A, B, C, D, E)입니다. 해당 영상들을 소리와 함께 재생시켜보면서 아래 설문에 응해주시기 바랍니다. <br/>")
gr.Markdown("- (요청사항1) 위 영상들은 같은 음성에 대해 표정이 합성된 영상입니다. GT는 정답이 되는 영상이므로 당연하게도 품질 면에서 가장 좋은 영상으로 보입니다. 따라서, 설문 응답으로 선택하실 수 없습니다.")
gr.Markdown("- (요청사항2) Play버튼을 통해 반복/청취해가시며 평가해주시기를 요청드립니다.")
with gr.Row(elem_classes="user_study_video_first_row"):
video0 = gr.Video(autoplay=False, label="GT", elem_classes="user_study_video")
video1 = gr.Video(autoplay=False, label="A", elem_classes="user_study_video")
video2 = gr.Video(autoplay=False, label="B", elem_classes="user_study_video")
with gr.Row(elem_classes="user_study_video_second_row"):
video3 = gr.Video(autoplay=False, label="C", elem_classes="user_study_video")
video4 = gr.Video(autoplay=False, label="D", elem_classes="user_study_video")
video5 = gr.Video(autoplay=False, label="E", elem_classes="user_study_video")
with gr.Row():
sync = gr.Button(value='Play', scale=1)
pause = gr.Button(value='Stop', scale=1)
type_ = gr.Textbox(label="Type", value="Reconstruction", visible=False)
metric_a = gr.Radio(["A", "B", "C", "D", "E"], label="(Q) 위에 제시된 합성 영상(A-E)들 중 전체적인 선명도와 품질이 가장 좋은 영상을 선택해주세요." , info="영상 선명도(Fideltiy)와 품질(Quality)", elem_classes="radio-group")
metric_b = gr.Radio(["A", "B", "C", "D", "E"], label="(Q) 위에 제시된 합성 영상(A-E)들 중 음성과 얼굴 표정(ex, 입 모양)이 가장 잘 동기화된 영상을 선택해주세요.", info="입술 동기화 (Lip Sync)", elem_classes="radio-group")
metric_c = gr.Radio(["A", "B", "C", "D", "E"], label="(Q) 위에 제시된 합성 영상(A-E)들 중 원본(GT)의 인물의 외형/생김새와 가장 유사한 영상을 성택해주세요.", info="정체성 일관성 (Identity Consistency)", elem_classes="radio-group")
with gr.Row():
with gr.Column():
name = gr.Textbox(label="Name", placeholder="Enter your name")
with gr.Column():
check = gr.Button(value="Submit", scale=2)
gr.Markdown("<br/>")
with gr.Row():
retry = gr.ClearButton([metric_a, metric_b, metric_c, name], value='Load', scale=1)
SurveyDemo.load(generate_videos_reconstruction, None, [video0, video1, video2, video3, video4, video5])
retry.click(replay_videos_reconstruction, None, [video0, video1, video2, video3, video4, video5])
sync.click(None, None, None, js=sync_video_js)
pause.click(None, None, None, js=pause_video_js)
check.click(validate_user_study, [type_, metric_a, metric_b, metric_c, name], None)
with gr.Tab(label="User Study (Dubbing)"):
gr.Markdown("## 연구내용 소개")
gr.Markdown("해당 웹 페이지는 ETRI 차세대 주역 신진연구 사업 내 <span style=\"text-decoration:underline; color: yellow\">\"초실사 영상 생성AI 제어 프레임워크 요소기술 개발\"</span> 연구과제 개발내용을 포함하며, 생성AI기술을 통한 발화 인물 합성(Talking Face Generation)의 성능 평가를 위한 설문조사를 위해 구성되었습니다.")
gr.Markdown("## 설문 목적 및 방법")
gr.Markdown("이 설문은 생성AI를 이용하여 합성된 영상들을 평가해주시는 것을 목적으로 합니다. <br/> 아래를 보시면 6개의 영상이 있고. 좌측 상단에 노란색 테두리로 강조된 이미지는 실존 인물의 참조 사진(Reference Image)입니다. 그 외 5개 영상은 서로 다른 생성AI모델들이 참조 인물에 대해 새롭게 합성해낸 영상(A, B, C, D, E)입니다. 해당 영상들을 소리와 함께 재생시켜보면서 아래 설문에 응해주시기 바랍니다. <br/>")
gr.Markdown("- (요청사항1) 위 영상들은 같은 음성에 대해 표정이 합성된 영상입니다.")
gr.Markdown("- (요청사항2) Play버튼을 통해 반복/청취해가시며 평가해주시기를 요청드립니다")
with gr.Row(elem_classes="user_study_video_first_row"):
# video0 = gr.Video(autoplay=False, label="GT", elem_classes="user_study_dubbing_video")
image0 = gr.Image(label="GT", elem_classes="user_study_dubbing_image")
video1 = gr.Video(autoplay=False, label="A", elem_classes="user_study_dubbing_video")
video2 = gr.Video(autoplay=False, label="B", elem_classes="user_study_dubbing_video")
with gr.Row(elem_classes="user_study_video_second_row"):
video3 = gr.Video(autoplay=False, label="C", elem_classes="user_study_dubbing_video")
video4 = gr.Video(autoplay=False, label="D", elem_classes="user_study_dubbing_video")
video5 = gr.Video(autoplay=False, label="E", elem_classes="user_study_dubbing_video")
with gr.Row():
sync = gr.Button(value='Play', scale=1)
pause = gr.Button(value='Stop', scale=1)
type_ = gr.Textbox(label="Type", value="Dubbing", visible=False)
metric_a = gr.Radio(["A", "B", "C", "D", "E"], label="(Q) 위에 제시된 합성 영상(A-E)들 중 전체적인 선명도와 품질이 가장 좋은 영상을 선택해주세요.", info="영상 선명도(Fideltiy)와 품질(Quality)", elem_classes="radio-group")
metric_b = gr.Radio(["A", "B", "C", "D", "E"], label="(Q) 위에 제시된 합성 영상(A-E)들 중 음성과 얼굴 표정(ex, 입 모양)이 가장 잘 동기화된 영상을 선택해주세요.", info="입술 동기화 (Lip Sync)", elem_classes="radio-group")
metric_c = gr.Radio(["A", "B", "C", "D", "E"], label="(Q) 위에 제시된 합성 영상(A-E)들 중 원본(GT)의 인물의 외형/생김새와 가장 유사한 영상을 성택해주세요.", info="정체성 일관성 (Identity Consistency)", elem_classes="radio-group")
with gr.Row():
with gr.Column():
name = gr.Textbox(label="Name", placeholder="Enter your name")
with gr.Column():
check = gr.Button(value="Submit", scale=2)
gr.Markdown("<br/>")
with gr.Row():
retry = gr.ClearButton([metric_a, metric_b, metric_c, name], value='Load', scale=1)
SurveyDemo.load(generate_videos_dubbing, None, [image0, video1, video2, video3, video4, video5])
retry.click(replay_videos_dubbing, None, [image0, video1, video2, video3, video4, video5])
sync.click(None, None, None, js=sync_dubbing_video_js)
pause.click(None, None, None, js=pause_dubbing_video_js)
check.click(validate_user_study, [type_, metric_a, metric_b, metric_c, name], None)
with gr.Tab(label="User Study Results") as user_study_results_tab:
gr.Markdown("## User Study Results")
data = gr.Dataframe(headers=["Name", "Type", "MetricA", "MetricB", "MetricC"], visible=True)
count = gr.Number(label="Total User Studies")
SurveyDemo.load(load_user_studies, None, [data, count])
def on_ai_avatar_result_tab_update(pose_video_input, video_output):
return pose_video_input, video_output
def on_user_study_results_tabs_update():
return load_user_studies()
# ai_avatar_result_tab.select(on_ai_avatar_result_tab_update, [pose_video_input, video_output], [pose_video_output, synthesis_video_output])
user_study_results_tab.select(on_user_study_results_tabs_update, None, [data, count])
def backup_data():
user_studies, _ = load_user_studies()
user_studies.to_csv("./result.csv", index=False)
scheduler = BackgroundScheduler()
scheduler.add_job(backup_data, trigger='interval', seconds=1)
scheduler.start()
SurveyDemo.launch(share=True, server_name='0.0.0.0', server_port=8888)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment