Robo-Beam / app.py
seawolf2357's picture
Update app.py
6778255
raw
history blame
17.3 kB
#!/usr/bin/env python
import os
import re
import tempfile
from collections.abc import Iterator
from threading import Thread
import cv2
import gradio as gr
import spaces
import torch
from loguru import logger
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import pandas as pd
import PyPDF2
##################################################
# ๊ธฐ๋ณธ ์„ค์ •
##################################################
MAX_CONTENT_CHARS = 8000 # ํ…์ŠคํŠธ๋กœ ์ „๋‹ฌ ์‹œ ์ตœ๋Œ€ ๊ธ€์ž ์ˆ˜
model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager"
)
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
##################################################
# 1) CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜ (๋นˆ ํŒŒ์ผ ๋Œ€๋น„)
##################################################
def analyze_csv_file(path: str) -> str:
try:
df = pd.read_csv(path)
df_str = df.to_string().strip()
if not df_str:
df_str = "(CSV is empty)"
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
except Exception as e:
return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read().strip()
if not text:
text = "(TXT is empty)"
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
except Exception as e:
return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
def pdf_to_markdown(pdf_path: str) -> str:
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
chunks = []
for page_num, page in enumerate(reader.pages, start=1):
ptext = (page.extract_text() or "").strip()
if ptext:
chunks.append(f"## Page {page_num}\n\n{ptext}\n")
full_text = "\n".join(chunks).strip()
if not full_text:
full_text = "(PDF is empty)"
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
except Exception as e:
return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
##################################################
# 2) ์ด๋ฏธ์ง€/๋น„๋””์˜ค ์—…๋กœ๋“œ ์ œํ•œ
##################################################
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
image_count = 0
video_count = 0
for path in paths:
if path.endswith(".mp4"):
video_count += 1
else:
image_count += 1
return image_count, video_count
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
image_count = 0
video_count = 0
for item in history:
# assistant ๋˜๋Š” content๊ฐ€ str์ด๋ฉด ์ œ์™ธ
if item["role"] != "user" or isinstance(item["content"], str):
continue
file_path = item["content"][0]
if file_path.endswith(".mp4"):
video_count += 1
else:
image_count += 1
return image_count, video_count
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
"""
์ด๋ฏธ์ง€/๋น„๋””์˜ค ๊ฐœ์ˆ˜ ์ œํ•œ
"""
media_files = []
for f in message["files"]:
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
media_files.append(f)
new_image_count, new_video_count = count_files_in_new_message(media_files)
history_image_count, history_video_count = count_files_in_history(history)
image_count = history_image_count + new_image_count
video_count = history_video_count + new_video_count
# ๋น„๋””์˜ค 1๊ฐœ ์ดˆ๊ณผ ๋ถˆ๊ฐ€
if video_count > 1:
gr.Warning("Only one video is supported.")
return False
# ๋น„๋””์˜ค+์ด๋ฏธ์ง€ ํ˜ผํ•ฉ ๋ถˆ๊ฐ€
if video_count == 1:
if image_count > 0:
gr.Warning("Mixing images and videos is not allowed.")
return False
if "<image>" in message["text"]:
gr.Warning("Using <image> tags with video files is not supported.")
return False
# ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ ์ œํ•œ
if video_count == 0 and image_count > MAX_NUM_IMAGES:
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
return False
# <image> ํƒœ๊ทธ ์ˆ˜์™€ ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ˆ˜ ์ผ์น˜
if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
gr.Warning("The number of <image> tags in the text does not match the number of images.")
return False
return True
##################################################
# 3) ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
##################################################
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
vidcap = cv2.VideoCapture(video_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = int(fps / 3)
frames = []
for i in range(0, total_frames, frame_interval):
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
def process_video(video_path: str) -> list[dict]:
content = []
frames = downsample_video(video_path)
for pil_image, timestamp in frames:
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
pil_image.save(temp_file.name)
content.append({"type": "text", "text": f"Frame {timestamp}:"})
content.append({"type": "image", "url": temp_file.name})
return content
##################################################
# 4) interleaved <image> ์ฒ˜๋ฆฌ
##################################################
def process_interleaved_images(message: dict) -> list[dict]:
parts = re.split(r"(<image>)", message["text"])
content = []
image_index = 0
for part in parts:
if part == "<image>":
content.append({"type": "image", "url": message["files"][image_index]})
image_index += 1
elif part.strip():
content.append({"type": "text", "text": part.strip()})
else:
if isinstance(part, str) and part != "<image>":
content.append({"type": "text", "text": part})
return content
##################################################
# 5) CSV/PDF/TXT = ํ…์ŠคํŠธ / ์ด๋ฏธ์ง€,๋น„๋””์˜ค = ์‹ค์ œ ๊ฒฝ๋กœ
##################################################
def process_new_user_message(message: dict) -> list[dict]:
user_text = (message["text"] or "").strip() or "(No text)"
if not message["files"]:
return [{"type": "text", "text": user_text}]
video_files = [f for f in message["files"] if f.endswith(".mp4")]
image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
content_list = [{"type": "text", "text": user_text}]
# CSV
for csv_path in csv_files:
csv_analysis = analyze_csv_file(csv_path)
if not csv_analysis.strip():
csv_analysis = "(No CSV content?)"
content_list.append({"type": "text", "text": csv_analysis})
# TXT
for txt_path in txt_files:
txt_analysis = analyze_txt_file(txt_path)
if not txt_analysis.strip():
txt_analysis = "(No TXT content?)"
content_list.append({"type": "text", "text": txt_analysis})
# PDF
for pdf_path in pdf_files:
pdf_md = pdf_to_markdown(pdf_path)
if not pdf_md.strip():
pdf_md = "(No PDF content?)"
content_list.append({"type": "text", "text": pdf_md})
if video_files:
# ํ•˜๋‚˜๋งŒ ์ฒ˜๋ฆฌ
content_list += process_video(video_files[0])
return content_list
if "<image>" in user_text:
return process_interleaved_images(message)
else:
# ์ผ๋ฐ˜ ์ด๋ฏธ์ง€
for img_path in image_files:
content_list.append({"type": "image", "url": img_path})
return content_list
##################################################
# 6) ํžˆ์Šคํ† ๋ฆฌ -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜ (๋น„์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๋Š” ๋ฌด์‹œ)
##################################################
def process_history(history: list[dict]) -> list[dict]:
messages = []
current_user_content = []
for item in history:
if item["role"] == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
else:
# user
content = item["content"]
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
else:
# [ํŒŒ์ผ๊ฒฝ๋กœ]
fpath = content[0]
# ์ด๋ฏธ์ง€๋‚˜ mp4๋งŒ ์œ ์ง€, ๋‚˜๋จธ์ง€๋Š” ์ œ์™ธ
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", fpath, re.IGNORECASE) or fpath.endswith(".mp4"):
current_user_content.append({"type": "image", "url": fpath})
else:
pass
return messages
##################################################
# 7) ๋ฉ”์ธ ์ถ”๋ก  (๋นˆ ํ† ํฐ ๋ฐฉ์–ด)
##################################################
@spaces.GPU(duration=120)
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
if not validate_media_constraints(message, history):
yield ""
return
messages = []
if system_prompt:
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
messages.extend(process_history(history))
user_content = process_new_user_message(message)
messages.append({"role": "user", "content": user_content})
# 1) tokenize=False ํ›„ ํ† ํฐ ๊ธธ์ด ์ฒดํฌ
raw_text = processor.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
token_ids = processor.tokenizer.encode(raw_text, add_special_tokens=False)
if len(token_ids) == 0:
# ๋นˆ ์ž…๋ ฅ โ†’ ์ž„์˜ ๋ฌธ๊ตฌ ์ถ”๊ฐ€
raw_text += " (No content?)"
token_ids = processor.tokenizer.encode(raw_text, add_special_tokens=False)
# 2) ์‹ค์ œ tokenizer
inputs = processor.tokenizer(
raw_text,
return_tensors="pt",
padding=True
)
inputs = {k: v.to(model.device, dtype=torch.bfloat16) for k, v in inputs.items()}
# 3) ์ŠคํŠธ๋ฆฌ๋ฐ ์ƒ์„ฑ
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
"inputs": inputs["input_ids"],
"attention_mask": inputs.get("attention_mask"),
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": 0.3,
"top_p": 0.95,
}
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
t = Thread(target=model.generate, kwargs=gen_kwargs)
t.start()
output = ""
for chunk in streamer:
output += chunk
yield output
##################################################
# 8) ์˜ˆ์‹œ
##################################################
examples = [
[
{
"text": "PDF ํŒŒ์ผ ๋‚ด์šฉ์„ ์š”์•ฝ, ๋ถ„์„ํ•˜๋ผ.",
"files": ["assets/additional-examples/pdf.pdf"],
}
],
[
{
"text": "CSV ํŒŒ์ผ ๋‚ด์šฉ์„ ์š”์•ฝ, ๋ถ„์„ํ•˜๋ผ",
"files": ["assets/additional-examples/sample-csv.csv"],
}
],
[
{
"text": "๋™์ผํ•œ ๋ง‰๋Œ€ ๊ทธ๋ž˜ํ”„๋ฅผ ๊ทธ๋ฆฌ๋Š” matplotlib ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/barchart.png"],
}
],
[
{
"text": "์ด ์˜์ƒ์—์„œ ์ด์ƒํ•œ ์ ์ด ๋ฌด์—‡์ธ๊ฐ€์š”?",
"files": ["assets/additional-examples/tmp.mp4"],
}
],
[
{
"text": "์ด๋ฏธ ์ด ์˜์–‘์ œ๋ฅผ <image> ๊ฐ€์ง€๊ณ  ์žˆ๊ณ , ์ด ์ œํ’ˆ <image>์„ ์ƒˆ๋กœ ์‚ฌ๋ ค ํ•ฉ๋‹ˆ๋‹ค. ํ•จ๊ป˜ ์„ญ์ทจํ•  ๋•Œ ์ฃผ์˜ํ•ด์•ผ ํ•  ์ ์ด ์žˆ์„๊นŒ์š”?",
"files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
}
],
[
{
"text": "์ด๋ฏธ์ง€์˜ ์‹œ๊ฐ์  ์š”์†Œ์—์„œ ์˜๊ฐ์„ ๋ฐ›์•„ ์‹œ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.",
"files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
}
],
[
{
"text": "์ด๋ฏธ์ง€์˜ ์‹œ๊ฐ์  ์š”์†Œ๋ฅผ ํ† ๋Œ€๋กœ ์งง์€ ์•…๊ณก์„ ์ž‘๊ณกํ•ด์ฃผ์„ธ์š”.",
"files": [
"assets/sample-images/07-1.png",
"assets/sample-images/07-2.png",
"assets/sample-images/07-3.png",
"assets/sample-images/07-4.png",
],
}
],
[
{
"text": "์ด ์ง‘์—์„œ ๋ฌด์Šจ ์ผ์ด ์žˆ์—ˆ์„์ง€ ์งง์€ ์ด์•ผ๊ธฐ๋ฅผ ์ง€์–ด๋ณด์„ธ์š”.",
"files": ["assets/sample-images/08.png"],
}
],
[
{
"text": "์ด๋ฏธ์ง€๋“ค์˜ ์ˆœ์„œ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์งง์€ ์ด์•ผ๊ธฐ๋ฅผ ๋งŒ๋“ค์–ด ์ฃผ์„ธ์š”.",
"files": [
"assets/sample-images/09-1.png",
"assets/sample-images/09-2.png",
"assets/sample-images/09-3.png",
"assets/sample-images/09-4.png",
"assets/sample-images/09-5.png",
],
}
],
[
{
"text": "์ด ์„ธ๊ณ„์—์„œ ์‚ด๊ณ  ์žˆ์„ ์ƒ๋ฌผ๋“ค์„ ์ƒ์ƒํ•ด์„œ ๋ฌ˜์‚ฌํ•ด์ฃผ์„ธ์š”.",
"files": ["assets/sample-images/10.png"],
}
],
[
{
"text": "์ด๋ฏธ์ง€์— ์ ํžŒ ํ…์ŠคํŠธ๋ฅผ ์ฝ์–ด์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/1.png"],
}
],
[
{
"text": "์ด ํ‹ฐ์ผ“์€ ์–ธ์ œ ๋ฐœ๊ธ‰๋œ ๊ฒƒ์ด๊ณ , ๊ฐ€๊ฒฉ์€ ์–ผ๋งˆ์ธ๊ฐ€์š”?",
"files": ["assets/additional-examples/2.png"],
}
],
[
{
"text": "์ด๋ฏธ์ง€์— ์žˆ๋Š” ํ…์ŠคํŠธ๋ฅผ ๊ทธ๋Œ€๋กœ ์ฝ์–ด์„œ ๋งˆํฌ๋‹ค์šด ํ˜•ํƒœ๋กœ ์ ์–ด์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/3.png"],
}
],
[
{
"text": "์ด ์ ๋ถ„์„ ํ’€์–ด์ฃผ์„ธ์š”.",
"files": ["assets/additional-examples/4.png"],
}
],
[
{
"text": "์ด ์ด๋ฏธ์ง€๋ฅผ ๊ฐ„๋‹จํžˆ ์บก์…˜์œผ๋กœ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”.",
"files": ["assets/sample-images/01.png"],
}
],
[
{
"text": "์ด ํ‘œ์ง€ํŒ์—๋Š” ๋ฌด์Šจ ๋ฌธ๊ตฌ๊ฐ€ ์ ํ˜€ ์žˆ๋‚˜์š”?",
"files": ["assets/sample-images/02.png"],
}
],
[
{
"text": "๋‘ ์ด๋ฏธ์ง€๋ฅผ ๋น„๊ตํ•ด์„œ ๊ณตํ†ต์ ๊ณผ ์ฐจ์ด์ ์„ ๋งํ•ด์ฃผ์„ธ์š”.",
"files": ["assets/sample-images/03.png"],
}
],
[
{
"text": "์ด๋ฏธ์ง€์— ๋ณด์ด๋Š” ๋ชจ๋“  ์‚ฌ๋ฌผ๊ณผ ๊ทธ ์ƒ‰์ƒ์„ ๋‚˜์—ดํ•ด์ฃผ์„ธ์š”.",
"files": ["assets/sample-images/04.png"],
}
],
[
{
"text": "์žฅ๋ฉด์˜ ๋ถ„์œ„๊ธฐ๋ฅผ ๋ฌ˜์‚ฌํ•ด์ฃผ์„ธ์š”.",
"files": ["assets/sample-images/05.png"],
}
],
]
demo = gr.ChatInterface(
fn=run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
textbox=gr.MultimodalTextbox(
file_types=[
".png", ".jpg", ".jpeg", ".gif", ".webp",
".mp4", ".csv", ".txt", ".pdf"
],
file_count="multiple",
autofocus=True
),
multimodal=True,
additional_inputs=[
gr.Textbox(
label="System Prompt",
value="You are a deeply thoughtful AI. Consider problems thoroughly and derive correct solutions through systematic reasoning. Please answer in korean."
),
gr.Slider(label="Max New Tokens", minimum=100, maximum=8000, step=50, value=2000),
],
stop_btn=False,
title="Gemma 3 27B IT",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths="style.css",
delete_cache=(1800, 1800),
)
if __name__ == "__main__":
demo.launch()