Spaces:
Runtime error
Runtime error
import discord | |
import logging | |
import os | |
import uuid | |
import torch | |
from huggingface_hub import snapshot_download | |
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline | |
# λ‘κΉ μ€μ | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()]) | |
# μΈν νΈ μ€μ | |
intents = discord.Intents.default() | |
intents.message_content = True | |
# Hugging Face λͺ¨λΈ λ€μ΄λ‘λ | |
huggingface_token = os.getenv("HF_TOKEN") | |
model_path = snapshot_download( | |
repo_id="stabilityai/stable-diffusion-3-medium", | |
revision="refs/pr/26", | |
repo_type="model", | |
ignore_patterns=[".md", "..gitattributes"], | |
local_dir="stable-diffusion-3-medium", | |
token=huggingface_token, | |
) | |
# λͺ¨λΈ λ‘λ ν¨μ | |
def load_pipeline(pipeline_type): | |
if pipeline_type == "text2img": | |
return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16) | |
elif pipeline_type == "img2img": | |
return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16) | |
# λλ°μ΄μ€ μ€μ | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# λμ€μ½λ λ΄ ν΄λμ€ | |
class MyClient(discord.Client): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.is_processing = False | |
self.text2img_pipeline = load_pipeline("text2img").to(device) | |
self.text2img_pipeline.enable_attention_slicing() # λ©λͺ¨λ¦¬ μ΅μ ν | |
async def on_ready(self): | |
logging.info(f'{self.user}λ‘ λ‘κ·ΈμΈλμμ΅λλ€!') | |
async def on_message(self, message): | |
if message.author == self.user: | |
return | |
if message.content.startswith('!image '): | |
self.is_processing = True | |
try: | |
prompt = message.content[len('!image '):] | |
image_path = await self.generate_image(prompt) | |
await message.channel.send(file=discord.File(image_path, 'generated_image.png')) | |
finally: | |
self.is_processing = False | |
async def generate_image(self, prompt): | |
generator = torch.Generator(device=device).manual_seed(torch.seed()) | |
images = self.text2img_pipeline(prompt, num_inference_steps=50, generator=generator)["images"] | |
image_path = f'/tmp/{uuid.uuid4()}.png' | |
images[0].save(image_path) | |
return image_path | |
# λμ€μ½λ ν ν° λ° λ΄ μ€ν | |
if __name__ == "__main__": | |
discord_token = os.getenv('DISCORD_TOKEN') | |
discord_client = MyClient(intents=intents) | |
discord_client.run(discord_token) | |