Spaces:
Runtime error
Runtime error
| import discord | |
| import logging | |
| import os | |
| import asyncio | |
| from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
| import torch | |
| import re | |
| import requests | |
| from PIL import Image | |
| import io | |
| import gradio as gr | |
| import threading | |
| from huggingface_hub import InferenceClient | |
| # ๋ก๊น ์ค์ | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s:%(message)s', handlers=[logging.StreamHandler()]) | |
| # ๋์ค์ฝ๋ ์ธํ ํธ ์ค์ | |
| intents = discord.Intents.default() | |
| intents.message_content = True | |
| intents.messages = True | |
| intents.guilds = True | |
| intents.guild_messages = True | |
| # ์ถ๋ก API ํด๋ผ์ด์ธํธ ์ค์ | |
| hf_client = InferenceClient("CohereForAI/aya-23-35B", token=os.getenv("HF_TOKEN")) | |
| # PaliGemma ๋ชจ๋ธ ์ค์ (CPU ๋ชจ๋) | |
| model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval() | |
| processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner") | |
| # ๋ํ ํ์คํ ๋ฆฌ๋ฅผ ์ ์ฅํ ์ ์ญ ๋ณ์ | |
| conversation_history = [] | |
| def modify_caption(caption: str) -> str: | |
| prefix_substrings = [ | |
| ('captured from ', ''), | |
| ('captured at ', '') | |
| ] | |
| pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) | |
| replacers = {opening: replacer for opening, replacer in prefix_substrings} | |
| def replace_fn(match): | |
| return replacers[match.group(0)] | |
| return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) | |
| async def create_captions_rich(image: Image.Image) -> str: | |
| prompt = "caption en" | |
| image_tensor = processor(images=image, return_tensors="pt").pixel_values.to("cpu") | |
| image_tensor = (image_tensor * 255).type(torch.uint8) | |
| model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cpu") | |
| input_len = model_inputs["input_ids"].shape[-1] | |
| loop = asyncio.get_event_loop() | |
| generation = await loop.run_in_executor( | |
| None, | |
| lambda: model.generate(**model_inputs, max_new_tokens=256, do_sample=False) | |
| ) | |
| generation = generation[0][input_len:] | |
| decoded = processor.decode(generation, skip_special_tokens=True) | |
| modified_caption = modify_caption(decoded) | |
| return modified_caption | |
| async def translate_to_korean(text: str) -> str: | |
| messages = [ | |
| {"role": "system", "content": "Translate the following text from English to Korean."}, | |
| {"role": "user", "content": text} | |
| ] | |
| loop = asyncio.get_event_loop() | |
| response = await loop.run_in_executor( | |
| None, | |
| lambda: hf_client.chat_completion( | |
| messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85 | |
| ) | |
| ) | |
| full_response = [] | |
| for part in response: | |
| if part.choices and part.choices[0].delta and part.choices[0].delta.content: | |
| full_response.append(part.choices[0].delta.content) | |
| full_response_text = ''.join(full_response) | |
| return full_response_text.strip() | |
| async def interact_with_model(user_input: str) -> str: | |
| global conversation_history | |
| conversation_history.append({"role": "user", "content": user_input}) | |
| messages = [ | |
| {"role": "system", "content": "Translate the following text from English to Korean and respond as if you are an assistant who provides detailed answers in Korean."}, | |
| ] + conversation_history | |
| loop = asyncio.get_event_loop() | |
| response = await loop.run_in_executor( | |
| None, | |
| lambda: hf_client.chat_completion( | |
| messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85 | |
| ) | |
| ) | |
| full_response = [] | |
| for part in response: | |
| if part.choices and part.choices[0].delta and part.choices[0].delta.content: | |
| full_response.append(part.choices[0].delta.content) | |
| full_response_text = ''.join(full_response) | |
| conversation_history.append({"role": "assistant", "content": full_response_text}) | |
| return full_response_text.strip() | |
| # Gradio ์ธํฐํ์ด์ค ์ค์ | |
| def create_captions_rich_sync(image): | |
| caption = asyncio.run(create_captions_rich(image)) | |
| translated_caption = asyncio.run(translate_to_korean(caption)) | |
| return translated_caption | |
| css = """ | |
| #mkd { | |
| height: 500px; | |
| overflow: auto; | |
| border: 1px solid #ccc; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML("<h1><center>PaliGemma Fine-tuned for Long Captioning<center><h1>") | |
| with gr.Tab(label="PaliGemma Long Captioner"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(label="Input Picture") | |
| submit_btn = gr.Button(value="Submit") | |
| output = gr.Text(label="Caption") | |
| submit_btn.click(create_captions_rich_sync, [input_img], [output]) | |
| # Gradio ์๋ฒ๋ฅผ ๋น๋๊ธฐ์ ์ผ๋ก ์คํ | |
| def run_gradio(): | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("GRADIO_SERVER_PORT", 7861)), | |
| inbrowser=True | |
| ) | |
| # ํน์ ์ฑ๋ ID ์ค์ | |
| SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID", "123456789012345678")) | |
| # ๋์ค์ฝ๋ ๋ด ์ค์ | |
| class MyClient(discord.Client): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.is_processing = False | |
| async def on_ready(self): | |
| logging.info(f'{self.user}๋ก ๋ก๊ทธ์ธ๋์์ต๋๋ค!') | |
| threading.Thread(target=run_gradio, daemon=True).start() | |
| logging.info("Gradio ์๋ฒ๊ฐ ์์๋์์ต๋๋ค.") | |
| async def on_message(self, message): | |
| if message.author == self.user: | |
| return | |
| if not self.is_message_in_specific_channel(message): | |
| return | |
| if self.is_processing: | |
| return | |
| self.is_processing = True | |
| try: | |
| if message.attachments: | |
| image_url = message.attachments[0].url | |
| response = await process_image(image_url, message) | |
| await message.channel.send(response) | |
| else: | |
| response = await interact_with_model(message.content) | |
| await message.channel.send(response) | |
| finally: | |
| self.is_processing = False | |
| def is_message_in_specific_channel(self, message): | |
| return message.channel.id == SPECIFIC_CHANNEL_ID or ( | |
| isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID | |
| ) | |
| async def process_image(image_url, message): | |
| image = await download_image(image_url) | |
| caption = await create_captions_rich(image) | |
| translated_caption = await translate_to_korean(caption) | |
| intro_message = f"{message.author.mention}, ์ธ์๋ ์ด๋ฏธ์ง ์ค๋ช : {translated_caption}\n\n์ง๋ฌธ์ด ์์ผ๋ฉด ๋ฌผ์ด๋ณด์ธ์!" | |
| return intro_message | |
| async def download_image(url): | |
| response = requests.get(url) | |
| image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
| return image | |
| if __name__ == "__main__": | |
| discord_client = MyClient(intents=intents) | |
| discord_client.run(os.getenv('DISCORD_TOKEN')) | |