Last active
March 16, 2025 07:33
-
-
Save tori29umai0123/294c72b7a127e7808c9aa5c74ed72551 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import numpy as np | |
import torch | |
from PIL import Image | |
sys.path.append(os.getcwd()) # 現在のディレクトリをシステムパスに追加 | |
from tha3.poser.modes.load_poser import load_poser | |
from tha3.util import rgba_to_numpy_image, grid_change_to_numpy_image, rgb_to_numpy_image, extract_pytorch_image_from_PIL_image | |
from tha3.poser.general_poser_02 import GeneralPoser02 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # CUDAが利用可能かチェックしてデバイスを設定 | |
def tha_inference_1k(poser, image, pose): | |
mapbase = torch.zeros((2,512,512), dtype = poser.get_dtype(), device = device) | |
for y in range(512): | |
for x in range(512): | |
mapbase[1, x, y] = x | |
mapbase[0, x, y] = y | |
image_size = 1024 | |
with torch.inference_mode(): | |
image_divided = [None] * 4 | |
image_resized = None | |
face_output = [None] * 4 | |
image_output = [None] * 4 | |
torch_map = None | |
image_divided[0] = image[:, ::2, ::2] | |
image_divided[1] = image[:, 1::2, ::2] | |
image_divided[2] = image[:, ::2, 1::2] | |
image_divided[3] = image[:, 1::2, 1::2] | |
image_resized = (image_divided[0] + image_divided[1] + image_divided[2] + image_divided[3]) / 4.0 | |
for l in range(4): | |
face_output[l] = poser.pose(image_divided[l], pose, 8)[0] | |
image_output[l] = image_divided[l].detach().clone() | |
image_output[l][:, 32:224, 160:352] = face_output[l] | |
torch_image = torch.zeros((4, 1024, 1024), dtype = poser.get_dtype(), device = device) | |
torch_image[:, ::2, ::2] = image_output[0] | |
torch_image[:, 1::2, ::2] = image_output[1] | |
torch_image[:, ::2, 1::2] = image_output[2] | |
torch_image[:, 1::2, 1::2] = image_output[3] | |
torch_map = poser.pose(image_resized, pose, 4)[0] | |
map_scale = image_size / 512.0 | |
torch_image_map = torch_map * 256.0 | |
torch_image_map = mapbase + torch_image_map | |
torch_image_map_large = F.interpolate(torch.reshape(torch_image_map, (1, 2, 512, 512)), (image_size, image_size), mode='bicubic', align_corners=False) | |
torch_image_map_large = (torch_image_map_large * map_scale).int().clamp(0, image_size - 1) | |
torch_image_map_large_b = torch_image_map_large[0, 0] + torch_image_map_large[0, 1] * image_size | |
torch_image_map_large_serial = torch.reshape(torch_image_map_large_b, (1, image_size * image_size)).tolist() | |
torch_source_image_serial = torch.reshape(torch_image, (4, image_size * image_size)) | |
torch_image_serial_l = [torch_source_image_serial[:, x] for x in torch_image_map_large_serial] | |
torch_image_serial = torch_image_serial_l[0] | |
torch_image = torch.reshape(torch_image_serial, (1, 4, image_size, image_size))[0] | |
return torch_image | |
def convert_output_image_from_torch_to_numpy(output_image): | |
# PyTorchの画像データをNumPy形式に変換する関数 | |
if output_image.shape[2] == 2: | |
h, w, c = output_image.shape | |
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w) | |
elif output_image.shape[0] == 4: | |
numpy_image = rgba_to_numpy_image(output_image) | |
elif output_image.shape[0] == 3: | |
numpy_image = rgb_to_numpy_image(output_image) | |
elif output_image.shape[0] == 1: | |
c, h, w = output_image.shape | |
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0) | |
numpy_image = rgba_to_numpy_image(alpha_image) | |
elif output_image.shape[0] == 2: | |
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4) | |
else: | |
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0]) | |
numpy_image = np.uint8(np.rint(numpy_image * 255.0)) | |
return numpy_image | |
def get_pose_parameter_index_for_category_and_param(poser: GeneralPoser02, category_name: str, param_name: str) -> int: | |
# カテゴリとパラメータ名に基づいてポーザーのパラメータインデックスを取得する関数 | |
pose_params = poser.get_pose_parameter_groups() | |
index = 0 | |
for param in pose_params: | |
if param.get_category().name == category_name and param_name in param.get_parameter_names(): | |
return index + param.get_parameter_names().index(param_name) | |
index += param.get_arity() | |
raise ValueError(f"Parameter {param_name} in Category {category_name} not found") | |
def save_generated_image(output_dir: str, numpy_image: np.ndarray, value: float): | |
# 生成された画像を指定されたディレクトリに保存する関数 | |
image_name = f"{value:03d}.png" # 3桁の連番に変更 | |
image_path = os.path.join(output_dir, image_name) | |
pil_image = Image.fromarray(numpy_image, mode='RGBA') | |
pil_image.save(image_path) | |
def main(input_image_dir, output_dir): | |
try: | |
# モデルの読み込み(ここでは 'standard_float' を例としています) | |
poser = load_poser('standard_float', device) | |
except RuntimeError as e: | |
print(e) | |
return | |
# 入力ディレクトリ内の画像ファイルを処理 | |
for image_filename in os.listdir(input_image_dir): | |
if image_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')): | |
image_path = os.path.join(input_image_dir, image_filename) | |
print("Processing image:", image_path) | |
pil_image = Image.open(image_path) | |
torch_source_image = extract_pytorch_image_from_PIL_image(pil_image).to(device).to(poser.get_dtype()) | |
# 出力先に画像名(拡張子除く)のサブフォルダを作成 | |
base_filename, _ = os.path.splitext(image_filename) | |
output_subdir = os.path.join(output_dir, base_filename) | |
if not os.path.exists(output_subdir): | |
os.makedirs(output_subdir) | |
# 0から100までの連番でポーズを変化させた画像を生成 | |
for value in range(0, 101): | |
pose = [0.0] * poser.get_num_parameters() | |
l_eye_index = get_pose_parameter_index_for_category_and_param(poser, "EYE", "eye_wink_left") | |
r_eye_index = get_pose_parameter_index_for_category_and_param(poser, "EYE", "eye_wink_right") | |
# 両目のパラメータを同時に変更(値が0のときは0、100のときは1) | |
pose[l_eye_index] = value / 100 | |
pose[r_eye_index] = value / 100 | |
# tha_inference_1k を用いて推論を行う | |
with torch.no_grad(): | |
# tha_inference_1k は画像とポーズテンソルを受け取るので、適宜変換して渡します | |
output_image = tha_inference_1k(poser, torch_source_image, torch.tensor(pose, device=device))[0].detach().cpu() | |
numpy_image = convert_output_image_from_torch_to_numpy(output_image) | |
save_generated_image(output_subdir, numpy_image, value) | |
if __name__ == "__main__": | |
input_image_dir = r"E:\AI\face_anime\png2" | |
output_dir = r"E:\AI\face_anime\base_png" | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
print("Input image directory:", input_image_dir) | |
print("Output directory:", output_dir) | |
main(input_image_dir, output_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment