Last active
December 26, 2024 09:48
-
-
Save tori29umai0123/6977e25d3f067a1d8213e4e51694c8f7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 学習には https://gist.github.com/tori29umai0123/fc51ea86aedc1b1e394b12829d1c95e5 のような形式のデータセットを利用する | |
# Manga109-sのアノテーションデータ(http://www.manga109.org/ja/annotations.html)を元に作成 | |
import os | |
import glob | |
import json | |
import argparse | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.amp as amp | |
from tqdm import tqdm | |
from sentence_transformers import SentenceTransformer | |
# グローバル変数としてモデル埋め込みを定義 | |
model_embedding = None | |
def scale_coords(coords, image_size): | |
""" | |
座標を画像サイズに基づいて正規化する (0 ~ 1)。 | |
""" | |
width, height = image_size[..., 0], image_size[..., 1] | |
coords_scaled = coords.clone().float() | |
coords_scaled[..., [0, 2]] = coords_scaled[..., [0, 2]] / width.unsqueeze(-1) | |
coords_scaled[..., [1, 3]] = coords_scaled[..., [1, 3]] / height.unsqueeze(-1) | |
coords_scaled = F.hardtanh(coords_scaled, min_val=0.0, max_val=1.0) | |
return coords_scaled | |
def unscale_coords(coords, image_size): | |
""" | |
正規化された座標を元の画像サイズに復元する。 | |
""" | |
width, height = image_size[..., 0], image_size[..., 1] | |
coords_unscaled = coords.clone().float() | |
coords_unscaled[..., [0, 2]] = coords_unscaled[..., [0, 2]] * width.unsqueeze(-1) | |
coords_unscaled[..., [1, 3]] = coords_unscaled[..., [1, 3]] * height.unsqueeze(-1) | |
return coords_unscaled | |
def generate_text_embeddings_with_order(texts, text_order): | |
""" | |
台詞と順番情報の埋め込みを生成。 | |
""" | |
embed_dim = model_embedding.get_sentence_embedding_dimension() # Sentence-BERTの埋め込み次元(例: 768) | |
valid_texts = [text for text in texts if text.strip()] | |
if not valid_texts: | |
return torch.zeros(0, embed_dim + 4) | |
text_embeddings = model_embedding.encode(valid_texts, convert_to_tensor=True) # (N, embed_dim) | |
text_orders = torch.tensor(text_order, dtype=torch.float32).unsqueeze(1).to(text_embeddings.device) # (N, 1) | |
placeholder = torch.zeros((len(valid_texts), 3), dtype=torch.float32).to(text_embeddings.device) | |
# 順序情報と余白を追加 | |
result = torch.cat((text_embeddings, text_orders, placeholder), dim=1) # (N, embed_dim + 4) | |
return result | |
def custom_collate_fn(batch): | |
page_sizes = torch.stack([item[0] for item in batch]) | |
num_frames = torch.tensor([item[1].item() for item in batch]) | |
frame_coords = [item[2] if item[2] is not None else torch.zeros(0, 4) for item in batch] | |
texts_list = [item[3] for item in batch] | |
text_orders_list = [item[4] for item in batch] | |
text_coords_list = [item[5] for item in batch] | |
# 台詞埋め込みを生成 | |
text_embeddings = [ | |
generate_text_embeddings_with_order(texts, text_order) | |
for texts, text_order in zip(texts_list, text_orders_list) | |
] | |
combined_texts = [" ".join(texts) for texts in texts_list] | |
combined_embeddings = model_embedding.encode(combined_texts, convert_to_tensor=True) | |
max_text_count = max(len(embed) for embed in text_embeddings) | |
embed_dim = model_embedding.get_sentence_embedding_dimension() + 4 | |
padded_text_embeddings = torch.zeros(len(batch), max_text_count, embed_dim) | |
masks = torch.zeros(len(batch), max_text_count, dtype=torch.bool) | |
relative_positions = torch.zeros(len(batch), max_text_count, 1) | |
for i, embed in enumerate(text_embeddings): | |
padded_text_embeddings[i, :embed.size(0), :] = embed | |
masks[i, :embed.size(0)] = 1 | |
relative_positions[i, :embed.size(0), 0] = torch.linspace(0, 1, embed.size(0)) | |
max_frames_in_batch = int(max(num_frames)) if num_frames.size(0) > 0 else 0 | |
padded_frame_coords = torch.zeros(len(batch), max_frames_in_batch, 4) if max_frames_in_batch > 0 else None | |
frame_orders = torch.zeros(len(batch), max_frames_in_batch, 1) | |
for i, frames in enumerate(frame_coords): | |
num_frames_i = frames.size(0) | |
if num_frames_i > 0: | |
padded_frame_coords[i, :num_frames_i, :] = frames | |
# 学習時はJSONの順番をそのまま使用 | |
frame_orders[i, :num_frames_i, 0] = torch.arange(num_frames_i).float() | |
else: | |
# 推論時はnum_framesから順番を生成 | |
num_frames_i = int(num_frames[i].item()) | |
frame_orders[i, :num_frames_i, 0] = torch.arange(num_frames_i).float() | |
# フレーム順序を0から1の範囲に正規化 | |
for i in range(len(frame_orders)): | |
num_frames_i = int(num_frames[i].item()) | |
if num_frames_i > 1: | |
frame_orders[i, :num_frames_i, 0] /= (num_frames_i - 1) | |
else: | |
frame_orders[i, :num_frames_i, 0] = 0.0 | |
return ( | |
page_sizes, num_frames, padded_frame_coords, padded_text_embeddings, | |
masks, relative_positions, combined_embeddings, text_coords_list, | |
frame_orders # frame_orders | |
) | |
def generate_text_embeddings(texts): | |
"""台詞の埋め込みを生成。""" | |
embed_dim = model_embedding.get_sentence_embedding_dimension() | |
valid_texts = [text for text in texts if text.strip()] | |
return model_embedding.encode(valid_texts, convert_to_tensor=True) if valid_texts else torch.zeros(0, embed_dim) | |
class MangaDataset(Dataset): | |
def __init__(self, json_files, skip_incomplete=True, for_training=True): | |
""" | |
Mangaデータセットの初期化。 | |
""" | |
self.samples = [] | |
self.for_training = for_training | |
if isinstance(json_files, str): | |
json_files = [json_files] | |
for json_file in json_files: | |
if not os.path.isfile(json_file): | |
print(f"[Warning] JSON file {json_file} does not exist. Skipping.") | |
continue | |
try: | |
with open(json_file, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
except json.JSONDecodeError as e: | |
print(f"[Error] Failed to decode JSON file {json_file}: {e}. Skipping.") | |
continue | |
for idx, page in enumerate(data.get("pages", [])): | |
# 必要なキーが存在するか確認 | |
required_keys = ["width", "height", "texts", "frames"] | |
if not all(k in page for k in required_keys): | |
raise ValueError(f"Page {idx} is missing required fields: {required_keys}") | |
page_size = (page["width"], page["height"]) | |
texts = [t["text"] for t in page.get("texts", []) if "text" in t] | |
text_predicted_coords = [t.get("predicted_coords", None) for t in page.get("texts", [])] | |
num_frames = page.get("frames", {}).get("num_frames", 0) | |
frame_coords = page.get("frames", {}).get("frame_coords", None) | |
# 学習時に座標データが欠けていないか確認 | |
if self.for_training: | |
if not text_predicted_coords or any(coord is None for coord in text_predicted_coords): | |
raise ValueError(f"Page {idx} has missing text coordinates.") | |
if frame_coords is None or len(frame_coords) != num_frames: | |
raise ValueError(f"Page {idx} has missing or incomplete frame coordinates.") | |
# サンプルを追加 | |
self.samples.append({ | |
"page_size": page_size, | |
"texts": texts, | |
"text_predicted_coords": text_predicted_coords, | |
"num_frames": num_frames, | |
"frame_coords": frame_coords if frame_coords else [] | |
}) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
sample = self.samples[idx] | |
page_size = torch.tensor(sample["page_size"], dtype=torch.float32) | |
num_frames = torch.tensor([sample["num_frames"]], dtype=torch.float32) | |
# フレーム座標の処理 | |
frame_coords_raw = sample.get("frame_coords", []) | |
if frame_coords_raw and any(coord is not None for coord in frame_coords_raw): | |
frame_coords = torch.tensor(frame_coords_raw, dtype=torch.float32) | |
frame_coords = scale_coords(frame_coords, page_size.unsqueeze(0)) | |
else: | |
frame_coords = None # ゼロ埋めせず、欠損として扱う | |
# 台詞と順序情報を取得 | |
texts = sample["texts"] | |
text_order = list(range(len(texts))) # 台詞の順序 | |
# テキスト座標の処理 | |
text_predicted_coords = sample.get("text_predicted_coords", []) | |
if text_predicted_coords and any(coord is not None for coord in text_predicted_coords): | |
text_coords = torch.tensor(text_predicted_coords, dtype=torch.float32) | |
text_coords = scale_coords(text_coords, page_size.unsqueeze(0)) | |
else: | |
text_coords = None # 欠損として扱う | |
return page_size, num_frames, frame_coords, texts, text_order, text_coords | |
class MangaModel(nn.Module): | |
def __init__(self, text_embed_dim): | |
super(MangaModel, self).__init__() | |
self.text_embed_dim = text_embed_dim | |
# 台詞座標予測モデル | |
self.text_coords_model = nn.Sequential( | |
nn.Linear(text_embed_dim + 7, 128), | |
nn.LayerNorm(128), | |
nn.GELU(), | |
nn.Linear(128, 64), | |
nn.LayerNorm(64), | |
nn.GELU(), | |
nn.Linear(64, 4) # (x1, y1, x2, y2) | |
) | |
# コマ座標予測モデル | |
self.frame_coords_model = nn.Sequential( | |
nn.Linear(text_embed_dim + 1, 128), # Combined embedding + frame order | |
nn.LayerNorm(128), | |
nn.GELU(), | |
nn.Linear(128, 64), | |
nn.LayerNorm(64), | |
nn.GELU(), | |
nn.Linear(64, 4) # (x1, y1, x2, y2) | |
) | |
def forward(self, text_embeddings, relative_positions, combined_embeddings, image_size, num_frames, frame_orders): | |
batch_size, text_count, embed_dim = text_embeddings.size() | |
max_num_frames = num_frames.max().int().item() | |
# 台詞座標予測 | |
image_size_expanded = image_size.unsqueeze(1).repeat(1, text_count, 1) | |
text_features = torch.cat((text_embeddings, image_size_expanded, relative_positions), dim=2) | |
pred_text_coords = self.text_coords_model(text_features.view(-1, text_features.size(-1))).view(batch_size, text_count, 4) | |
pred_text_coords = torch.sigmoid(pred_text_coords) # 出力を [0, 1] に制限 | |
# コマ座標予測 | |
if max_num_frames > 0: | |
# 入力特徴量を結合 | |
combined_embeddings_expanded = combined_embeddings.unsqueeze(1).expand(-1, max_num_frames, -1) | |
frame_features = torch.cat(( | |
combined_embeddings_expanded, | |
frame_orders | |
), dim=2) | |
# フレーム座標を予測 | |
pred_frame_coords = self.frame_coords_model(frame_features.view(-1, frame_features.size(-1))).view(batch_size, max_num_frames, 4) | |
pred_frame_coords = torch.sigmoid(pred_frame_coords) | |
else: | |
pred_frame_coords = torch.zeros(batch_size, 0, 4, device=image_size.device) | |
return pred_text_coords, pred_frame_coords | |
def train_model(train_loader, model, optimizer, criterion, device, use_fp16, alpha=1.0, beta=0.5): | |
model.train() | |
scaler = amp.GradScaler(enabled=use_fp16) | |
epoch_loss = 0.0 | |
for batch_idx, ( | |
page_size, num_frames, frame_coords, text_embeddings, masks, | |
relative_positions, combined_embeddings, text_coords_list, | |
frame_orders | |
) in enumerate(tqdm(train_loader, desc="Training", disable=True)): | |
page_size = page_size.to(device) | |
num_frames = num_frames.to(device) | |
frame_coords = frame_coords.to(device) | |
text_embeddings = text_embeddings.to(device) | |
masks = masks.to(device) | |
relative_positions = relative_positions.to(device) | |
combined_embeddings = combined_embeddings.to(device) | |
frame_orders = frame_orders.to(device) | |
optimizer.zero_grad() | |
with amp.autocast(device_type=device.type, enabled=use_fp16): | |
# メインモデルの推論 | |
pred_text_coords, pred_frame_coords = model( | |
text_embeddings, relative_positions, combined_embeddings, | |
page_size, num_frames, frame_orders=frame_orders | |
) | |
# 台詞座標損失 | |
text_loss = 0.0 | |
if masks.numel() > 0 and any(tc.numel() > 0 for tc in text_coords_list): | |
max_text_count = max(tc.size(0) for tc in text_coords_list) | |
target_text_coords = torch.zeros(len(text_coords_list), max_text_count, 4, device=device) | |
for i, tc in enumerate(text_coords_list): | |
target_text_coords[i, :tc.size(0), :] = tc.to(device) | |
valid_mask = masks.unsqueeze(-1).expand_as(target_text_coords) | |
text_loss = criterion(pred_text_coords[valid_mask], target_text_coords[valid_mask]) | |
# コマ座標損失 | |
frame_loss = criterion(pred_frame_coords, frame_coords) | |
# 総損失 | |
loss = alpha * text_loss + beta * frame_loss | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
epoch_loss += loss.item() | |
return epoch_loss / len(train_loader) | |
def validate_model(val_loader, model, device, use_fp16): | |
model.eval() | |
predictions = [] | |
with torch.no_grad(): | |
for batch_idx, ( | |
page_size, num_frames, frame_coords, text_embeddings, masks, | |
relative_positions, combined_embeddings, text_coords_list, | |
frame_orders | |
) in enumerate(tqdm(val_loader, desc="Validation")): | |
page_size = page_size.to(device) | |
num_frames = num_frames.to(device) | |
text_embeddings = text_embeddings.to(device) | |
combined_embeddings = combined_embeddings.to(device) | |
relative_positions = relative_positions.to(device) | |
frame_orders = frame_orders.to(device) | |
with amp.autocast(device_type=device.type, enabled=use_fp16): | |
# モデルで推論 | |
pred_text_coords, pred_frame_coords = model( | |
text_embeddings, relative_positions, combined_embeddings, page_size, num_frames, frame_orders=frame_orders | |
) | |
# スケーリング解除 | |
page_size_expanded_text = page_size.unsqueeze(1).expand(-1, pred_text_coords.size(1), -1) | |
unscaled_text_coords = unscale_coords(pred_text_coords, page_size_expanded_text) | |
if pred_frame_coords is not None: | |
page_size_expanded_frame = page_size.unsqueeze(1).expand(-1, pred_frame_coords.size(1), -1) | |
unscaled_frame_coords = unscale_coords(pred_frame_coords, page_size_expanded_frame) | |
else: | |
unscaled_frame_coords = [None] * len(page_size) # 欠損データとして扱う | |
# 結果を保存 | |
for i in range(len(page_size)): | |
# 有効な台詞のみを取得 | |
if masks[i].sum() > 0: | |
valid_text_coords = unscaled_text_coords[i, :masks[i].sum()].tolist() | |
else: | |
valid_text_coords = [None] * text_embeddings.size(1) # 欠損データの場合 | |
# フレームの座標も保存 | |
frame_coords_result = ( | |
unscaled_frame_coords[i].tolist() if pred_frame_coords is not None else None | |
) | |
predictions.append({ | |
"page_index": batch_idx * val_loader.batch_size + i + 1, | |
"predicted_text_coords": valid_text_coords, | |
"predicted_frame_coords": frame_coords_result, | |
}) | |
return predictions | |
def save_predictions_to_json(predictions, output_dir, epoch): | |
""" | |
推論結果をJSONファイルに保存する。 | |
""" | |
os.makedirs(output_dir, exist_ok=True) | |
output_path = os.path.join(output_dir, f"validation_results_epoch_{epoch}.json") | |
try: | |
with open(output_path, "w", encoding="utf-8") as f: | |
json.dump({"results": predictions}, f, ensure_ascii=False, indent=4) | |
print(f"Validation results saved to {output_path}") | |
except Exception as e: | |
print(f"[Error] Failed to save validation results: {e}") | |
def validate_training_data(train_dataset): | |
""" | |
学習データの内容を検証し、不正があれば例外を発生させる。 | |
""" | |
for idx, sample in enumerate(train_dataset.samples): | |
# ページサイズのチェック | |
page_size = sample.get("page_size") | |
if not page_size or len(page_size) != 2 or not all(isinstance(v, (int, float)) for v in page_size): | |
raise ValueError(f"Invalid page_size in sample {idx}: {page_size}") | |
# テキストの埋め込み対象が空でないか | |
texts = sample.get("texts", []) | |
if not texts or not all(isinstance(t, str) and t.strip() for t in texts): | |
raise ValueError(f"Invalid texts in sample {idx}: {texts}") | |
# フレーム座標が正しい形式か | |
frame_coords = sample.get("frame_coords", []) | |
if frame_coords and not all(len(coord) == 4 for coord in frame_coords): | |
raise ValueError(f"Invalid frame_coords in sample {idx}: {frame_coords}") | |
# フレーム数と座標数の不整合 | |
num_frames = sample.get("num_frames", 0) | |
if num_frames != len(frame_coords): | |
if self.for_training and num_frames > 0: | |
raise ValueError( | |
f"Inconsistent number of frames in sample {idx}: num_frames={num_frames}, frame_coords={len(frame_coords)}" | |
) | |
print("[INFO] Training data validation passed.") | |
def main(train_folder, val_json, model_output, epochs, learning_rate, sentence_transformer_model, batch_size, use_fp16, alpha, beta, save_epochs): | |
global model_embedding | |
model_embedding = SentenceTransformer(sentence_transformer_model) | |
train_files = glob.glob(os.path.join(train_folder, "*.json")) | |
train_dataset = MangaDataset(train_files, skip_incomplete=True, for_training=True) | |
# 学習データの検証 | |
validate_training_data(train_dataset) | |
val_dataset = MangaDataset([val_json], skip_incomplete=False, for_training=False) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = MangaModel(text_embed_dim=model_embedding.get_sentence_embedding_dimension()).to(device) | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
criterion = nn.L1Loss() | |
for epoch in range(epochs): | |
# save_epochsごとに推論を実行 | |
if (epoch + 1) % save_epochs == 0 or epoch == epochs - 1: | |
print(f"Epoch {epoch + 1}/{epochs}") | |
# モデルを訓練 | |
train_loss = train_model(train_loader, model, optimizer, criterion, device, use_fp16, alpha, beta) | |
print(f"Train Loss: {train_loss:.6f}") | |
predictions = validate_model(val_loader, model, device, use_fp16) | |
# 推論結果を保存 | |
save_predictions_to_json(predictions, model_output, epoch + 1) | |
# モデルの保存 | |
model_save_path = os.path.join(model_output, f"model_epoch_{epoch + 1}.pth") | |
torch.save(model.state_dict(), model_save_path) | |
print(f"Model saved at {model_save_path}") | |
# 最終モデルの保存 | |
final_model_path = os.path.join(model_output, "model_final.pth") | |
torch.save(model.state_dict(), final_model_path) | |
print(f"Final model saved at {final_model_path}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-j", "--json_folder", required=True, help="Path to training JSON folder") | |
parser.add_argument("-v", "--val_json", required=True, help="Path to validation JSON file") | |
parser.add_argument("-m", "--model_output", required=True, help="Path to save trained model") | |
parser.add_argument("-e", "--epochs", type=int, default=100000000000000, help="Number of epochs") | |
parser.add_argument("-se", "--save_epochs", type=int, default=100000000, help="Save model and run validation every N epochs") | |
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3, help="Learning rate") | |
parser.add_argument("-s", "--sentence_transformer_model", type=str, default="sonoisa/sentence-bert-base-ja-mean-tokens-v2", help="SentenceTransformer model name") | |
parser.add_argument("-b", "--batch_size", type=int, default=8, help="Batch size") | |
parser.add_argument("--fp16", action="store_true", help="Use mixed precision training") | |
parser.add_argument("--alpha", type=float, default=2.0, help="Weight for text loss") | |
parser.add_argument("--beta", type=float, default=1.0, help="Weight for frame loss") | |
args = parser.parse_args() | |
main( | |
args.json_folder, | |
args.val_json, | |
args.model_output, | |
args.epochs, | |
args.learning_rate, | |
args.sentence_transformer_model, | |
args.batch_size, | |
args.fp16, | |
args.alpha, | |
args.beta, | |
args.save_epochs | |
) |
何が上手くいくかわからないので学習を試してみてからでいいと思いますが、 1つのlossにtext_coordsとframe_coordsのlossを加算していくと、たとえばフレームの方ががたくさんある時に、frame_coords_modelばかりが最適化されて行ってしまうと思うので、テキスト数やフレーム数でそれぞれのlossを割るか、係数を付けるなどしてバランスを調整した方がいい気がしました
コメントありがとうございます!
アドバイスに従って係数追加してみました!とりあえずこれで学習してみます~
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
何が上手くいくかわからないので学習を試してみてからでいいと思いますが、
1つのlossにtext_coordsとframe_coordsのlossを加算していくと、たとえばフレームの方ががたくさんある時に、frame_coords_modelばかりが最適化されて行ってしまうと思うので、テキスト数やフレーム数でそれぞれのlossを割るか、係数を付けるなどしてバランスを調整した方がいい気がしました