Spaces:
Runtime error
Runtime error
import discord | |
import logging | |
import os | |
import uuid | |
import torch | |
import subprocess | |
from huggingface_hub import snapshot_download | |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline | |
from transformers import pipeline | |
# λ‘κΉ μ€μ | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()]) | |
# μΈν νΈ μ€μ | |
intents = discord.Intents.default() | |
intents.message_content = True | |
# Hugging Face λͺ¨λΈ λ€μ΄λ‘λ | |
huggingface_token = os.getenv("HF_TOKEN") | |
model_path = snapshot_download( | |
repo_id="Corcelio/mobius", | |
repo_type="model", | |
local_dir="mobius", | |
token=huggingface_token, | |
) | |
# λͺ¨λΈ λ‘λ ν¨μ | |
def load_pipeline(pipeline_type): | |
logging.debug(f'νμ΄νλΌμΈ λ‘λ μ€: {pipeline_type}') | |
if pipeline_type == "text2img": | |
return StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True) | |
elif pipeline_type == "img2img": | |
return StableDiffusionImg2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True) | |
# λλ°μ΄μ€ μ€μ | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# λ²μ νμ΄νλΌμΈ μ€μ | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
# κ³ μ λ λ€κ±°ν°λΈ ν둬ννΈ | |
negative_prompt = "blur, low quality, bad composition, ugly, disfigured, weird colors, low quality, jpeg artifacts, lowres, grainy, deformed structures, blurry, opaque, low contrast, distorted details, details are low" | |
# λμ€μ½λ λ΄ ν΄λμ€ | |
class MyClient(discord.Client): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.is_processing = False | |
self.text2img_pipeline = load_pipeline("text2img").to(device) | |
self.text2img_pipeline.enable_attention_slicing() # λ©λͺ¨λ¦¬ μ΅μ ν | |
async def on_ready(self): | |
logging.info(f'{self.user}λ‘ λ‘κ·ΈμΈλμμ΅λλ€!') | |
subprocess.Popen(["python", "web.py"]) | |
logging.info("web.py μλ²κ° μμλμμ΅λλ€.") | |
async def on_message(self, message): | |
if message.author == self.user: | |
return | |
if message.content.startswith('!image '): | |
self.is_processing = True | |
try: | |
prompt = message.content[len('!image '):] | |
prompt_en = translate_prompt(prompt) | |
logging.debug(f'λ²μλ ν둬ννΈ: {prompt_en}') | |
logging.debug(f'κ³ μ λ λ€κ±°ν°λΈ ν둬ννΈ: {negative_prompt}') | |
image_path = await self.generate_image(prompt_en, negative_prompt) | |
user_id = message.author.id # μ¬μ©μμ ID μΊ‘μ² | |
await message.channel.send( | |
f"<@{user_id}> λμ΄ μμ²νμ μ΄λ―Έμ§μ λλ€.", | |
file=discord.File(image_path, 'generated_image.png') | |
) | |
finally: | |
self.is_processing = False | |
async def generate_image(self, prompt, negative_prompt): | |
generator = torch.Generator(device=device).manual_seed(torch.seed()) | |
images = self.text2img_pipeline(prompt, negative_prompt=negative_prompt, num_inference_steps=50, generator=generator)["images"] | |
image_path = f'/tmp/{uuid.uuid4()}.png' | |
images[0].save(image_path) | |
return image_path | |
# ν둬ννΈ λ²μ ν¨μ | |
def translate_prompt(prompt): | |
logging.debug(f'ν둬ννΈ λ²μ μ€: {prompt}') | |
translation = translator(prompt, max_length=512) | |
translated_text = translation[0]['translation_text'] | |
logging.debug(f'λ²μλ ν μ€νΈ: {translated_text}') | |
return translated_text | |
# λμ€μ½λ ν ν° λ° λ΄ μ€ν | |
if __name__ == "__main__": | |
discord_token = os.getenv('DISCORD_TOKEN') | |
discord_client = MyClient(intents=intents) | |
discord_client.run(discord_token) | |