It's version 0.4
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
import logging
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
from aiogram import Dispatcher, Bot
|
||||
from aiogram.types import Message, FSInputFile, BufferedInputFile
|
||||
from aiogram.filters import Command
|
||||
|
||||
from models.state import BotState
|
||||
from config import Config
|
||||
|
||||
from storage.message_storage import save_message
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from utils.antispam import saving
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SD_URL = "http://192.168.31.95:7860/sdapi/v1/txt2img"
|
||||
|
||||
# Загружаем пайплайн перевода один раз при старте (синхронный)
|
||||
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ru-en")
|
||||
|
||||
|
||||
async def translate_to_en(text: str) -> str:
|
||||
try:
|
||||
# выполняем перевод в отдельном потоке, чтобы не блокировать event loop
|
||||
result = await asyncio.to_thread(translator, text, max_length=512)
|
||||
return result[0]["translation_text"]
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка перевода: {e}")
|
||||
return text
|
||||
|
||||
|
||||
async def generate_img2img(prompt: str, init_image: BytesIO) -> BytesIO | None:
|
||||
"""
|
||||
Генерация изображения по методу img2img.
|
||||
:param prompt: текстовый промт (уже переведённый на английский)
|
||||
:param init_image: входное изображение в BytesIO
|
||||
:return: BytesIO с результатом или None при ошибке
|
||||
"""
|
||||
try:
|
||||
# кодируем входное изображение в base64
|
||||
init_image_base64 = base64.b64encode(init_image.getvalue()).decode("utf-8")
|
||||
|
||||
payload = {
|
||||
"init_images": [init_image_base64],
|
||||
"prompt": prompt,
|
||||
"negative_prompt": "blurry, low quality, bad anatomy, watermark, text, cropped",
|
||||
"steps": 20, # можно 15–20
|
||||
"width": 1024, # лучше подставлять размеры исходного фото
|
||||
"height": 1024,
|
||||
"sampler_name": "Euler a", # мягкий и стабильный для img2img
|
||||
"Schedule_type": "Karras",
|
||||
"cfg_scale": 6, # чуть ниже, чем для txt2img
|
||||
"seed": -1,
|
||||
"denoising_strength": 0.8, # 0.3–0.5 для «сохранить стиль», 0.6–0.8 для «перерисовать»
|
||||
"restore_faces": True, # если работаешь с людьми
|
||||
"override_settings": {
|
||||
"sd_model_checkpoint": "waiNSFWIllustrious_v150.safetensors"
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
SD_URL.replace("txt2img", "img2img"), json=payload
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Stable Diffusion img2img API error: {resp.status}")
|
||||
return None
|
||||
r = await resp.json()
|
||||
image_base64 = r["images"][0]
|
||||
return BytesIO(base64.b64decode(image_base64))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка img2img: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# sd_xl_base_1.0.safetensors
|
||||
# waiNSFWIllustrious_v150.safetensors
|
||||
async def generate_image(prompt: str) -> BytesIO | None:
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": "blurry, low quality, bad anatomy, watermark, text, cropped",
|
||||
"steps": 20,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"sampler_name": "Euler a", # сэмплер
|
||||
"cfg_scale": 7, # насколько строго следовать промту
|
||||
"seed": -1, # -1 = случайный сид
|
||||
"batch_size": 1, # сколько картинок за раз
|
||||
"n_iter": 1, # сколько раз повторить генерацию
|
||||
"restore_faces": False, # восстановление лиц
|
||||
"tiling": False, # тайлинг для текстур
|
||||
"enable_hr": False, # highres fix (двухэтапная генерация)
|
||||
"denoising_strength": 0.7, # сила денойзинга (актуально при enable_hr или img2img)
|
||||
"hr_scale": 2, # во сколько раз увеличить при highres fix
|
||||
"hr_upscaler": "Latent", # апскейлер для highres fix
|
||||
"override_settings": {
|
||||
"sd_model_checkpoint": "waiNSFWIllustrious_v150.safetensors" # выбор модели
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(SD_URL, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Stable Diffusion API error: {resp.status}")
|
||||
return None
|
||||
r = await resp.json()
|
||||
image_base64 = r["images"][0]
|
||||
return BytesIO(base64.b64decode(image_base64))
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка генерации изображения: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def register_handlers(dp: Dispatcher, state: BotState, bot: Bot):
|
||||
@dp.message(Command("draw"))
|
||||
async def draw(message: Message):
|
||||
save_message(message.chat.id, message.message_id)
|
||||
if message.from_user.id in Config.BAN:
|
||||
msg = await message.reply("Вы в бане")
|
||||
save_message(msg.chat.id, msg.message_id)
|
||||
else:
|
||||
user_prompt = message.text.replace("/draw", "").strip()
|
||||
if not user_prompt:
|
||||
confirm_msg = await message.answer("❗ Укажи промт после команды /draw")
|
||||
save_message(confirm_msg.chat.id, confirm_msg.message_id)
|
||||
return
|
||||
|
||||
en_prompt = await translate_to_en(user_prompt)
|
||||
logger.info(f"Промт переведен: {user_prompt} -> {en_prompt}")
|
||||
|
||||
img_bytes = await generate_image(en_prompt)
|
||||
if img_bytes:
|
||||
img_bytes.seek(0)
|
||||
photo = BufferedInputFile(img_bytes.read(), filename="result.png")
|
||||
msg = await bot.send_photo(chat_id=message.chat.id, photo=photo)
|
||||
save_message(msg.chat.id, msg.message_id)
|
||||
else:
|
||||
error_msg = await message.answer("⚠️ Ошибка при генерации изображения.")
|
||||
save_message(error_msg.chat.id, error_msg.message_id)
|
||||
|
||||
@dp.message(Command("img2img"))
|
||||
@saving
|
||||
async def img2img_with_caption(message: Message, bot: Bot):
|
||||
raw_caption = message.caption or ""
|
||||
user_prompt = raw_caption.replace("/img2img", "").strip()
|
||||
if not user_prompt:
|
||||
await message.answer(
|
||||
"❗ Укажи промт в подписи к фото после команды /img2img"
|
||||
)
|
||||
return
|
||||
|
||||
en_prompt = await translate_to_en(user_prompt)
|
||||
logger.info(f"Промт для img2img переведен: {user_prompt} -> {en_prompt}")
|
||||
|
||||
try:
|
||||
if message.photo:
|
||||
# Берём последнее (самое большое) фото
|
||||
photo = message.photo[-1].file_id
|
||||
|
||||
# Отправляем в SD API по file_id (как в iadmin)
|
||||
# Здесь отличие: SD API требует base64, поэтому file_id нужно скачать
|
||||
# Но логика построена как в iadmin — сначала берём file_id
|
||||
file = await bot.get_file(photo)
|
||||
file_bytes = await bot.download_file(file.file_path)
|
||||
|
||||
img_bytes = await generate_img2img(
|
||||
en_prompt, BytesIO(file_bytes.read())
|
||||
)
|
||||
if img_bytes:
|
||||
img_bytes.seek(0)
|
||||
photo = BufferedInputFile(
|
||||
img_bytes.read(), filename="img2img_result.png"
|
||||
)
|
||||
msg = await bot.send_photo(chat_id=message.chat.id, photo=photo)
|
||||
else:
|
||||
msg = await message.answer("⚠️ Ошибка при img2img генерации.")
|
||||
|
||||
else:
|
||||
msg = await message.answer("❗ Пришли фото с подписью /img2img <промт>")
|
||||
save_message(msg.chat.id, msg.message_id)
|
||||
except Exception as e:
|
||||
await message.answer(f"⚠️ Ошибка: {e}")
|
||||
Reference in New Issue
Block a user