Last active
July 9, 2024 12:36
-
-
Save tori29umai0123/c8e478925cc6524401dbb6b28f5a32c6 to your computer and use it in GitHub Desktop.
rag_stable_diffusion_prompt.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
from pathlib import Path | |
import re | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from llama_cpp import Llama | |
import numpy as np | |
class CustomEmbeddings: | |
def __init__(self, model): | |
self.model = model | |
def embed_documents(self, texts): | |
return np.array(self.model.encode(texts), dtype=np.float32) | |
def embed_query(self, text): | |
return np.array(self.model.encode([text]), dtype=np.float32) | |
# Danbooruタグの読み込み | |
def load_danbooru_tags(file_path): | |
with open(file_path, "r", encoding="utf-8") as f: | |
return [line.split(",")[0].strip() for line in f] | |
# ベクトルストアの作成 | |
def create_vector_store(tags): | |
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) | |
embeddings = CustomEmbeddings(model) | |
documents = [tag for tag in tags] | |
vectors = embeddings.embed_documents(documents) | |
index = faiss.IndexFlatL2(len(vectors[0])) | |
index.add(vectors) | |
return index, documents, embeddings | |
# RAGシステムのセットアップ | |
def setup_rag_system(index, documents, embeddings, num_of_ref_tags=10): | |
model_path = "gemma-2-9b-it-Q8_0.gguf" | |
llm = Llama(model_path=model_path, n_ctx=2048, n_batch=512) | |
def retrieve(query, k=num_of_ref_tags): | |
query_vector = embeddings.embed_query(query) | |
D, I = index.search(query_vector, k) | |
return [documents[i] for i in I[0]] | |
def generate_response(query, context): | |
prompt = f""" | |
You are a precise tag matcher for Danbooru tags. Your task is to find exact matches in the provided context for the given input element. | |
Rules: | |
1. Only return tags that are exact matches to the input element and present in the context. | |
2. Do not add any tags that are not exact matches to the input. | |
3. If an exact match is found in the context, return only that match. | |
4. If no exact match is found, return the input element as is. | |
5. Never modify, complete, or expand the input element. | |
6. Provide a maximum of one tag as output. | |
Context: {', '.join(context)} | |
Input: {query} | |
Output: | |
""" | |
response = llm(prompt, max_tokens=50, stop=["\n"]) | |
return response["choices"][0]["text"].strip() | |
return retrieve, generate_response | |
def generate_rich_description(scene_description): | |
model_path = "gemma-2-9b-it-Q8_0.gguf" | |
llm = Llama(model_path=model_path, n_ctx=2048, n_batch=512, tensor_split=[48, 0, 0]) | |
prompt = f""" | |
Based on the following brief scene description, generate a comma-separated list of danbooru tags that could be used as a Stable Diffusion prompt in English. | |
Brief description: {scene_description} | |
Stable Diffusion Prompt Elements: | |
""" | |
response = llm(prompt, max_tokens=200, stop=["\n\n"]) | |
return response["choices"][0]["text"].strip() | |
def convert_description_to_danbooru_tags(prompt_elements, retrieve, generate_response): | |
elements = [elem.strip() for elem in prompt_elements.split(",")] | |
danbooru_tags = [] | |
for element in elements: | |
context = retrieve(element) | |
tag = generate_response(element, context) | |
if tag: | |
danbooru_tags.append(tag) | |
print(f"{element}: {tag}") | |
unique_tags = [] | |
seen = set() | |
for tag in danbooru_tags: | |
if tag not in seen: | |
unique_tags.append(tag) | |
seen.add(tag) | |
return ", ".join(unique_tags) | |
if __name__ == "__main__": | |
vector_store_path = Path("danbooru_tags_vector_store.faiss") | |
if vector_store_path.exists(): | |
index = faiss.read_index(str(vector_store_path)) | |
tags = load_danbooru_tags("danbooru_tags.csv") | |
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) | |
embeddings = CustomEmbeddings(model) | |
else: | |
tags = load_danbooru_tags("danbooru_tags.csv") | |
index, documents, embeddings = create_vector_store(tags) | |
faiss.write_index(index, str(vector_store_path)) | |
retrieve, generate_response = setup_rag_system(index, tags, embeddings, num_of_ref_tags=10) | |
scene_description = "浴衣の少女が花火を見て笑っている" | |
prompt_elements = generate_rich_description(scene_description) | |
print(f"Generated Danbooru Tags:\n{prompt_elements}") | |
danbooru_tags = convert_description_to_danbooru_tags(prompt_elements, retrieve, generate_response) | |
print(f"\nMatched Danbooru Tags:\n{danbooru_tags}") | |
# リソースの解放 | |
del index | |
del retrieve | |
del generate_response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment