Last active
August 10, 2024 08:05
-
-
Save tori29umai0123/b6ce6d6450e5f87c00633af7d37915ea to your computer and use it in GitHub Desktop.
AI-NovelChat
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import time | |
import socket | |
import gradio as gr | |
from llama_cpp import Llama | |
import datetime | |
from jinja2 import Template | |
import configparser | |
import threading | |
import asyncio | |
import csv | |
DEFAULT_INI_FILE = 'settings.ini' | |
# ビルドしているかしていないかでパスを変更 | |
if getattr(sys, 'frozen', False): | |
path = os.path.dirname(sys.executable) | |
model_dir = os.path.join(os.path.dirname(path), "AI-NovelAssistant", "models") | |
else: | |
path = os.path.dirname(os.path.abspath(__file__)) | |
model_dir = os.path.join(path, "models") | |
def get_model_files(): | |
return [f for f in os.listdir(model_dir) if f.endswith('.gguf')] | |
def load_settings_from_ini(filename): | |
config = configparser.ConfigParser() | |
if not os.path.exists(filename): | |
print(f"{filename} が見つかりません。デフォルト設定で作成します。") | |
create_default_ini(filename) | |
config.read(filename, encoding='utf-8') | |
settings = {} | |
if 'Character' in config: | |
settings['instructions'] = config['Character'].get('instructions', '') | |
settings['example_qa'] = config['Character'].get('example_qa', '').split('\n') | |
settings['initial_conversation'] = config['Character'].get('initial_conversation', '') | |
if 'Models' in config: | |
settings['DEFAULT_CHAT_MODEL'] = config['Models'].get('DEFAULT_CHAT_MODEL', '') | |
settings['DEFAULT_GEN_MODEL'] = config['Models'].get('DEFAULT_GEN_MODEL', '') | |
return settings | |
def save_settings_to_ini(settings, filename): | |
config = configparser.ConfigParser() | |
config['Character'] = { | |
'instructions': settings.get('instructions', ''), | |
'example_qa': '\n'.join(settings.get('example_qa', [])), | |
'initial_conversation': settings.get('initial_conversation', '') | |
} | |
config['Models'] = { | |
'DEFAULT_CHAT_MODEL': settings.get('DEFAULT_CHAT_MODEL', ''), | |
'DEFAULT_GEN_MODEL': settings.get('DEFAULT_GEN_MODEL', '') | |
} | |
with open(filename, 'w', encoding='utf-8') as configfile: | |
config.write(configfile) | |
def create_default_ini(filename): | |
default_settings = { | |
'instructions': "丁寧な敬語でアイディアのヒアリングしてください。物語をより面白くする提案、キャラクター造形の考察、世界観を膨らませる手伝いなどをお願いします。求められた時以外は基本、聞き役に徹してユーザー自身に言語化させるよう促してください。ユーザーのことは『ユーザー』と呼んでください。", | |
'example_qa': [ | |
"user: キャラクターの設定について悩んでいます。", | |
"assistant: 承知いたしました。キャラクター設定は物語の核となる重要な要素ですね。ユーザー様が現在考えているキャラクターについて、簡単にご説明いただけますでしょうか?例えば、年齢、性別、職業、性格の特徴などから始めていただけると、より具体的なアドバイスができるかと思います。" | |
"user: プロットを書き出したいので、ヒアリングお願いします。", | |
"assistant: 承知しました。ではまず『起承転結』の起から考えていきましょう。", | |
"user: 読者を惹きこむ為のコツを提案してください", | |
"assistant: 諸説ありますが、『謎・ピンチ・意外性』を冒頭に持ってくることが重要だと言います。", | |
"user: プロットが面白いか自信がないので、考察のお手伝いをお願いします", | |
"assistant: 承知しました。まずコメントをする前にこの物語の『売り』について簡単に言語化してください", | |
], | |
'DEFAULT_CHAT_MODEL': 'Ninja-v1-RP-expressive-v2_Q4_K_M.gguf', | |
'DEFAULT_GEN_MODEL': 'Mistral-Nemo-Instruct-2407-Q8_0.gguf' | |
} | |
save_settings_to_ini(default_settings, filename) | |
def list_log_files(): | |
logs_dir = os.path.join(path, "logs") | |
if not os.path.exists(logs_dir): | |
return [] | |
return [f for f in os.listdir(logs_dir) if f.endswith('.csv')] | |
def load_chat_log(file_name): | |
file_path = os.path.join(path, "logs", file_name) | |
chat_history = [] | |
with open(file_path, 'r', encoding='utf-8') as csvfile: | |
reader = csv.reader(csvfile) | |
next(reader) # Skip header | |
for row in reader: | |
if len(row) == 2: | |
role, message = row | |
if role == "user": | |
chat_history.append([message, None]) | |
elif role == "assistant": | |
if chat_history and chat_history[-1][1] is None: | |
chat_history[-1][1] = message | |
else: | |
chat_history.append([None, message]) | |
return chat_history | |
class GentextParams: | |
def __init__(self): | |
self.gen_temperature = 0.35 | |
self.gen_top_p = 1.0 | |
self.gen_top_k = 40 | |
self.gen_rep_pen = 1.0 | |
self.chat_temperature = 0.5 | |
self.chat_top_p = 0.7 | |
self.chat_top_k = 80 | |
self.chat_rep_pen = 1.2 | |
def update_generate_parameters(self, temperature, top_p, top_k, rep_pen): | |
self.gen_temperature = temperature | |
self.gen_top_p = top_p | |
self.gen_top_k = top_k | |
self.gen_rep_pen = rep_pen | |
def update_chat_parameters(self, temperature, top_p, top_k, rep_pen): | |
self.chat_temperature = temperature | |
self.chat_top_p = top_p | |
self.chat_top_k = top_k | |
self.chat_rep_pen = rep_pen | |
params = GentextParams() | |
class LlamaAdapter: | |
def __init__(self, model_path, params): | |
self.llm = Llama(model_path=model_path, n_ctx=10000) | |
self.params = params | |
def generate_text(self, text, author_description, token_multiplier, instruction): | |
input_tokens = self.llm.tokenize(text.encode()) | |
max_tokens = int(len(input_tokens) * token_multiplier) | |
response = self.llm.create_chat_completion( | |
messages=[ | |
{"role": "system", "content": author_description}, | |
{"role": "user", "content": f"{instruction}:\n\n{text}"}, | |
], | |
max_tokens=max_tokens, temperature=self.params.gen_temperature, top_p=self.params.gen_top_p, top_k=self.params.gen_top_k, repeat_penalty=self.params.gen_rep_pen, | |
) | |
return response["choices"][0]["message"]["content"].strip() | |
def generate(self, prompt, max_new_tokens=10000): | |
return self.llm(prompt, temperature=self.params.chat_temperature, max_tokens=max_new_tokens, top_p=self.params.chat_top_p, top_k=self.params.chat_top_k, repeat_penalty=self.params.chat_rep_pen, stop=["user:", "・会話履歴", "<END>"]) | |
class CharacterMaker: | |
def __init__(self): | |
self.llama = None | |
self.history = [] | |
self.settings = None | |
self.model_loaded = threading.Event() | |
def set_model(self, model_name): | |
def load_model(): | |
try: | |
model_path = os.path.join(model_dir, model_name) | |
self.llama = LlamaAdapter(model_path, params) | |
self.model_loaded.set() | |
print(f"モデル {model_name} のロードが完了しました。") | |
except Exception as e: | |
print(f"モデルのロード中にエラーが発生しました: {str(e)}") | |
self.model_loaded.set() # エラーの場合でもイベントをセット | |
threading.Thread(target=load_model).start() | |
def make(self, input_str: str): | |
if not self.model_loaded.is_set(): | |
return "モデルをロード中です。しばらくお待ちください。" | |
if not self.llama: | |
return "モデルのロードに失敗しました。設定を確認してください。" | |
prompt = self._generate_prompt(input_str) | |
res = self.llama.generate(prompt, max_new_tokens=1000, stop=["<END>", "\n"]) | |
res_text = res["choices"][0]["text"] | |
self.history.append({"user": input_str, "assistant": res_text}) | |
return res_text | |
def make_prompt(self, input_str: str): | |
prompt_template = """{{instructions}} | |
・キャラクターの回答例 | |
{% for qa in example_qa %} | |
{{qa}} | |
{% endfor %} | |
・会話履歴 | |
{% for history in histories %} | |
user: {{history.user}} | |
assistant: {{history.assistant}} | |
{% endfor %} | |
user: {{input_str}} | |
assistant:""" | |
template = Template(prompt_template) | |
return template.render( | |
instructions=self.settings.get('instructions', ''), | |
example_qa=self.settings.get('example_qa', []), | |
histories=self.history, | |
input_str=input_str | |
) | |
def _generate_prompt(self, input_str: str): | |
return self.make_prompt(input_str) | |
def update_settings(self, new_settings, filename): | |
self.settings.update(new_settings) | |
save_settings_to_ini(self.settings, filename) | |
self.set_model(self.settings['DEFAULT_CHAT_MODEL']) | |
def load_character(self, filename): | |
if isinstance(filename, list): | |
filename = filename[0] if filename else "" | |
self.settings = load_settings_from_ini(filename) | |
if self.settings: | |
self.set_model(self.settings['DEFAULT_CHAT_MODEL']) | |
return f"{filename}から設定を読み込み、モデル {self.settings['DEFAULT_CHAT_MODEL']} を設定しました。" | |
return f"{filename}の読み込みに失敗しました。" | |
def reset(self): | |
self.history = [] | |
if self.llama: | |
self.set_model(self.settings['DEFAULT_CHAT_MODEL']) | |
character_maker = CharacterMaker() | |
async def chat_with_character(message, history): | |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history] | |
prompt = character_maker._generate_prompt(message) | |
response = character_maker.llama.generate(prompt, max_new_tokens=1000)["choices"][0]["text"] | |
for i in range(len(response)): | |
time.sleep(0.01) | |
yield response[: i+1] | |
def generate_text_with_token_multiplier(text, author_type, genre, writing_style, target_audience, token_multiplier, model_name, instruction): | |
author_description = f"あなたは{author_type}で、{genre}と{writing_style}の文体で{target_audience}に人気があります。" | |
model_path = os.path.join(model_dir, model_name) | |
llama = LlamaAdapter(model_path, params) | |
return llama.generate_text(text, author_description, token_multiplier, instruction) | |
def clear_chat(): | |
character_maker.reset() | |
return [] | |
def save_chat_log(chat_history): | |
current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") | |
filename = f"{current_time}.csv" | |
logs_dir = os.path.join(path, "logs") | |
if not os.path.exists(logs_dir): | |
os.makedirs(logs_dir) | |
file_path = os.path.join(logs_dir, filename) | |
with open(file_path, 'w', newline='', encoding='utf-8') as csvfile: | |
writer = csv.writer(csvfile) | |
writer.writerow(["Role", "Message"]) | |
for user_message, assistant_message in chat_history: | |
if user_message: | |
writer.writerow(["user", user_message]) | |
if assistant_message: | |
writer.writerow(["assistant", assistant_message]) | |
return f"チャットログが {file_path} に保存されました。" | |
def get_ip_address(): | |
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
try: | |
s.connect(('10.255.255.255', 1)) | |
IP = s.getsockname()[0] | |
except Exception: | |
IP = '127.0.0.1' | |
finally: | |
s.close() | |
return IP | |
def is_port_in_use(port): | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
return s.connect_ex(('localhost', port)) == 0 | |
def find_available_port(starting_port): | |
port = starting_port | |
while is_port_in_use(port): | |
print(f"Port {port} is in use, trying next one.") | |
port += 1 | |
return port | |
model_files = get_model_files() | |
def build_gradio_interface(): | |
global demo | |
# カスタムCSS | |
custom_css = """ | |
#chatbot, #chatbot_read { | |
height: 50vh; | |
overflow-y: auto; | |
resize: vertical; | |
border: 1px solid #ccc; | |
} | |
/* サイズ変更用のグリップをより直感的に操作できるようにスタイリング */ | |
.resizer-grip { | |
height: 10px; | |
background: #ccc; | |
cursor: ns-resize; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
# HTMLブロックでカスタムJavaScriptと追加のCSSを注入 | |
gr.HTML(""" | |
<style> | |
#chatbot, #chatbot_read { | |
resize: both; | |
overflow: auto; | |
min-height: 100px; | |
max-height: 80vh; | |
} | |
</style> | |
<script> | |
// リサイズを処理するためのJavaScript、必要であれば | |
document.addEventListener('DOMContentLoaded', function() { | |
const chatboxes = document.querySelectorAll('#chatbot, #chatbot_read'); | |
chatboxes.forEach(chatbox => { | |
chatbox.addEventListener('mousedown', function(e) { | |
console.log('Resizing started'); | |
}); | |
}); | |
}); | |
</script> | |
""") | |
with gr.Tab("チャット"): | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
chat_interface = gr.ChatInterface( | |
chat_with_character, | |
chatbot=chatbot, | |
textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7), | |
theme="soft", | |
submit_btn="送信", | |
stop_btn="停止", | |
retry_btn="もう一度生成", | |
undo_btn="前のメッセージを取り消す", | |
clear_btn="チャットをクリア", | |
) | |
with gr.Row(): | |
model_dropdown = gr.Dropdown(choices=model_files, label="モデル選択", value=character_maker.settings.get('DEFAULT_CHAT_MODEL', '')) | |
save_log_button = gr.Button("チャットログを保存") | |
save_log_output = gr.Textbox(label="保存状態") | |
with gr.Accordion("詳細設定", open=False): | |
chat_temperature = gr.Slider(label="Temperature", value=0.5, minimum=0.0, maximum=1.0, step=0.05, interactive=True) | |
chat_top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.7, minimum=0.0, maximum=1, step=0.05, interactive=True) | |
chat_top_k = gr.Slider(label="Top-k", value=80, minimum=1, maximum=200, step=1, interactive=True) | |
chat_rep_pen = gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True) | |
apply_settings = gr.Button("設定を適用") | |
def apply_chat_settings(temp, top_p, top_k, rep_pen): | |
params.update_chat_parameters(temp, top_p, top_k, rep_pen) | |
return f"設定を適用しました: Temperature={temp}, Top-p={top_p}, Top-k={top_k}, Repetition Penalty={rep_pen}" | |
apply_settings.click( | |
apply_chat_settings, | |
inputs=[chat_temperature, chat_top_p, chat_top_k, chat_rep_pen], | |
outputs=[save_log_output] | |
) | |
save_log_button.click( | |
save_chat_log, | |
inputs=[chatbot], | |
outputs=[save_log_output] | |
) | |
model_dropdown.change( | |
lambda x: character_maker.set_model(x), | |
inputs=[model_dropdown], | |
outputs=[] | |
) | |
with gr.Tab("文章生成"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instruction_type = gr.Dropdown( | |
choices=["自由入力", "推敲", "プロット作成", "あらすじ作成"], | |
label="指示タイプ", | |
value="自由入力" | |
) | |
gen_instruction = gr.Textbox( | |
label="指示", | |
value="", | |
lines=3 | |
) | |
gen_input_text = gr.Textbox(lines=5, label="処理されるテキストを入力してください") | |
with gr.Column(scale=1): | |
gen_author_type = gr.Textbox(label="作家のタイプ", value="新進気鋭のSF小説家") | |
gen_genre = gr.Textbox(label="ジャンル", value="斬新なアイデア") | |
gen_writing_style = gr.Textbox(label="文体", value="切れ味のある文体、流麗な文章") | |
gen_target_audience = gr.Textbox(label="ターゲット読者", value="若い世代") | |
token_multiplier = gr.Slider(minimum=0.1, maximum=20, value=1.5, step=0.1, label="トークン倍率", info="入力トークン数に対する生成トークン数の倍率(0.1〜20)") | |
gen_model = gr.Dropdown(choices=model_files, label="モデル選択", value=character_maker.settings.get('DEFAULT_GEN_MODEL', '')) | |
generate_button = gr.Button("文章生成開始") | |
generated_output = gr.Textbox(label="生成された文章") | |
with gr.Accordion("詳細設定", open=False): | |
gen_temperature = gr.Slider(label="Temperature", value=0.35, minimum=0.0, maximum=1.0, step=0.05, interactive=True) | |
gen_top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.9, minimum=0.0, maximum=1, step=0.05, interactive=True) | |
gen_top_k = gr.Slider(label="Top-k", value=40, minimum=1, maximum=200, step=1, interactive=True) | |
gen_rep_pen = gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05) | |
apply_settings = gr.Button("設定を適用") | |
def apply_gen_settings(temp, top_p, top_k, rep_pen): | |
params.update_generate_parameters(temp, top_p, top_k, rep_pen) | |
return f"設定を適用しました: Temperature={temp}, Top-p={top_p}, Top-k={top_k}, Repetition Penalty={rep_pen}" | |
apply_settings.click( | |
apply_gen_settings, | |
inputs=[gen_temperature, gen_top_p, gen_top_k, gen_rep_pen], | |
outputs=[save_log_output] | |
) | |
generate_button.click( | |
generate_text_with_token_multiplier, | |
inputs=[gen_input_text, gen_author_type, gen_genre, gen_writing_style, gen_target_audience, token_multiplier, gen_model, gen_instruction], | |
outputs=[generated_output] | |
) | |
def update_instruction(choice): | |
instructions = { | |
"自由入力": "", | |
"推敲": "以下のテキストを推敲してください。原文の文体や特徴的な表現は保持しつつ、必要に応じて微調整を加えてください。文章の流れを自然にし、表現を洗練させることが目標ですが、元の雰囲気や個性を損なわないよう注意してください", | |
"プロット作成": "以下のテキストをプロットにしてください。起承転結に分割すること。", | |
"あらすじ作成": "以下のテキストをあらすじにして、簡潔にまとめて下さい。" | |
} | |
return instructions.get(choice, "") | |
instruction_type.change( | |
update_instruction, | |
inputs=[instruction_type], | |
outputs=[gen_instruction] | |
) | |
generate_button.click( | |
generate_text_with_token_multiplier, | |
inputs=[gen_input_text, gen_author_type, gen_genre, gen_writing_style, gen_target_audience, token_multiplier, gen_model, gen_instruction], | |
outputs=[generated_output] | |
) | |
# Gradioインターフェースの "チャットログ閲覧" タブを更新 | |
with gr.Tab("ログ閲覧"): | |
gr.Markdown("## チャットログ閲覧") | |
chatbot_read = gr.Chatbot(elem_id="chatbot_read") | |
log_file_dropdown = gr.Dropdown(label="ログファイル選択", choices=list_log_files()) | |
refresh_log_list_button = gr.Button("ログファイルリストを更新") | |
def update_log_dropdown(): | |
return gr.update(choices=list_log_files()) | |
def load_and_display_chat_log(file_name): | |
chat_history = load_chat_log(file_name) | |
return gr.update(value=chat_history) | |
refresh_log_list_button.click( | |
update_log_dropdown, | |
outputs=[log_file_dropdown] | |
) | |
log_file_dropdown.change( | |
load_and_display_chat_log, | |
inputs=[log_file_dropdown], | |
outputs=[chatbot_read] | |
) | |
async def load_model_and_start_gradio(): | |
# INIファイルが存在しない場合、デフォルトのINIファイルを作成 | |
if not os.path.exists(DEFAULT_INI_FILE): | |
print(f"{DEFAULT_INI_FILE} が見つかりません。デフォルト設定で作成します。") | |
create_default_ini(DEFAULT_INI_FILE) | |
# デフォルト設定の読み込み | |
result = character_maker.load_character(DEFAULT_INI_FILE) | |
print(result) | |
# モデルのロード完了を待つ | |
while not character_maker.model_loaded.is_set(): | |
await asyncio.sleep(1) | |
if not character_maker.llama: | |
print("モデルのロードに失敗しました。アプリケーションを終了します。") | |
return | |
# Gradio インターフェースの構築 | |
build_gradio_interface() | |
ip_address = get_ip_address() | |
starting_port = 7860 | |
port = find_available_port(starting_port) | |
print(f"サーバーのアドレス: http://{ip_address}:{port}") | |
demo.queue() | |
demo.launch( | |
server_name='0.0.0.0', | |
server_port=port, | |
share=False, | |
favicon_path=os.path.join(path, "custom.html") | |
) | |
if __name__ == "__main__": | |
asyncio.run(load_model_and_start_gradio()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment