Spaces:
Runtime error
Runtime error
File size: 3,920 Bytes
78efe79 440418c f3985af 9b2f51a 7262aa5 021392e 9b2f51a 649012a 021392e 407a575 32c38ef f3985af 440418c 1831164 440418c 22dee1c 7262aa5 9b2f51a d6dc9a8 9b2f51a d6dc9a8 9b2f51a 539a18a 9b2f51a 649012a 9b2f51a 649012a 9b2f51a 021392e 539a18a 9b2f51a 78efe79 08baccf dc80b35 9b2f51a 08baccf 78efe79 40d0e92 021392e 78efe79 dc80b35 7262aa5 9b2f51a 021392e 539a18a 7262aa5 dc80b35 539a18a 9b2f51a 539a18a 9b2f51a 0926d14 021392e 539a18a 021392e 539a18a 021392e 9b2f51a 34428f1 9b2f51a dc80b35 9b2f51a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import discord
import logging
import os
import uuid
import torch
import subprocess
from huggingface_hub import snapshot_download
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from transformers import pipeline
# λ‘κΉ
μ€μ
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# μΈν
νΈ μ€μ
intents = discord.Intents.default()
intents.message_content = True
# Hugging Face λͺ¨λΈ λ€μ΄λ‘λ
huggingface_token = os.getenv("HF_TOKEN")
model_path = snapshot_download(
repo_id="Corcelio/mobius",
repo_type="model",
local_dir="mobius",
token=huggingface_token,
)
# λͺ¨λΈ λ‘λ ν¨μ
def load_pipeline(pipeline_type):
logging.debug(f'νμ΄νλΌμΈ λ‘λ μ€: {pipeline_type}')
if pipeline_type == "text2img":
return StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
elif pipeline_type == "img2img":
return StableDiffusionImg2ImgPipeline.from_pretrained(model_path, torch_dtype=torch.float16, use_fast=True)
# λλ°μ΄μ€ μ€μ
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# λ²μ νμ΄νλΌμΈ μ€μ
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# κ³ μ λ λ€κ±°ν°λΈ ν둬ννΈ
negative_prompt = "blur, low quality, bad composition, ugly, disfigured, weird colors, low quality, jpeg artifacts, lowres, grainy, deformed structures, blurry, opaque, low contrast, distorted details, details are low"
# λμ€μ½λ λ΄ ν΄λμ€
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
self.text2img_pipeline = load_pipeline("text2img").to(device)
self.text2img_pipeline.enable_attention_slicing() # λ©λͺ¨λ¦¬ μ΅μ ν
async def on_ready(self):
logging.info(f'{self.user}λ‘ λ‘κ·ΈμΈλμμ΅λλ€!')
subprocess.Popen(["python", "web.py"])
logging.info("web.py μλ²κ° μμλμμ΅λλ€.")
async def on_message(self, message):
if message.author == self.user:
return
if message.content.startswith('!image '):
self.is_processing = True
try:
prompt = message.content[len('!image '):]
prompt_en = translate_prompt(prompt)
logging.debug(f'λ²μλ ν둬ννΈ: {prompt_en}')
logging.debug(f'κ³ μ λ λ€κ±°ν°λΈ ν둬ννΈ: {negative_prompt}')
image_path = await self.generate_image(prompt_en, negative_prompt)
user_id = message.author.id # μ¬μ©μμ ID μΊ‘μ²
await message.channel.send(
f"<@{user_id}> λμ΄ μμ²νμ μ΄λ―Έμ§μ
λλ€.",
file=discord.File(image_path, 'generated_image.png')
)
finally:
self.is_processing = False
async def generate_image(self, prompt, negative_prompt):
generator = torch.Generator(device=device).manual_seed(torch.seed())
images = self.text2img_pipeline(prompt, negative_prompt=negative_prompt, num_inference_steps=50, generator=generator)["images"]
image_path = f'/tmp/{uuid.uuid4()}.png'
images[0].save(image_path)
return image_path
# ν둬ννΈ λ²μ ν¨μ
def translate_prompt(prompt):
logging.debug(f'ν둬ννΈ λ²μ μ€: {prompt}')
translation = translator(prompt, max_length=512)
translated_text = translation[0]['translation_text']
logging.debug(f'λ²μλ ν
μ€νΈ: {translated_text}')
return translated_text
# λμ€μ½λ ν ν° λ° λ΄ μ€ν
if __name__ == "__main__":
discord_token = os.getenv('DISCORD_TOKEN')
discord_client = MyClient(intents=intents)
discord_client.run(discord_token)
|