Spaces:
Runtime error
Runtime error
File size: 2,635 Bytes
78efe79 440418c f3985af 9b2f51a 7262aa5 9b2f51a 407a575 32c38ef f3985af 440418c 1831164 440418c 22dee1c 7262aa5 9b2f51a 78efe79 08baccf dc80b35 9b2f51a 08baccf 78efe79 40d0e92 78efe79 dc80b35 7262aa5 9b2f51a 7262aa5 dc80b35 9b2f51a 0926d14 9b2f51a 34428f1 9b2f51a dc80b35 9b2f51a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import discord
import logging
import os
import uuid
import torch
from huggingface_hub import snapshot_download
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
# λ‘κΉ
μ€μ
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# μΈν
νΈ μ€μ
intents = discord.Intents.default()
intents.message_content = True
# Hugging Face λͺ¨λΈ λ€μ΄λ‘λ
huggingface_token = os.getenv("HF_TOKEN")
model_path = snapshot_download(
repo_id="stabilityai/stable-diffusion-3-medium",
revision="refs/pr/26",
repo_type="model",
ignore_patterns=[".md", "..gitattributes"],
local_dir="stable-diffusion-3-medium",
token=huggingface_token,
)
# λͺ¨λΈ λ‘λ ν¨μ
def load_pipeline(pipeline_type):
if pipeline_type == "text2img":
return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16)
elif pipeline_type == "img2img":
return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
# λλ°μ΄μ€ μ€μ
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# λμ€μ½λ λ΄ ν΄λμ€
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
self.text2img_pipeline = load_pipeline("text2img").to(device)
self.text2img_pipeline.enable_attention_slicing() # λ©λͺ¨λ¦¬ μ΅μ ν
async def on_ready(self):
logging.info(f'{self.user}λ‘ λ‘κ·ΈμΈλμμ΅λλ€!')
async def on_message(self, message):
if message.author == self.user:
return
if message.content.startswith('!image '):
self.is_processing = True
try:
prompt = message.content[len('!image '):]
image_path = await self.generate_image(prompt)
await message.channel.send(file=discord.File(image_path, 'generated_image.png'))
finally:
self.is_processing = False
async def generate_image(self, prompt):
generator = torch.Generator(device=device).manual_seed(torch.seed())
images = self.text2img_pipeline(prompt, num_inference_steps=50, generator=generator)["images"]
image_path = f'/tmp/{uuid.uuid4()}.png'
images[0].save(image_path)
return image_path
# λμ€μ½λ ν ν° λ° λ΄ μ€ν
if __name__ == "__main__":
discord_token = os.getenv('DISCORD_TOKEN')
discord_client = MyClient(intents=intents)
discord_client.run(discord_token)
|