kai-vision / app.py
seawolf2357's picture
Update app.py
e616357 verified
raw
history blame
4.75 kB
import discord
import logging
import os
import asyncio # asyncio ๋ชจ๋“ˆ ์ถ”๊ฐ€
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import re
import requests
from PIL import Image
import io
import gradio as gr
import threading
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.DEBUG, format='%(asctime:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# ๋””์Šค์ฝ”๋“œ ์ธํ…ํŠธ ์„ค์ •
intents = discord.Intents.default()
intents.message_content = True
intents.messages = True
intents.guilds = True
intents.guild_messages = True
# PaliGemma ๋ชจ๋ธ ์„ค์ • (CPU ๋ชจ๋“œ)
model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
def modify_caption(caption: str) -> str:
prefix_substrings = [
('captured from ', ''),
('captured at ', '')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
async def create_captions_rich(image: Image.Image) -> str:
prompt = "caption en"
image_tensor = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
image_tensor = (image_tensor * 255).type(torch.uint8)
model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cpu")
input_len = model_inputs["input_ids"].shape[-1]
loop = asyncio.get_event_loop()
generation = await loop.run_in_executor(
None,
lambda: model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
modified_caption = modify_caption(decoded)
return modified_caption
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
def create_captions_rich_sync(image):
return asyncio.run(create_captions_rich(image))
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>PaliGemma Fine-tuned for Long Captioning<center><h1>")
with gr.Tab(label="PaliGemma Long Captioner"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Caption")
submit_btn.click(create_captions_rich_sync, [input_img], [output])
# Gradio ์„œ๋ฒ„๋ฅผ ๋น„๋™๊ธฐ์ ์œผ๋กœ ์‹คํ–‰
def run_gradio():
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("GRADIO_SERVER_PORT", 7861)),
inbrowser=True
)
# ํŠน์ • ์ฑ„๋„ ID ์„ค์ •
SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID", "123456789012345678")) # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๋˜๋Š” ์ง์ ‘ ์„ค์ •
# ๋””์Šค์ฝ”๋“œ ๋ด‡ ์„ค์ •
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
async def on_ready(self):
logging.info(f'{self.user}๋กœ ๋กœ๊ทธ์ธ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!')
threading.Thread(target=run_gradio, daemon=True).start()
logging.info("Gradio ์„œ๋ฒ„๊ฐ€ ์‹œ์ž‘๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
async def on_message(self, message):
if message.author == self.user:
return
if not self.is_message_in_specific_channel(message):
return
if self.is_processing:
return
self.is_processing = True
try:
if message.attachments:
image_url = message.attachments[0].url
response = await process_image(image_url, message)
await message.channel.send(response)
finally:
self.is_processing = False
def is_message_in_specific_channel(self, message):
return message.channel.id == SPECIFIC_CHANNEL_ID or (
isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID
)
async def process_image(image_url, message):
image = await download_image(image_url)
caption = await create_captions_rich(image)
return f"{message.author.mention}, ์ธ์‹๋œ ์ด๋ฏธ์ง€ ์„ค๋ช…: {caption}"
async def download_image(url):
response = requests.get(url)
image = Image.open(io.BytesIO(response.content)).convert("RGB") # ์ด๋ฏธ์ง€ ๋ณ€ํ™˜
return image
if __name__ == "__main__":
discord_client = MyClient(intents=intents)
discord_client.run(os.getenv('DISCORD_TOKEN'))