Spaces:
Runtime error
Runtime error
File size: 3,416 Bytes
78efe79 440418c f3985af 9b2f51a 7262aa5 021392e 9b2f51a 021392e 407a575 32c38ef f3985af 440418c 1831164 440418c 22dee1c 7262aa5 9b2f51a e6b5c32 9b2f51a e6b5c32 9b2f51a e6b5c32 9b2f51a 021392e 9b2f51a 78efe79 08baccf dc80b35 9b2f51a 08baccf 78efe79 40d0e92 021392e 78efe79 dc80b35 7262aa5 9b2f51a 021392e 7262aa5 dc80b35 9b2f51a 0926d14 021392e 9b2f51a 34428f1 9b2f51a dc80b35 9b2f51a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import discord
import logging
import os
import uuid
import torch
import subprocess
from huggingface_hub import snapshot_download
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
from transformers import pipeline
# λ‘κΉ
μ€μ
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# μΈν
νΈ μ€μ
intents = discord.Intents.default()
intents.message_content = True
# Hugging Face λͺ¨λΈ λ€μ΄λ‘λ
huggingface_token = os.getenv("HF_TOKEN")
model_path = snapshot_download(
repo_id="stabilityai/stable-diffusion-3-medium",
revision="refs/pr/26",
repo_type="model",
ignore_patterns=[".md", "..gitattributes"],
local_dir="stable-diffusion-3-medium",
token=huggingface_token,
)
# λͺ¨λΈ λ‘λ ν¨μ
def load_pipeline(pipeline_type):
logging.debug(f'Loading pipeline: {pipeline_type}')
if pipeline_type == "text2img":
return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
elif pipeline_type == "img2img":
return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
# λλ°μ΄μ€ μ€μ
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# λ²μ νμ΄νλΌμΈ μ€μ
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# λμ€μ½λ λ΄ ν΄λμ€
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
self.text2img_pipeline = load_pipeline("text2img").to(device)
self.text2img_pipeline.enable_attention_slicing() # λ©λͺ¨λ¦¬ μ΅μ ν
async def on_ready(self):
logging.info(f'{self.user}λ‘ λ‘κ·ΈμΈλμμ΅λλ€!')
subprocess.Popen(["python", "web.py"])
logging.info("web.py μλ²κ° μμλμμ΅λλ€.")
async def on_message(self, message):
if message.author == self.user:
return
if message.content.startswith('!image '):
self.is_processing = True
try:
prompt = message.content[len('!image '):]
prompt_en = translate_prompt(prompt)
logging.debug(f'Translated prompt: {prompt_en}')
image_path = await self.generate_image(prompt_en)
await message.channel.send(file=discord.File(image_path, 'generated_image.png'))
finally:
self.is_processing = False
async def generate_image(self, prompt):
generator = torch.Generator(device=device).manual_seed(torch.seed())
images = self.text2img_pipeline(prompt, num_inference_steps=50, generator=generator)["images"]
image_path = f'/tmp/{uuid.uuid4()}.png'
images[0].save(image_path)
return image_path
# ν둬ννΈ λ²μ ν¨μ
def translate_prompt(prompt):
logging.debug(f'Translating prompt: {prompt}')
translation = translator(prompt, max_length=512)
translated_text = translation[0]['translation_text']
logging.debug(f'Translated text: {translated_text}')
return translated_text
# λμ€μ½λ ν ν° λ° λ΄ μ€ν
if __name__ == "__main__":
discord_token = os.getenv('DISCORD_TOKEN')
discord_client = MyClient(intents=intents)
discord_client.run(discord_token)
|