Files
myfirstprogram/addons/draw/handlers.py
T

196 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import base64
from io import BytesIO
import asyncio
import aiohttp
from PIL import Image
from aiogram import Dispatcher, Bot
from aiogram.types import Message, BufferedInputFile
from aiogram.filters import Command
from models.state import BotState
from config import Config
from storage.message_storage import save_message
from transformers import pipeline
from utils.antispam import saving
logger = logging.getLogger(__name__)
SD_URL = "http://192.168.31.95:7860/sdapi/v1/txt2img"
# Загружаем пайплайн перевода один раз при старте (синхронный)
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ru-en")
async def translate_to_en(text: str) -> str:
try:
# выполняем перевод в отдельном потоке, чтобы не блокировать event loop
result = await asyncio.to_thread(translator, text, max_length=512)
return result[0]["translation_text"]
except Exception as e:
logger.error(f"Ошибка перевода: {e}")
return text
async def generate_img2img(prompt: str, init_image: BytesIO) -> BytesIO | None:
"""
Генерация изображения по методу img2img.
:param prompt: текстовый промт (уже переведённый на английский)
:param init_image: входное изображение в BytesIO
:return: BytesIO с результатом или None при ошибке
"""
try:
# Определяем размеры оригинала
init_image.seek(0)
with Image.open(init_image) as img:
width, height = img.size
# кодируем входное изображение в base64
init_image_base64 = base64.b64encode(init_image.getvalue()).decode("utf-8")
payload = {
"init_images": [init_image_base64],
"prompt": prompt,
"negative_prompt": "blurry, low quality, bad anatomy, watermark, text, cropped",
"steps": 25,
"width": width, # берём ширину оригинала
"height": height, # берём высоту оригинала
"sampler_name": "Euler a",
"scheduler": "Karras", # исправлен ключ
"cfg_scale": 10,
"seed": -1,
"denoising_strength": 0.45,
"restore_faces": True,
"override_settings": {
"sd_model_checkpoint": "waiNSFWIllustrious_v150.safetensors"
},
}
async with aiohttp.ClientSession() as session:
async with session.post(
SD_URL.replace("txt2img", "img2img"), json=payload
) as resp:
if resp.status != 200:
logger.error(f"Stable Diffusion img2img API error: {resp.status}")
return None
r = await resp.json()
image_base64 = r["images"][0]
return BytesIO(base64.b64decode(image_base64))
except Exception as e:
logger.error(f"Ошибка img2img: {e}")
return None
# sd_xl_base_1.0.safetensors
# waiNSFWIllustrious_v150.safetensors
async def generate_image(prompt: str) -> BytesIO | None:
payload = {
"prompt": prompt,
"negative_prompt": "blurry, low quality, bad anatomy, watermark, text, cropped",
"steps": 20,
"width": 1024,
"height": 1024,
"sampler_name": "Euler a", # сэмплер
"cfg_scale": 7, # насколько строго следовать промту
"seed": -1, # -1 = случайный сид
"batch_size": 1, # сколько картинок за раз
"n_iter": 1, # сколько раз повторить генерацию
"restore_faces": False, # восстановление лиц
"tiling": False, # тайлинг для текстур
"enable_hr": False, # highres fix (двухэтапная генерация)
"denoising_strength": 0.7, # сила денойзинга (актуально при enable_hr или img2img)
"hr_scale": 2, # во сколько раз увеличить при highres fix
"hr_upscaler": "Latent", # апскейлер для highres fix
"override_settings": {
"sd_model_checkpoint": "waiNSFWIllustrious_v150.safetensors" # выбор модели
},
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(SD_URL, json=payload) as resp:
if resp.status != 200:
logger.error(f"Stable Diffusion API error: {resp.status}")
return None
r = await resp.json()
image_base64 = r["images"][0]
return BytesIO(base64.b64decode(image_base64))
except Exception as e:
logger.error(f"Ошибка генерации изображения: {e}")
return None
def register_handlers(dp: Dispatcher, state: BotState, bot: Bot):
@dp.message(Command("draw"))
async def draw(message: Message):
save_message(message.chat.id, message.message_id)
if message.from_user.id in Config.BAN:
msg = await message.reply("Вы в бане")
save_message(msg.chat.id, msg.message_id)
else:
user_prompt = message.text.replace("/draw", "").strip()
if not user_prompt:
confirm_msg = await message.answer("❗ Укажи промт после команды /draw")
save_message(confirm_msg.chat.id, confirm_msg.message_id)
return
en_prompt = await translate_to_en(user_prompt)
logger.info(f"Промт переведен: {user_prompt} -> {en_prompt}")
img_bytes = await generate_image(en_prompt)
if img_bytes:
img_bytes.seek(0)
photo = BufferedInputFile(img_bytes.read(), filename="result.png")
msg = await bot.send_photo(chat_id=message.chat.id, photo=photo)
save_message(msg.chat.id, msg.message_id)
else:
error_msg = await message.answer("⚠️ Ошибка при генерации изображения.")
save_message(error_msg.chat.id, error_msg.message_id)
@dp.message(Command("img2img"))
@saving
async def img2img_with_caption(message: Message, bot: Bot):
raw_caption = message.caption or ""
user_prompt = raw_caption.replace("/img2img", "").strip()
if not user_prompt:
await message.answer(
"❗ Укажи промт в подписи к фото после команды /img2img"
)
return
en_prompt = await translate_to_en(user_prompt)
logger.info(f"Промт для img2img переведен: {user_prompt} -> {en_prompt}")
try:
if message.photo:
# Берём последнее (самое большое) фото
photo = message.photo[-1].file_id
# Отправляем в SD API по file_id (как в iadmin)
# Здесь отличие: SD API требует base64, поэтому file_id нужно скачать
# Но логика построена как в iadmin — сначала берём file_id
file = await bot.get_file(photo)
file_bytes = await bot.download_file(file.file_path)
img_bytes = await generate_img2img(
en_prompt, BytesIO(file_bytes.read())
)
if img_bytes:
img_bytes.seek(0)
photo = BufferedInputFile(
img_bytes.read(), filename="img2img_result.png"
)
msg = await bot.send_photo(chat_id=message.chat.id, photo=photo)
else:
msg = await message.answer("⚠️ Ошибка при img2img генерации.")
else:
msg = await message.answer("❗ Пришли фото с подписью /img2img <промт>")
save_message(msg.chat.id, msg.message_id)
except Exception as e:
await message.answer(f"⚠️ Ошибка: {e}")