kai-flx / app.py
seawolf2357's picture
Update app.py
9b2f51a verified
raw
history blame
2.64 kB
import discord
import logging
import os
import uuid
import torch
from huggingface_hub import snapshot_download
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# μΈν…νŠΈ μ„€μ •
intents = discord.Intents.default()
intents.message_content = True
# Hugging Face λͺ¨λΈ λ‹€μš΄λ‘œλ“œ
huggingface_token = os.getenv("HF_TOKEN")
model_path = snapshot_download(
repo_id="stabilityai/stable-diffusion-3-medium",
revision="refs/pr/26",
repo_type="model",
ignore_patterns=[".md", "..gitattributes"],
local_dir="stable-diffusion-3-medium",
token=huggingface_token,
)
# λͺ¨λΈ λ‘œλ“œ ν•¨μˆ˜
def load_pipeline(pipeline_type):
if pipeline_type == "text2img":
return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16)
elif pipeline_type == "img2img":
return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
# λ””λ°”μ΄μŠ€ μ„€μ •
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# λ””μŠ€μ½”λ“œ 봇 클래슀
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
self.text2img_pipeline = load_pipeline("text2img").to(device)
self.text2img_pipeline.enable_attention_slicing() # λ©”λͺ¨λ¦¬ μ΅œμ ν™”
async def on_ready(self):
logging.info(f'{self.user}둜 λ‘œκ·ΈμΈλ˜μ—ˆμŠ΅λ‹ˆλ‹€!')
async def on_message(self, message):
if message.author == self.user:
return
if message.content.startswith('!image '):
self.is_processing = True
try:
prompt = message.content[len('!image '):]
image_path = await self.generate_image(prompt)
await message.channel.send(file=discord.File(image_path, 'generated_image.png'))
finally:
self.is_processing = False
async def generate_image(self, prompt):
generator = torch.Generator(device=device).manual_seed(torch.seed())
images = self.text2img_pipeline(prompt, num_inference_steps=50, generator=generator)["images"]
image_path = f'/tmp/{uuid.uuid4()}.png'
images[0].save(image_path)
return image_path
# λ””μŠ€μ½”λ“œ 토큰 및 봇 μ‹€ν–‰
if __name__ == "__main__":
discord_token = os.getenv('DISCORD_TOKEN')
discord_client = MyClient(intents=intents)
discord_client.run(discord_token)