Last active
August 18, 2025 07:33
-
-
Save tamanobi/e6c3593b80864537bce9cde84d08e134 to your computer and use it in GitHub Desktop.
calling comfyui
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.11" | |
# dependencies = [ | |
# "httpx", | |
# "websocket-client", | |
# "pillow", | |
# "typing-extensions", | |
# ] | |
# /// | |
from pathlib import Path | |
import json | |
import uuid | |
import websocket | |
from PIL import Image | |
import io | |
import urllib.request | |
import threading | |
import time | |
import httpx | |
from typing import Optional, Callable, Dict, Any | |
import logging | |
import hashlib | |
class ImageGenerator: | |
def __init__(self, workflow_path: Path, output_path: Path, host: str): | |
self.workflow_path = workflow_path | |
self.output_path = output_path | |
self.host = host | |
self.client_id = str(uuid.uuid4()) | |
self.workflow = json.loads(self.workflow_path.read_text()) | |
# 出力ディレクトリを作成 | |
self.output_path.mkdir(parents=True, exist_ok=True) | |
# ログ設定 | |
logging.basicConfig(level=logging.INFO) | |
self.logger = logging.getLogger(__name__) | |
# WebSocket関連 | |
self.ws = None | |
self.progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None | |
self.completion_event = threading.Event() | |
self.generated_images: list[tuple[str, bytes]] = [] | |
self.current_prompt_id = None | |
self.current_node = {} | |
def __call__(self, prompt: str, negative: str, seed: int, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None): | |
return self.gen(prompt, negative, seed, progress_callback) | |
def queue_prompt(self, prompt_data: Dict[str, Any]) -> Dict[str, Any]: | |
"""プロンプトをComfyUIのキューに追加""" | |
data = json.dumps({"prompt": prompt_data, "client_id": self.client_id}).encode('utf-8') | |
req = urllib.request.Request(f"{self.host}/prompt", data=data) | |
req.add_header('Content-Type', 'application/json') | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
with httpx.Client(timeout=10) as client: | |
response = client.post(f"{self.host}/prompt", data=data) | |
return json.loads(response.text) | |
except urllib.error.URLError as e: | |
if attempt < max_retries - 1: | |
self.logger.warning(f"プロンプト送信失敗 (試行 {attempt + 1}/{max_retries}): {e}") | |
time.sleep(2 ** attempt) # 指数バックオフ | |
continue | |
else: | |
self.logger.error(f"プロンプト送信エラー (最大試行回数に達しました): {e}") | |
raise | |
except Exception as e: | |
self.logger.error(f"プロンプト送信エラー: {e}") | |
raise | |
def on_websocket_message(self, ws, message: bytes | str): | |
"""WebSocketメッセージの処理""" | |
try: | |
if isinstance(message, bytes): | |
self.logger.info(f"WebSocket から bytes を受信しました: {message[:80]}") | |
if self.current_node.get(self.current_prompt_id) == "save_image_websocket_node": | |
image = message[8:] | |
hash_image = hashlib.sha256(image).hexdigest() | |
self.generated_images.append((hash_image, image)) | |
return | |
elif isinstance(message, str): | |
self.logger.info(f"WebSocket から str を受信しました: {message[:80]}") | |
msg = json.loads(message) | |
else: | |
raise ValueError("Invalid message type") | |
if msg['type'] == 'executing': | |
data = msg.get('data', {}) | |
if data.get('node'): | |
self.current_node[self.current_prompt_id] = data['node'] | |
if data.get('node') is None and data.get('prompt_id') == self.current_prompt_id: | |
# 実行完了 | |
self.logger.info("画像生成完了") | |
self.completion_event.set() | |
elif self.progress_callback: | |
self.progress_callback(msg) | |
elif msg['type'] == 'progress': | |
if self.progress_callback: | |
self.progress_callback(msg) | |
elif msg['type'] == 'executed': | |
# 生成された画像の情報を取得 | |
data = msg.get('data', {}) | |
if 'output' in data: | |
for node_id, output in data['output'].items(): | |
if 'images' in output: | |
for image_info in output['images']: | |
self.generated_images.append(image_info) | |
elif msg['type'] == 'execution_success': | |
print("success", msg["data"]) | |
else: | |
pass | |
except Exception as e: | |
self.logger.error(f"WebSocketメッセージ処理エラー: {e}", exc_info=True) | |
def on_websocket_error(self, ws, error): | |
"""WebSocketエラーの処理""" | |
self.logger.error(f"WebSocketエラー: {error}") | |
def on_websocket_close(self, ws, close_status_code, close_msg): | |
"""WebSocket接続のクローズ処理""" | |
self.logger.info("WebSocket接続が閉じられました") | |
def on_websocket_open(self, ws): | |
"""WebSocket接続のオープン処理""" | |
self.logger.info("WebSocket接続が確立されました") | |
def connect_websocket(self, max_retries: int = 3): | |
"""WebSocket接続を確立""" | |
if self.host.startswith("http://"): | |
ws_url = f"ws://{self.host.replace('http://', '')}/ws?clientId={self.client_id}" | |
elif self.host.startswith("https://"): | |
ws_url = f"wss://{self.host.replace('https://', '')}/ws?clientId={self.client_id}" | |
else: | |
raise ValueError("Invalid host: must start with http:// or https://") | |
for attempt in range(max_retries): | |
try: | |
self.ws = websocket.WebSocketApp( | |
ws_url, | |
on_message=self.on_websocket_message, | |
on_error=self.on_websocket_error, | |
on_close=self.on_websocket_close, | |
on_open=self.on_websocket_open | |
) | |
# WebSocketを別スレッドで実行 | |
def run_websocket(): | |
try: | |
self.ws.run_forever(ping_interval=30, ping_timeout=10) | |
except Exception as e: | |
self.logger.error(f"WebSocket実行エラー: {e}") | |
ws_thread = threading.Thread(target=run_websocket, daemon=True) | |
ws_thread.start() | |
# 接続が確立されるまで待つ | |
time.sleep(2) | |
# 接続確認 | |
if self.ws and self.ws.sock and self.ws.sock.connected: | |
self.logger.info("WebSocket接続成功") | |
return | |
else: | |
raise ConnectionError("WebSocket接続に失敗しました") | |
except Exception as e: | |
if attempt < max_retries - 1: | |
self.logger.warning(f"WebSocket接続失敗 (試行 {attempt + 1}/{max_retries}): {e}") | |
time.sleep(2 ** attempt) | |
continue | |
else: | |
self.logger.error(f"WebSocket接続エラー (最大試行回数に達しました): {e}") | |
raise | |
def update_workflow(self, prompt: str, negative: str, seed: int, mapping: dict[str, str]) -> Dict[str, Any]: | |
"""ワークフローにプロンプト、ネガティブプロンプト、シードを設定""" | |
workflow = self.workflow.copy() | |
# プロンプトとネガティブプロンプトの設定 | |
# これはワークフローの構造に依存するため、実際のワークフローに合わせて調整が必要 | |
if mapping: | |
prompt_node_id = mapping.get('prompt_node_id') | |
negative_node_id = mapping.get('negative_node_id') | |
seed_node_id = mapping.get('seed_node_id') | |
if prompt_node_id: | |
workflow[prompt_node_id]['inputs']['text'] = prompt | |
if negative_node_id: | |
workflow[negative_node_id]['inputs']['text'] = negative | |
if seed_node_id: | |
workflow[seed_node_id]['inputs']['seed'] = seed | |
return workflow | |
for node_id, node in workflow.items(): | |
if node.get('class_type') == 'CLIPTextEncode': | |
if 'text' in node.get('inputs', {}): | |
# ポジティブプロンプトかネガティブプロンプトかを判定 | |
# 通常、最初に見つかったものがポジティブ、次がネガティブ | |
current_text = node['inputs']['text'] | |
if 'bad' in current_text.lower() or 'negative' in current_text.lower(): | |
node['inputs']['text'] = negative | |
else: | |
node['inputs']['text'] = prompt | |
elif node.get('class_type') == 'KSampler': | |
if 'seed' in node.get('inputs', {}): | |
node['inputs']['seed'] = seed | |
return workflow | |
def download_images(self) -> list: | |
"""生成された画像をダウンロード""" | |
downloaded_files = [] | |
for digest, image_bytes in self.generated_images: | |
img = Image.open(io.BytesIO(image_bytes)) | |
saving_path = self.output_path / self.current_prompt_id / f"{digest}.png" | |
saving_path.parent.mkdir(parents=True, exist_ok=True) | |
img.save(saving_path) | |
downloaded_files.append(saving_path) | |
return downloaded_files | |
def check_server_status(self) -> bool: | |
"""ComfyUIサーバーの状態を確認""" | |
try: | |
with httpx.Client(timeout=5.0) as client: | |
response = client.get(f"{self.host}/system_stats") | |
response.raise_for_status() | |
self.logger.info("ComfyUIサーバーは正常に稼働しています") | |
return True | |
except httpx.RequestError as e: | |
self.logger.error(f"ComfyUIサーバーに接続できません: {e}") | |
return False | |
except httpx.HTTPStatusError as e: | |
self.logger.error(f"ComfyUIサーバーエラー: {e.response.status_code} - {e}") | |
return False | |
def gen(self, prompt: str, negative: str, seed: int, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None) -> list: | |
"""画像生成を実行""" | |
self.progress_callback = progress_callback | |
self.completion_event.clear() | |
self.generated_images.clear() | |
# サーバー状態確認 | |
if not self.check_server_status(): | |
self.logger.error("ComfyUIサーバーが利用できません") | |
return [] | |
try: | |
# WebSocket接続 | |
self.connect_websocket() | |
# ワークフローを更新 | |
mapping = { | |
"prompt_node_id": "6", | |
"negative_node_id": "7", | |
"seed_node_id": "3" | |
} | |
updated_workflow = self.update_workflow(prompt, negative, seed, mapping) | |
print("updated_workflow", updated_workflow) | |
# プロンプトをキューに追加 | |
result = self.queue_prompt(updated_workflow) | |
self.current_prompt_id = result.get('prompt_id') | |
print("current_prompt_id", self.current_prompt_id) | |
if not self.current_prompt_id: | |
self.logger.error("プロンプトIDが取得できませんでした") | |
return [] | |
self.logger.info(f"プロンプトID: {self.current_prompt_id}") | |
# 完了まで待機(タイムアウト付き) | |
if self.completion_event.wait(timeout=300): # 5分でタイムアウト | |
# 画像をダウンロード | |
downloaded_files = self.download_images() | |
if not downloaded_files: | |
self.logger.warning("画像が生成されませんでした") | |
return downloaded_files | |
else: | |
self.logger.error("タイムアウト: 画像生成が完了しませんでした") | |
return [] | |
except Exception as e: | |
self.logger.error(f"画像生成エラー: {e}") | |
return [] | |
finally: | |
# WebSocket接続を閉じる | |
if self.ws: | |
try: | |
self.ws.close() | |
except Exception as e: | |
self.logger.warning(f"WebSocket切断エラー: {e}") | |
if __name__ == "__main__": | |
host = "https://comfyui-comfyui.climt.studio" | |
gen = ImageGenerator(Path("workflow_api.json"), Path("output"), host) | |
images = gen("a beautiful girl", "ugly, bad, deformed", 42, None) | |
print("images", images) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment