kai-flx / app.py
seawolf2357's picture
Update app.py
649012a verified
raw
history blame
3.92 kB
import discord
import logging
import os
import uuid
import torch
import subprocess
from huggingface_hub import snapshot_download
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from transformers import pipeline
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# μΈν…νŠΈ μ„€μ •
intents = discord.Intents.default()
intents.message_content = True
# Hugging Face λͺ¨λΈ λ‹€μš΄λ‘œλ“œ
huggingface_token = os.getenv("HF_TOKEN")
model_path = snapshot_download(
repo_id="Corcelio/mobius",
repo_type="model",
local_dir="mobius",
token=huggingface_token,
)
# λͺ¨λΈ λ‘œλ“œ ν•¨μˆ˜
def load_pipeline(pipeline_type):
logging.debug(f'νŒŒμ΄ν”„λΌμΈ λ‘œλ“œ 쀑: {pipeline_type}')
if pipeline_type == "text2img":
return StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
elif pipeline_type == "img2img":
return StableDiffusionImg2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
# λ””λ°”μ΄μŠ€ μ„€μ •
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# λ²ˆμ—­ νŒŒμ΄ν”„λΌμΈ μ„€μ •
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# κ³ μ •λœ λ„€κ±°ν‹°λΈŒ ν”„λ‘¬ν”„νŠΈ
negative_prompt = "blur, low quality, bad composition, ugly, disfigured, weird colors, low quality, jpeg artifacts, lowres, grainy, deformed structures, blurry, opaque, low contrast, distorted details, details are low"
# λ””μŠ€μ½”λ“œ 봇 클래슀
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
self.text2img_pipeline = load_pipeline("text2img").to(device)
self.text2img_pipeline.enable_attention_slicing() # λ©”λͺ¨λ¦¬ μ΅œμ ν™”
async def on_ready(self):
logging.info(f'{self.user}둜 λ‘œκ·ΈμΈλ˜μ—ˆμŠ΅λ‹ˆλ‹€!')
subprocess.Popen(["python", "web.py"])
logging.info("web.py μ„œλ²„κ°€ μ‹œμž‘λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
async def on_message(self, message):
if message.author == self.user:
return
if message.content.startswith('!image '):
self.is_processing = True
try:
prompt = message.content[len('!image '):]
prompt_en = translate_prompt(prompt)
logging.debug(f'λ²ˆμ—­λœ ν”„λ‘¬ν”„νŠΈ: {prompt_en}')
logging.debug(f'κ³ μ •λœ λ„€κ±°ν‹°λΈŒ ν”„λ‘¬ν”„νŠΈ: {negative_prompt}')
image_path = await self.generate_image(prompt_en, negative_prompt)
user_id = message.author.id # μ‚¬μš©μžμ˜ ID 캑처
await message.channel.send(
f"<@{user_id}> λ‹˜μ΄ μš”μ²­ν•˜μ‹  μ΄λ―Έμ§€μž…λ‹ˆλ‹€.",
file=discord.File(image_path, 'generated_image.png')
)
finally:
self.is_processing = False
async def generate_image(self, prompt, negative_prompt):
generator = torch.Generator(device=device).manual_seed(torch.seed())
images = self.text2img_pipeline(prompt, negative_prompt=negative_prompt, num_inference_steps=50, generator=generator)["images"]
image_path = f'/tmp/{uuid.uuid4()}.png'
images[0].save(image_path)
return image_path
# ν”„λ‘¬ν”„νŠΈ λ²ˆμ—­ ν•¨μˆ˜
def translate_prompt(prompt):
logging.debug(f'ν”„λ‘¬ν”„νŠΈ λ²ˆμ—­ 쀑: {prompt}')
translation = translator(prompt, max_length=512)
translated_text = translation[0]['translation_text']
logging.debug(f'λ²ˆμ—­λœ ν…μŠ€νŠΈ: {translated_text}')
return translated_text
# λ””μŠ€μ½”λ“œ 토큰 및 봇 μ‹€ν–‰
if __name__ == "__main__":
discord_token = os.getenv('DISCORD_TOKEN')
discord_client = MyClient(intents=intents)
discord_client.run(discord_token)