Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active March 16, 2025 07:33
Show Gist options
  • Save tori29umai0123/294c72b7a127e7808c9aa5c74ed72551 to your computer and use it in GitHub Desktop.
Save tori29umai0123/294c72b7a127e7808c9aa5c74ed72551 to your computer and use it in GitHub Desktop.
import os
import sys
import numpy as np
import torch
from PIL import Image
sys.path.append(os.getcwd()) # 現在のディレクトリをシステムパスに追加
from tha3.poser.modes.load_poser import load_poser
from tha3.util import rgba_to_numpy_image, grid_change_to_numpy_image, rgb_to_numpy_image, extract_pytorch_image_from_PIL_image
from tha3.poser.general_poser_02 import GeneralPoser02
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # CUDAが利用可能かチェックしてデバイスを設定
def tha_inference_1k(poser, image, pose):
mapbase = torch.zeros((2,512,512), dtype = poser.get_dtype(), device = device)
for y in range(512):
for x in range(512):
mapbase[1, x, y] = x
mapbase[0, x, y] = y
image_size = 1024
with torch.inference_mode():
image_divided = [None] * 4
image_resized = None
face_output = [None] * 4
image_output = [None] * 4
torch_map = None
image_divided[0] = image[:, ::2, ::2]
image_divided[1] = image[:, 1::2, ::2]
image_divided[2] = image[:, ::2, 1::2]
image_divided[3] = image[:, 1::2, 1::2]
image_resized = (image_divided[0] + image_divided[1] + image_divided[2] + image_divided[3]) / 4.0
for l in range(4):
face_output[l] = poser.pose(image_divided[l], pose, 8)[0]
image_output[l] = image_divided[l].detach().clone()
image_output[l][:, 32:224, 160:352] = face_output[l]
torch_image = torch.zeros((4, 1024, 1024), dtype = poser.get_dtype(), device = device)
torch_image[:, ::2, ::2] = image_output[0]
torch_image[:, 1::2, ::2] = image_output[1]
torch_image[:, ::2, 1::2] = image_output[2]
torch_image[:, 1::2, 1::2] = image_output[3]
torch_map = poser.pose(image_resized, pose, 4)[0]
map_scale = image_size / 512.0
torch_image_map = torch_map * 256.0
torch_image_map = mapbase + torch_image_map
torch_image_map_large = F.interpolate(torch.reshape(torch_image_map, (1, 2, 512, 512)), (image_size, image_size), mode='bicubic', align_corners=False)
torch_image_map_large = (torch_image_map_large * map_scale).int().clamp(0, image_size - 1)
torch_image_map_large_b = torch_image_map_large[0, 0] + torch_image_map_large[0, 1] * image_size
torch_image_map_large_serial = torch.reshape(torch_image_map_large_b, (1, image_size * image_size)).tolist()
torch_source_image_serial = torch.reshape(torch_image, (4, image_size * image_size))
torch_image_serial_l = [torch_source_image_serial[:, x] for x in torch_image_map_large_serial]
torch_image_serial = torch_image_serial_l[0]
torch_image = torch.reshape(torch_image_serial, (1, 4, image_size, image_size))[0]
return torch_image
def convert_output_image_from_torch_to_numpy(output_image):
# PyTorchの画像データをNumPy形式に変換する関数
if output_image.shape[2] == 2:
h, w, c = output_image.shape
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)
elif output_image.shape[0] == 4:
numpy_image = rgba_to_numpy_image(output_image)
elif output_image.shape[0] == 3:
numpy_image = rgb_to_numpy_image(output_image)
elif output_image.shape[0] == 1:
c, h, w = output_image.shape
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)
numpy_image = rgba_to_numpy_image(alpha_image)
elif output_image.shape[0] == 2:
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)
else:
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0])
numpy_image = np.uint8(np.rint(numpy_image * 255.0))
return numpy_image
def get_pose_parameter_index_for_category_and_param(poser: GeneralPoser02, category_name: str, param_name: str) -> int:
# カテゴリとパラメータ名に基づいてポーザーのパラメータインデックスを取得する関数
pose_params = poser.get_pose_parameter_groups()
index = 0
for param in pose_params:
if param.get_category().name == category_name and param_name in param.get_parameter_names():
return index + param.get_parameter_names().index(param_name)
index += param.get_arity()
raise ValueError(f"Parameter {param_name} in Category {category_name} not found")
def save_generated_image(output_dir: str, numpy_image: np.ndarray, value: float):
# 生成された画像を指定されたディレクトリに保存する関数
image_name = f"{value:03d}.png" # 3桁の連番に変更
image_path = os.path.join(output_dir, image_name)
pil_image = Image.fromarray(numpy_image, mode='RGBA')
pil_image.save(image_path)
def main(input_image_dir, output_dir):
try:
# モデルの読み込み(ここでは 'standard_float' を例としています)
poser = load_poser('standard_float', device)
except RuntimeError as e:
print(e)
return
# 入力ディレクトリ内の画像ファイルを処理
for image_filename in os.listdir(input_image_dir):
if image_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
image_path = os.path.join(input_image_dir, image_filename)
print("Processing image:", image_path)
pil_image = Image.open(image_path)
torch_source_image = extract_pytorch_image_from_PIL_image(pil_image).to(device).to(poser.get_dtype())
# 出力先に画像名(拡張子除く)のサブフォルダを作成
base_filename, _ = os.path.splitext(image_filename)
output_subdir = os.path.join(output_dir, base_filename)
if not os.path.exists(output_subdir):
os.makedirs(output_subdir)
# 0から100までの連番でポーズを変化させた画像を生成
for value in range(0, 101):
pose = [0.0] * poser.get_num_parameters()
l_eye_index = get_pose_parameter_index_for_category_and_param(poser, "EYE", "eye_wink_left")
r_eye_index = get_pose_parameter_index_for_category_and_param(poser, "EYE", "eye_wink_right")
# 両目のパラメータを同時に変更(値が0のときは0、100のときは1)
pose[l_eye_index] = value / 100
pose[r_eye_index] = value / 100
# tha_inference_1k を用いて推論を行う
with torch.no_grad():
# tha_inference_1k は画像とポーズテンソルを受け取るので、適宜変換して渡します
output_image = tha_inference_1k(poser, torch_source_image, torch.tensor(pose, device=device))[0].detach().cpu()
numpy_image = convert_output_image_from_torch_to_numpy(output_image)
save_generated_image(output_subdir, numpy_image, value)
if __name__ == "__main__":
input_image_dir = r"E:\AI\face_anime\png2"
output_dir = r"E:\AI\face_anime\base_png"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print("Input image directory:", input_image_dir)
print("Output directory:", output_dir)
main(input_image_dir, output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment