Last active
June 11, 2024 13:02
-
-
Save tori29umai0123/615eb806832fba83025912cfc82008bb to your computer and use it in GitHub Desktop.
nsfw_filter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import csv | |
import glob | |
import os | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
import onnx | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
import shutil | |
# 画像のサイズ | |
IMAGE_SIZE = 448 | |
def preprocess_image(image): | |
image = np.array(image) | |
image = image[:, :, ::-1] # BGRからRGBに変換 | |
# 正方形にパディング | |
size = max(image.shape[0:2]) | |
pad_x = size - image.shape[1] | |
pad_y = size - image.shape[0] | |
pad_l = pad_x // 2 | |
pad_t = pad_y // 2 | |
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) | |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 | |
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) | |
image = image.astype(np.float32) | |
return image | |
def run_batch(path_imgs, input_name, ort_sess, rating_tags, thresh, nsfw_dir, sfw_dir): | |
imgs = np.array([im for _, im in path_imgs]) | |
probs = ort_sess.run(None, {input_name: imgs})[0] # onnxの出力 | |
probs = probs[: len(path_imgs)] | |
for (image_path, _), prob in zip(path_imgs, probs): | |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)} | |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0)) | |
max_sfw_score = tag_confidences.get("general", 0) | |
destination = nsfw_dir if max_nsfw_score > max_sfw_score else sfw_dir | |
# 適切なフォルダに画像をコピー | |
try: | |
shutil.copy(image_path, os.path.join(destination, os.path.basename(image_path))) | |
print(f"{image_path} を {destination} にコピーしました。") | |
except Exception as e: | |
print(f"{image_path} を {destination} にコピーできませんでした。エラー: {e}") | |
def main(): | |
print("Hugging Faceからwd14 taggerをロード中") | |
onnx_path = hf_hub_download(MODEL_ID, "model.onnx") | |
csv_path = hf_hub_download(MODEL_ID, "selected_tags.csv") | |
print("wd14 taggerでonnxを実行") | |
print(f"onnxモデルをロード中: {onnx_path}") | |
ort_sess = ort.InferenceSession(onnx_path) | |
with open(csv_path, "r", encoding="utf-8") as f: | |
reader = csv.reader(f) | |
header = next(reader) # ヘッダー行を読む | |
rows = list(reader) | |
assert header == ["tag_id", "name", "category", "count"], f"予期しないCSVフォーマット: {header}" | |
rating_tags = [row[1] for row in rows if row[2] == "9"] | |
#版権フィルター用範囲 | |
#character_tags = [row[1] for row in rows[1:] if row[2] == "4"] | |
image_paths = glob.glob(os.path.join(input_dir, "*.*")) | |
b_imgs = [] | |
for image_path in tqdm(image_paths, smoothing=0.0): | |
try: | |
image = Image.open(image_path) | |
image = image.convert("RGB") if image.mode != "RGB" else image | |
image = preprocess_image(image) | |
b_imgs.append((image_path, image)) | |
except Exception as e: | |
print(f"画像を読み込めません: {image_path}, エラー: {e}") | |
continue | |
if len(b_imgs) >= batch_size: | |
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, thresh, nsfw_dir, sfw_dir) | |
b_imgs = [] | |
if b_imgs: | |
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, thresh, nsfw_dir, sfw_dir) | |
print("処理完了!") | |
if __name__ == "__main__": | |
MODEL_ID = "SmilingWolf/wd-vit-tagger-v3" | |
input_dir = "E:/desktop/dart/test" | |
sfw_dir = "E:/desktop/dart/test_sfw" | |
nsfw_dir = "E:/desktop/dart/test_nsfw" | |
batch_size = 16 | |
thresh = 0.35 | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment