191 lines
8.3 KiB
Python
191 lines
8.3 KiB
Python
import logging
|
||
import base64
|
||
from io import BytesIO
|
||
import asyncio
|
||
import aiohttp
|
||
|
||
from aiogram import Dispatcher, Bot
|
||
from aiogram.types import Message, BufferedInputFile
|
||
from aiogram.filters import Command
|
||
|
||
from models.state import BotState
|
||
from config import Config
|
||
|
||
from storage.message_storage import save_message
|
||
|
||
from transformers import pipeline
|
||
|
||
from utils.antispam import saving
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
SD_URL = "http://192.168.31.95:7860/sdapi/v1/txt2img"
|
||
|
||
# Загружаем пайплайн перевода один раз при старте (синхронный)
|
||
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ru-en")
|
||
|
||
|
||
async def translate_to_en(text: str) -> str:
|
||
try:
|
||
# выполняем перевод в отдельном потоке, чтобы не блокировать event loop
|
||
result = await asyncio.to_thread(translator, text, max_length=512)
|
||
return result[0]["translation_text"]
|
||
except Exception as e:
|
||
logger.error(f"Ошибка перевода: {e}")
|
||
return text
|
||
|
||
|
||
async def generate_img2img(prompt: str, init_image: BytesIO) -> BytesIO | None:
|
||
"""
|
||
Генерация изображения по методу img2img.
|
||
:param prompt: текстовый промт (уже переведённый на английский)
|
||
:param init_image: входное изображение в BytesIO
|
||
:return: BytesIO с результатом или None при ошибке
|
||
"""
|
||
try:
|
||
# кодируем входное изображение в base64
|
||
init_image_base64 = base64.b64encode(init_image.getvalue()).decode("utf-8")
|
||
|
||
payload = {
|
||
"init_images": [init_image_base64],
|
||
"prompt": prompt,
|
||
"negative_prompt": "blurry, low quality, bad anatomy, watermark, text, cropped",
|
||
"steps": 20, # можно 15–20
|
||
"width": 1024, # лучше подставлять размеры исходного фото
|
||
"height": 1024,
|
||
"sampler_name": "Euler a", # мягкий и стабильный для img2img
|
||
"Schedule_type": "Karras",
|
||
"cfg_scale": 6, # чуть ниже, чем для txt2img
|
||
"seed": -1,
|
||
"denoising_strength": 0.8, # 0.3–0.5 для «сохранить стиль», 0.6–0.8 для «перерисовать»
|
||
"restore_faces": True, # если работаешь с людьми
|
||
"override_settings": {
|
||
"sd_model_checkpoint": "waiNSFWIllustrious_v150.safetensors"
|
||
},
|
||
}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
SD_URL.replace("txt2img", "img2img"), json=payload
|
||
) as resp:
|
||
if resp.status != 200:
|
||
logger.error(f"Stable Diffusion img2img API error: {resp.status}")
|
||
return None
|
||
r = await resp.json()
|
||
image_base64 = r["images"][0]
|
||
return BytesIO(base64.b64decode(image_base64))
|
||
|
||
except Exception as e:
|
||
logger.error(f"Ошибка img2img: {e}")
|
||
return None
|
||
|
||
|
||
# sd_xl_base_1.0.safetensors
|
||
# waiNSFWIllustrious_v150.safetensors
|
||
async def generate_image(prompt: str) -> BytesIO | None:
|
||
payload = {
|
||
"prompt": prompt,
|
||
"negative_prompt": "blurry, low quality, bad anatomy, watermark, text, cropped",
|
||
"steps": 20,
|
||
"width": 1024,
|
||
"height": 1024,
|
||
"sampler_name": "Euler a", # сэмплер
|
||
"cfg_scale": 7, # насколько строго следовать промту
|
||
"seed": -1, # -1 = случайный сид
|
||
"batch_size": 1, # сколько картинок за раз
|
||
"n_iter": 1, # сколько раз повторить генерацию
|
||
"restore_faces": False, # восстановление лиц
|
||
"tiling": False, # тайлинг для текстур
|
||
"enable_hr": False, # highres fix (двухэтапная генерация)
|
||
"denoising_strength": 0.7, # сила денойзинга (актуально при enable_hr или img2img)
|
||
"hr_scale": 2, # во сколько раз увеличить при highres fix
|
||
"hr_upscaler": "Latent", # апскейлер для highres fix
|
||
"override_settings": {
|
||
"sd_model_checkpoint": "waiNSFWIllustrious_v150.safetensors" # выбор модели
|
||
},
|
||
}
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(SD_URL, json=payload) as resp:
|
||
if resp.status != 200:
|
||
logger.error(f"Stable Diffusion API error: {resp.status}")
|
||
return None
|
||
r = await resp.json()
|
||
image_base64 = r["images"][0]
|
||
return BytesIO(base64.b64decode(image_base64))
|
||
except Exception as e:
|
||
logger.error(f"Ошибка генерации изображения: {e}")
|
||
return None
|
||
|
||
|
||
def register_handlers(dp: Dispatcher, state: BotState, bot: Bot):
|
||
@dp.message(Command("draw"))
|
||
async def draw(message: Message):
|
||
save_message(message.chat.id, message.message_id)
|
||
if message.from_user.id in Config.BAN:
|
||
msg = await message.reply("Вы в бане")
|
||
save_message(msg.chat.id, msg.message_id)
|
||
else:
|
||
user_prompt = message.text.replace("/draw", "").strip()
|
||
if not user_prompt:
|
||
confirm_msg = await message.answer("❗ Укажи промт после команды /draw")
|
||
save_message(confirm_msg.chat.id, confirm_msg.message_id)
|
||
return
|
||
|
||
en_prompt = await translate_to_en(user_prompt)
|
||
logger.info(f"Промт переведен: {user_prompt} -> {en_prompt}")
|
||
|
||
img_bytes = await generate_image(en_prompt)
|
||
if img_bytes:
|
||
img_bytes.seek(0)
|
||
photo = BufferedInputFile(img_bytes.read(), filename="result.png")
|
||
msg = await bot.send_photo(chat_id=message.chat.id, photo=photo)
|
||
save_message(msg.chat.id, msg.message_id)
|
||
else:
|
||
error_msg = await message.answer("⚠️ Ошибка при генерации изображения.")
|
||
save_message(error_msg.chat.id, error_msg.message_id)
|
||
|
||
@dp.message(Command("img2img"))
|
||
@saving
|
||
async def img2img_with_caption(message: Message, bot: Bot):
|
||
raw_caption = message.caption or ""
|
||
user_prompt = raw_caption.replace("/img2img", "").strip()
|
||
if not user_prompt:
|
||
await message.answer(
|
||
"❗ Укажи промт в подписи к фото после команды /img2img"
|
||
)
|
||
return
|
||
|
||
en_prompt = await translate_to_en(user_prompt)
|
||
logger.info(f"Промт для img2img переведен: {user_prompt} -> {en_prompt}")
|
||
|
||
try:
|
||
if message.photo:
|
||
# Берём последнее (самое большое) фото
|
||
photo = message.photo[-1].file_id
|
||
|
||
# Отправляем в SD API по file_id (как в iadmin)
|
||
# Здесь отличие: SD API требует base64, поэтому file_id нужно скачать
|
||
# Но логика построена как в iadmin — сначала берём file_id
|
||
file = await bot.get_file(photo)
|
||
file_bytes = await bot.download_file(file.file_path)
|
||
|
||
img_bytes = await generate_img2img(
|
||
en_prompt, BytesIO(file_bytes.read())
|
||
)
|
||
if img_bytes:
|
||
img_bytes.seek(0)
|
||
photo = BufferedInputFile(
|
||
img_bytes.read(), filename="img2img_result.png"
|
||
)
|
||
msg = await bot.send_photo(chat_id=message.chat.id, photo=photo)
|
||
else:
|
||
msg = await message.answer("⚠️ Ошибка при img2img генерации.")
|
||
|
||
else:
|
||
msg = await message.answer("❗ Пришли фото с подписью /img2img <промт>")
|
||
save_message(msg.chat.id, msg.message_id)
|
||
except Exception as e:
|
||
await message.answer(f"⚠️ Ошибка: {e}")
|