Created
May 31, 2023 02:32
-
-
Save iwalton3/55a0dff6a53ccc0fa832d6df23c1cded to your computer and use it in GitHub Desktop.
Discord Exllama Chatbot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from model import ExLlama, ExLlamaCache, ExLlamaConfig | |
from tokenizer import ExLlamaTokenizer | |
from generator import ExLlamaGenerator | |
import argparse | |
import torch | |
from timeit import default_timer as timer | |
torch.set_grad_enabled(False) | |
torch.cuda._lazy_init() | |
import asyncio | |
import traceback | |
import discord | |
from discord import app_commands | |
import re | |
from threading import RLock | |
parser = argparse.ArgumentParser(description = "Simple chatbot example for ExLlama") | |
parser.add_argument("-t", "--tokenizer", type = str, help = "Tokenizer model path", required = True) | |
parser.add_argument("-c", "--config", type = str, help = "Model config path (config.json)", required = True) | |
parser.add_argument("-m", "--model", type = str, help = "Model weights path (.pt or .safetensors file)", required = True) | |
parser.add_argument("-a", "--attention", type = ExLlamaConfig.AttentionMethod.argparse, choices = list(ExLlamaConfig.AttentionMethod), help="Attention method", default = ExLlamaConfig.AttentionMethod.SWITCHED) | |
parser.add_argument("-mm", "--matmul", type = ExLlamaConfig.MatmulMethod.argparse, choices = list(ExLlamaConfig.MatmulMethod), help="Matmul method", default = ExLlamaConfig.MatmulMethod.SWITCHED) | |
parser.add_argument("-mlp", "--mlp", type = ExLlamaConfig.MLPMethod.argparse, choices = list(ExLlamaConfig.MLPMethod), help="Matmul method", default = ExLlamaConfig.MLPMethod.SWITCHED) | |
parser.add_argument("-s", "--stream", type = int, help = "Stream layer interval", default = 0) | |
parser.add_argument("-gs", "--gpu_split", type = str, help = "Comma-separated list of VRAM (in GB) to use per GPU device for model layers, e.g. -gs 20,7,7") | |
parser.add_argument("-dq", "--dequant", type = str, help = "Number of layers (per GPU) to de-quantize at load time") | |
parser.add_argument("-l", "--length", type = int, help = "Maximum sequence length", default = 2048) | |
parser.add_argument("-l-out", "--length-out", type = int, help = "Maximum output", default = 768) | |
parser.add_argument("-l-grace", "--length-grace", type = int, help = "Space to leave in token pool", default = 768) | |
parser.add_argument("-temp", "--temperature", type = float, help = "Temperature", default = 0.72) | |
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K", default = 500) | |
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P", default = 0.65) | |
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P", default = 0.00) | |
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty", default = 1.1) | |
parser.add_argument("-repps", "--repetition_penalty_sustain", type = int, help = "Past length for repetition penalty", default = 256) | |
parser.add_argument("-beams", "--beams", type = int, help = "Number of beams for beam search", default = 1) | |
parser.add_argument("-beamlen", "--beam_length", type = int, help = "Number of future tokens to consider", default = 1) | |
args = parser.parse_args() | |
# It seems the shallow LoRA needs more prompting | |
system_prompt = ( | |
'(The system prompt for your chatbot goes here. It is inserted before messages in a conversation.)' | |
) | |
use_system_prompt = True | |
last_bot_message = None | |
messages = [] | |
# hex encoded since I am tired of seeing these every time I open the file | |
# put any bad words you don't want in bot output here | |
slurs = [ | |
'\x66\x61\x67', | |
'\x66\x61\x67\x67\x6f\x74', | |
'\x74\x72\x61\x6e\x6e\x79', | |
'\x6e\x69\x67\x67\x65\x72', | |
'\x6e\x69\x67\x67\x61', | |
'\x72\x65\x74\x61\x72\x64' | |
] | |
user_regex = re.compile(r'<@!?(\d+)>') | |
emote_regex = re.compile(r'<a?:([a-zA-Z0-9_]+):\d+>') | |
emote_replace_regex = re.compile(r':([a-zA-Z0-9_]+):') | |
word_regex = re.compile(r'([a-zA-Z]+)') | |
promptLock = RLock() | |
# Instantiate model and generator | |
config = ExLlamaConfig(args.config) | |
config.model_path = args.model | |
config.attention_method = args.attention | |
config.matmul_method = args.matmul | |
config.stream_layer_interval = args.stream | |
config.mlp_method = args.mlp | |
if args.length is not None: config.max_seq_len = args.length | |
config.set_auto_map(args.gpu_split) | |
config.set_dequant(args.dequant) | |
model = ExLlama(config) | |
cache = ExLlamaCache(model) | |
tokenizer = ExLlamaTokenizer(args.tokenizer) | |
def get_generator(): | |
generator = ExLlamaGenerator(model, tokenizer, cache) | |
generator.settings = ExLlamaGenerator.Settings() | |
generator.settings.temperature = args.temperature | |
generator.settings.top_k = args.top_k | |
generator.settings.top_p = args.top_p | |
generator.settings.min_p = args.min_p | |
generator.settings.token_repetition_penalty_max = args.repetition_penalty | |
generator.settings.token_repetition_penalty_sustain = args.repetition_penalty_sustain | |
generator.settings.token_repetition_penalty_decay = generator.settings.token_repetition_penalty_sustain // 2 | |
generator.settings.beams = args.beams | |
generator.settings.beam_length = args.beam_length | |
return generator | |
def message_wordcount_ok(messages): | |
messages_word_count = 0 | |
for message in messages: | |
# note this approach using the tokenizer isn't great... | |
# we may optimize it later | |
messages_word_count += tokenizer.encode(message[0]).shape[-1] | |
return messages_word_count <= args.length - args.length_grace | |
def ensure_wordcount_ok(messages): | |
while not message_wordcount_ok(messages): | |
messages.pop(0) | |
messages.pop(0) | |
class Bot: | |
def __init__(self, context=""): | |
self.messages = [] | |
self.last_bot_message = None | |
self.context = context | |
self.lock = asyncio.Lock() | |
self.generator = get_generator() | |
self.gen_cache = "" | |
self.needs_regen = True | |
def sendPrompt(self, prompt: str) -> str: | |
with promptLock: | |
with torch.no_grad(): | |
start = timer() | |
try: | |
if prompt.startswith(self.gen_cache) and not self.needs_regen: | |
self.generator.gen_feed_tokens(tokenizer.encode(prompt[len(self.gen_cache):])) | |
self.gen_cache = prompt | |
else: | |
self.gen_cache = prompt | |
self.generator.gen_begin(tokenizer.encode(prompt)) | |
num_res_tokens = 0 | |
res_line = "" | |
print(f'{self.context}bot: ', flush=True, end='') | |
self.generator.begin_beam_search() | |
for i in range(args.length_out): | |
gen_token = self.generator.beam_search() | |
if gen_token.item() == tokenizer.eos_token_id: | |
self.generator.replace_last_token(tokenizer.newline_token_id) | |
num_res_tokens += 1 | |
text = tokenizer.decode(self.generator.sequence_actual[:, -num_res_tokens:][0]) | |
new_text = text[len(res_line):] | |
res_line += new_text | |
print(new_text, end="", flush=True) | |
# <!end!> is the end string for the model between messages | |
if res_line.endswith("<!end!>"): | |
break | |
self.generator.end_beam_search() | |
self.gen_cache += res_line | |
end = timer() | |
print(f'\n[generated {num_res_tokens} tokens in {end - start:.2f} s at {num_res_tokens/(end - start):.2f} t/s]', flush=True) | |
return res_line.replace("<!end!>", "") | |
except: | |
# don't know state when it failed, so clear cache | |
self.needs_regen = True | |
raise | |
async def reset(self): | |
self.messages = [] | |
async def send_next_message(self, channel: discord.TextChannel, interaction: discord.Interaction = None): | |
ensure_wordcount_ok(self.messages) | |
if interaction is not None: | |
await self._send_next_message(channel, interaction) | |
else: | |
async with channel.typing(): | |
await self._send_next_message(channel) | |
async def _send_next_message(self, channel: discord.TextChannel, interaction: discord.Interaction = None): | |
prompt_messages = [[message, author] for message, author in self.messages] | |
last_message = None | |
for message in prompt_messages: | |
if last_message is not None and message[1] != last_message[1]: | |
last_message[0] += '<!end!>' | |
last_message = message | |
prompt = '' | |
if use_system_prompt: | |
while len(prompt_messages) > 0 and prompt_messages[0] == 'bot': | |
prompt_messages.pop(0) | |
prompt += system_prompt + '<!end!>\n' | |
prompt += '\n'.join(msg[0] for msg in prompt_messages) | |
tries = 0 | |
while True: | |
try: | |
def task(): | |
return self.sendPrompt(prompt + '<!end!>\n') | |
response = await asyncio.to_thread(task) | |
except Exception as e: | |
print("Generation failed!", flush=True) | |
traceback.print_exc() | |
response = None | |
if response is not None and not any(slur in response.lower() for slur in slurs): | |
break | |
tries += 1 | |
if tries > 3: | |
if interaction is not None: | |
await interaction.followup.send('I\'m sorry, I\'m having trouble coming up with a response. Try saying something else!') | |
else: | |
await channel.send('I\'m sorry, I\'m having trouble coming up with a response. Try saying something else!') | |
return | |
self.messages.append((response, 'bot')) | |
if len(response) < 2000: | |
if interaction is not None: | |
await interaction.followup.send(response) | |
self.last_bot_message = await interaction.original_response() | |
else: | |
self.last_bot_message = await channel.send(response) | |
else: | |
acc_messages = [] | |
acc_text = "" | |
for message in response.split('\n'): | |
if len(acc_text) + len(message) > 2000: | |
acc_messages.append(acc_text) | |
acc_text = "" | |
acc_text += message + '\n' | |
acc_messages.append(acc_text) | |
if interaction is not None: | |
await interaction.followup.send(acc_messages.pop(0)) | |
self.last_bot_message = await interaction.original_response() | |
for message in acc_messages: | |
await asyncio.sleep(0.5) | |
self.last_bot_message = await channel.send(message) | |
async def no(self, channel: discord.TextChannel, interaction: discord.Interaction = None, should_defer=True): | |
self.messages = self.messages[:-1] | |
if self.last_bot_message: | |
await self.last_bot_message.edit(content='~~' + self.last_bot_message.content.replace('~~', '') + '~~') | |
self.last_bot_message = None | |
if should_defer: | |
await interaction.response.defer() | |
await self.send_next_message(channel, interaction) | |
async def get_response(self, message: discord.Message): | |
if message.author.bot: | |
self.last_bot_message = message | |
return | |
if message.content.startswith(';'): | |
return | |
if message.content.startswith('!reset'): | |
await message.channel.send('Context reset!') | |
await self.reset() | |
return | |
if message.content.startswith('!help'): | |
if message.channel.guild is not None: | |
await message.delete() | |
await message.channel.send('Chatbot Commands:\n\n !help - display this message\n !reset - forget current discussion\n !no - reject and send a new response') | |
return | |
message.content = message.content.replace('<!end!>', '') | |
async with self.lock: | |
def find_user(match): | |
user_id = int(match.group(1)) | |
for user in message.mentions: | |
if user.id == user_id: | |
return f"@{user.name}" | |
return "@unknown" | |
if message.content == '!n' or message.content == '!no': | |
if message.channel.guild is not None: | |
await message.delete() | |
self.messages = self.messages[:-1] | |
if self.last_bot_message: | |
await self.last_bot_message.edit(content='~~' + self.last_bot_message.content.replace('~~', '') + '~~') | |
self.last_bot_message = None | |
else: | |
message_string = message.content | |
message_string = user_regex.sub(find_user, message_string) | |
message_string = emote_regex.sub(r':\1:', message_string) | |
currMsg = message_string | |
print(f'{self.context}{message.author.name}: {message_string}', flush=True) | |
self.messages.append((currMsg, message.author.name)) | |
await self.send_next_message(message.channel) | |
server_bots = {} | |
dm_bots = {} | |
class MyClient(discord.Client): | |
async def on_ready(self): | |
await tree.sync() | |
self.idleTimer = None | |
print(f'Logged on as {self.user}!', flush=True) | |
async def on_message(self, message: discord.Message): | |
if message.guild is None: | |
if message.author.id not in dm_bots: | |
dm_bots[message.author.id] = Bot(f"({message.author.name}) ") | |
await dm_bots[message.author.id].get_response(message) | |
return | |
if message.channel.name == 'ai-friend': | |
if message.guild.id not in server_bots: | |
server_bots[message.guild.id] = Bot(f"({message.guild.name}) ") | |
await server_bots[message.guild.id].get_response(message) | |
intents = discord.Intents.default() | |
intents.message_content = True | |
intents.members = True | |
client = MyClient(intents=intents) | |
tree = app_commands.CommandTree(client) | |
@tree.command(name='reset', description='Forget the current discussion') | |
async def reset(interaction: discord.Interaction): | |
if interaction.channel.guild is None: | |
if interaction.user.id not in dm_bots: | |
dm_bots[interaction.user.id] = Bot(f"({interaction.user.name}) ") | |
await dm_bots[interaction.user.id].reset() | |
await interaction.response.send_message('Context reset!') | |
return | |
if interaction.channel.name == 'ai-friend': | |
if interaction.guild.id not in server_bots: | |
server_bots[interaction.guild.id] = Bot(f"({interaction.guild.name}) ") | |
await server_bots[interaction.guild.id].reset() | |
await interaction.response.send_message('Context reset!') | |
@tree.command(name='no', description='Reject the last response and send a new one') | |
async def no(interaction: discord.Interaction): | |
should_defer = True | |
if interaction.channel.guild is None: | |
if interaction.user.id not in dm_bots: | |
dm_bots[interaction.user.id] = Bot(f"({interaction.user.name}) ") | |
await dm_bots[interaction.user.id].no(interaction.channel, interaction, should_defer) | |
return | |
if interaction.channel.name == 'ai-friend': | |
if interaction.guild.id not in server_bots: | |
server_bots[interaction.guild.id] = Bot(f"({interaction.guild.name}) ") | |
await server_bots[interaction.guild.id].no(interaction.channel, interaction, should_defer) | |
client.run('(Discord token goes here!)') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment