Created
May 14, 2023 18:26
-
-
Save vasmarfas/5fe59944ddc3dd8b25289baa0863f60e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from bot.db import db, DBTables, decrypt | |
import aiohttp | |
import asyncio | |
import time | |
async def job_exists(endpoint): | |
async with aiohttp.ClientSession() as session: | |
r = await session.get( | |
endpoint + "/sdapi/v1/progress", | |
json={ | |
"skip_current_image": True, | |
} | |
) | |
if r.status != 200: | |
return None | |
return (await r.json()).get('state').get('job_count') > 0 | |
async def wait_for_status(ignore_exceptions: bool = False): | |
endpoint = decrypt(db[DBTables.config].get('endpoint')) | |
try: | |
while await job_exists(endpoint): | |
while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time(): | |
await asyncio.sleep(5) | |
db[DBTables.cooldown]['_last_time_status_checked'] = time.time() | |
return | |
except Exception as e: | |
if not ignore_exceptions: | |
raise e | |
return |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from bot.common import bot | |
from bot.config import ADMIN | |
from aiogram import types | |
from bot.db import db, DBTables | |
from bot.utils.cooldown import throttle | |
from bot.modules.api.txt2img import txt2img | |
from bot.modules.api.objects.get_prompt import get_prompt | |
from bot.modules.api.objects.prompt_request import Generated | |
from bot.modules.api.status import wait_for_status | |
from bot.keyboards.image_info import get_img_info_keyboard | |
from bot.utils.errorable_command import wrap_exception | |
from bot.callbacks.factories.image_info import (prompt_only, full_prompt, import_prompt, back) | |
import tracemalloc | |
import deepl | |
tracemalloc.start() | |
auth_key = "c38b34ad-d497-8da2-4b6a-66716ccd3960:fx" | |
nsfw_filter = "milf, sex, sexy, horny, nsfw, dick, vagina, pussy, virgin, nudes, nude, clear transparent bikini, tits, boobs, porn, porno, sexual, butt, open body" | |
nsfw_words = [word.strip() for word in nsfw_filter.split(",")] | |
print(nsfw_words) | |
@wrap_exception([ValueError], custom_loading=True) | |
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) | |
async def generate_command(message: types.Message): | |
chat_member = await bot.get_chat_member(chat_id='-1001864315145', user_id=message.from_id) | |
temp_message = await message.reply("⏳ Enqueued...") | |
await increment_user_counter(message.from_id) | |
if not db[DBTables.config]['enabled']: | |
await message.reply('💔 Generation is disabled by admins now. Try again later') | |
await temp_message.delete() | |
return | |
elif chat_member.status not in ['member', 'creator', 'owner', 'administrator']: | |
# Если пользователь не подписан на канал, отправляем ему сообщение и кнопку с предложением подписаться | |
keyboard = types.InlineKeyboardMarkup() | |
button = types.InlineKeyboardButton('Подписаться', url='https://t.me/+MTmF3uzo91IxOGIy') | |
keyboard.add(button) | |
await message.reply(f'❌To use this bot you need to subscribe to the channel @vasmarfas', reply_markup=keyboard) | |
await temp_message.delete() | |
return | |
# elif (message.chat.id not in db[DBTables.config]['whitelist'] and message.from_id not in db[DBTables.config]['whitelist']): | |
# await message.reply('❌You are not on the white list, access denied. Contact admin @kilisauros for details') | |
# await temp_message.delete() | |
# return | |
elif message.chat.id < 0: | |
await increment_chat_counter(message.chat.id) | |
try: | |
prompt = get_prompt(user_id=message.from_id, | |
prompt_string=message.get_args()) | |
except AttributeError: | |
await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. " | |
f"For example, it can be: <code>masterpiece, best quality, 1girl, white hair, " | |
f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, jacket, " | |
f"outdoors, streets</code>", parse_mode='HTML') | |
return | |
class RussianLettersError(Exception): | |
pass | |
try: | |
async def check_ru(prompt_message: str): | |
lower = set('абвгдеёжзийклмнопрстуфхцчшщъыьэюя') | |
return lower.intersection(prompt_message.lower()) != set() | |
if await check_ru(prompt_message = prompt.prompt): | |
if (message.from_id in db[DBTables.config]['premium']): | |
translated_prompt = await translate_prompt(prompt.prompt) | |
prompt.prompt = str(translated_prompt) | |
else: | |
raise RussianLettersError("RU letters checked, generation canceled.") | |
except RussianLettersError as e: | |
await temp_message.edit_text(f"❌You should use only English letters! Auto-translate available only for premium users!", parse_mode='HTML') | |
return | |
found_word = next((word for word in nsfw_words if word in prompt.prompt), None) | |
if message.chat.id < 0: | |
if message.chat.id not in db[DBTables.config].get('nsfwchats'): | |
if found_word: | |
await temp_message.edit_text(f"Your prompt contains NSFW word '{found_word}'. \n\ | |
NSFW filter is enabled for this chat, edit your promt or contact your chat admin for disabling NSFW filter by /disablensfwfilter") | |
return | |
negative_prompt_status = 0 | |
try: | |
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1 | |
await temp_message.edit_text(f"⏳ Enqueued in position {db[DBTables.queue].get('n', 0)}...") | |
await wait_for_status() | |
await temp_message.edit_text(f"⌛ Generating...") | |
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 | |
if db[DBTables.config]['current_model'] == 'moDi-v1-pruned.ckpt [8067368533]' : | |
prompt.prompt = prompt.prompt + " modern disney style" | |
elif db[DBTables.config]['current_model'] == 'Inkpunk-Diffusion-v2.ckpt [2182245415]' : | |
prompt.prompt = prompt.prompt + " nvinkpunk" | |
elif db[DBTables.config]['current_model'] == 'gta5ArtworkDiffusion_v1.ckpt [607aa02fb8]' : | |
prompt.prompt = prompt.prompt + " gtav style" | |
elif db[DBTables.config]['current_model'] == 'Cyberpunk-Anime-Diffusion.safetensors [ab55b3722e]' : | |
prompt.prompt = prompt.prompt + " in dgs illustration style" | |
elif db[DBTables.config]['current_model'] == 'ghostmix_v12.safetensors [d7465e52e1]' : | |
prompt.prompt = '(masterpiece, top quality, best quality, official art, beautiful and aesthetic:1.2), (1girl:1.3), (fractal art:1.3), ' + prompt.prompt | |
elif db[DBTables.config]['current_model'] == 'mdjrny-v4.safetensors [aba96b389d]' : | |
prompt.prompt = prompt.prompt + " mdjrny-v4 style" | |
elif db[DBTables.config]['current_model'] == 'realisticVisionV20_v20.safetensors [c0d1994c73]' : | |
prompt.prompt = 'RAW photo, ' + prompt.prompt + ", (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" | |
elif db[DBTables.config]['current_model'] == 'robo-diffusion-v1.ckpt [244dbe0dcb]' : | |
prompt.prompt = 'nousr robot ' + prompt.prompt | |
elif db[DBTables.config]['current_model'] == 'AnythingV5_v5PrtRE.safetensors [7f96a1a9ca]' : | |
prompt.prompt = '(masterpiece, best quality)' + prompt.prompt | |
elif db[DBTables.config]['current_model'] == 'dreamlike-anime-1.0.safetensors [aacfb54d9c]' : | |
prompt.prompt = 'photo anime, masterpiece, high quality, absurdres ' + prompt.prompt | |
elif db[DBTables.config]['current_model'] == 'dreamlike-diffusion-1.0.safetensors [c86a1a99b8]' : | |
prompt.prompt = 'dreamlikeart ' + prompt.prompt | |
elif db[DBTables.config]['current_model'] == 'etherBluMix_etherBluMix31.safetensors [3ef76ede61]' : | |
prompt.prompt = 'masterpiece, best quality, ' + prompt.prompt | |
elif db[DBTables.config]['current_model'] == 'sxzLuma_097.safetensors [3709e3c3c9]' : | |
prompt.prompt = prompt.prompt + " realistic" | |
elif db[DBTables.config]['current_model'] == 'waifu_diffusion_pytorch_model.safetensors [dda5a15fe8]' : | |
prompt.prompt = 'masterpiece, best quality, ' + prompt.prompt | |
elif db[DBTables.config]['current_model'] == 'etherRealMix_etherRealMix2.safetensors [96e52f4268]' : | |
prompt.prompt = prompt.prompt + " realistic" | |
#NSFW Negative Filter | |
if (message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN): | |
if (prompt.negative_prompt is str): | |
prompt.negative_prompt = prompt.negative_prompt + nsfw_filter | |
negative_prompt_status = 1 | |
else: | |
prompt.negative_prompt = "bad anatomy, bad proportions, blurry, cloned face, cropped, deformed, disfigured, duplicate, error, extra arms, extra fingers, extra legs, extra limbs, fused fingers, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed limbs, missing arms, missing legs, morbid, mutated hands, mutation, mutilated, out of frame, poorly drawn face, poorly drawn hands, signature, too many fingers, ugly, username, watermark, worst quality, tiling, poorly drawn feet, milf, sexy, horny, nsfw, dick, vagina, pussy, virgin, nudes, nude, clear transparent bikini, tits, boobs, porn, porno, sexual, butt, open body" | |
negative_prompt_status = 0 | |
image = await txt2img(prompt) | |
await increment_request_counter(message.from_id) | |
image_message = await message.reply_photo(photo=image[0]) | |
message_username = "" | |
if message.from_user.username == "kilisauros": | |
message_username = message.from_user.username | |
else: | |
message_username = f"@{message.from_user.username}" | |
#Send photo to SD Image Archive | |
negative_send = prompt.negative_prompt | |
if negative_prompt_status == 0: | |
negative_send = "default" | |
archive_message = f'User ID: {message.from_id} \n \ | |
User Full Name: {message.from_user.full_name} \n \ | |
User username: {message_username} \n \ | |
Chat ID: {message.chat.id} \n \ | |
Chat title: {message.chat.title} \n \ | |
Info: \n \ | |
🖤 Prompt: {prompt.prompt} \n \ | |
🐊 Negative: {negative_send} \n \ | |
💫 Model: {db[DBTables.config]["current_model"]} \n \ | |
🪜 Steps: {prompt.steps} \n \ | |
🧑🎨 CFG Scale: {prompt.cfg_scale} \n \ | |
🖥️ Size: {prompt.width}x{prompt.height} \n \ | |
😀 Restore faces: {prompt.restore_faces} \n \ | |
⚒️ Sampler: {prompt.sampler} \n ' | |
await bot.send_photo(-929754401, photo=image[0], caption=archive_message) | |
db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated( | |
prompt=prompt, | |
seed=image[1]['seed'], | |
model=re.search(r", Model: ([^,]+),", image[1]['infotexts'][0]).groups()[0] | |
) | |
if db[DBTables.config]['current_model'] == 'moDi-v1-pruned.ckpt [8067368533]' : | |
prompt.prompt = prompt.prompt + " modern disney style" | |
elif db[DBTables.config]['current_model'] == 'Inkpunk-Diffusion-v2.ckpt [2182245415]' : | |
prompt.prompt = prompt.prompt + " nvinkpunk" | |
elif db[DBTables.config]['current_model'] == 'gta5ArtworkDiffusion_v1.ckpt [607aa02fb8]' : | |
prompt.prompt = prompt.prompt + " gtav style" | |
elif db[DBTables.config]['current_model'] == 'Cyberpunk-Anime-Diffusion.safetensors [ab55b3722e]' : | |
prompt.prompt = prompt.prompt + " cyberpunk style" | |
else: await message.reply(f'Here is your image: \nCurrent model: {db[DBTables.config]["current_model"]}', | |
reply_markup=get_img_info_keyboard(image_message.photo[0].file_unique_id)) | |
await temp_message.delete() | |
await db[DBTables.config].write() | |
except ValueError as e: | |
await message.reply(f'❌ Error! {e.args[0]}') | |
await temp_message.delete() | |
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 | |
return | |
async def increment_request_counter(user_id): | |
requests = db[DBTables.config].get('requests') or [] | |
print("requests before adding the user_id: ", requests) | |
if user_id not in requests: | |
requests.append(user_id) | |
db[DBTables.config]['requests'] = requests | |
print("requests after adding the user_id: ", requests) | |
else: | |
index = requests.index(user_id) | |
requests[index] += 1 | |
db[DBTables.config]['requests'] = requests | |
print("requests after updating the user_id: ", requests) | |
async def increment_user_counter(user_id): | |
users = db[DBTables.config].get('users') or [] | |
print("users before adding the user_id: ", users) | |
if user_id not in users: | |
users.append(user_id) | |
db[DBTables.config]['users'] = users | |
print("users after adding the user_id: ", users) | |
else: | |
print("User id already exists in the list") | |
async def increment_chat_counter(chat_id): | |
chats = db[DBTables.config].get('chats') or [] | |
print("chats before adding the chat_id: ", chats) | |
if chat_id not in chats: | |
chats.append(chat_id) | |
db[DBTables.config]['chats'] = chats | |
print("chats after adding the chat_id: ", chats) | |
else: | |
print("Chat id already exists in the list") | |
async def translate_prompt(prompt): | |
translator = deepl.Translator(auth_key) | |
translated = translator.translate_text(prompt, target_lang="EN-US") | |
return translated |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment