Open-GAMMA / app-backup.py
openfree's picture
Update app-backup.py
f07fdd7
raw
history blame
25.4 kB
#!/usr/bin/env python
import os
import re
import tempfile
from collections.abc import Iterator
from threading import Thread
import json
import requests
import cv2
import gradio as gr
import spaces
import torch
from loguru import logger
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
# CSV/TXT 분석
import pandas as pd
# PDF 텍스트 추출
import PyPDF2
##############################################################################
# SERPHouse API key from environment variable
##############################################################################
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
##############################################################################
# 간단한 키워드 추출 함수 (한글 + 알파벳 + 숫자 + 공백 보존)
##############################################################################
def extract_keywords(text: str, top_k: int = 5) -> str:
"""
1) 한글(가-힣), 영어(a-zA-Z), 숫자(0-9), 공백만 남김
2) 공백 기준 토큰 분리
3) 최대 top_k개만
"""
text = re.sub(r"[^a-zA-Z0-9가-힣\s]", "", text)
tokens = text.split()
key_tokens = tokens[:top_k]
return " ".join(key_tokens)
##############################################################################
# SERPHouse Live endpoint 호출
# - 상위 20개 결과 JSON을 LLM에 넘길 때 link, snippet 등 모두 포함
##############################################################################
def do_web_search(query: str) -> str:
"""
상위 20개 'organic' 결과 item 전체(제목, link, snippet 등)를
JSON 문자열 형태로 반환
"""
try:
url = "https://api.serphouse.com/serp/live"
params = {
"q": query,
"domain": "google.com",
"lang": "en",
"device": "desktop",
"serp_type": "web",
"num_result": "20",
"api_token": SERPHOUSE_API_KEY,
}
resp = requests.get(url, params=params, timeout=30)
resp.raise_for_status()
data = resp.json()
results = data.get("results", {})
organic = results.get("results", {}).get("organic", [])
if not organic:
return "No web search results found."
summary_lines = []
for idx, item in enumerate(organic[:20], start=1):
item_json = json.dumps(item, ensure_ascii=False, indent=2)
summary_lines.append(f"Result {idx}:\n{item_json}\n")
return "\n".join(summary_lines)
except Exception as e:
logger.error(f"Web search failed: {e}")
return f"Web search failed: {str(e)}"
##############################################################################
# 모델/프로세서 로딩
##############################################################################
MAX_CONTENT_CHARS = 4000
model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager"
)
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
##############################################################################
# CSV, TXT, PDF 분석 함수
##############################################################################
def analyze_csv_file(path: str) -> str:
"""
CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시.
"""
try:
df = pd.read_csv(path)
if df.shape[0] > 50 or df.shape[1] > 10:
df = df.iloc[:50, :10]
df_str = df.to_string()
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
except Exception as e:
return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
"""
TXT 파일 전문 읽기. 너무 길면 일부만 표시.
"""
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read()
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
except Exception as e:
return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
def pdf_to_markdown(pdf_path: str) -> str:
"""
PDF → Markdown. 페이지별로 간단히 텍스트 추출.
"""
text_chunks = []
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
max_pages = min(5, len(reader.pages))
for page_num in range(max_pages):
page = reader.pages[page_num]
page_text = page.extract_text() or ""
page_text = page_text.strip()
if page_text:
if len(page_text) > MAX_CONTENT_CHARS // max_pages:
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
if len(reader.pages) > max_pages:
text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
except Exception as e:
return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
full_text = "\n".join(text_chunks)
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
##############################################################################
# 이미지/비디오 업로드 제한 검사
##############################################################################
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
image_count = 0
video_count = 0
for path in paths:
if path.endswith(".mp4"):
video_count += 1
elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
image_count += 1
return image_count, video_count
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
image_count = 0
video_count = 0
for item in history:
if item["role"] != "user" or isinstance(item["content"], str):
continue
if isinstance(item["content"], list) and len(item["content"]) > 0:
file_path = item["content"][0]
if isinstance(file_path, str):
if file_path.endswith(".mp4"):
video_count += 1
elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
image_count += 1
return image_count, video_count
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
media_files = []
for f in message["files"]:
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
media_files.append(f)
new_image_count, new_video_count = count_files_in_new_message(media_files)
history_image_count, history_video_count = count_files_in_history(history)
image_count = history_image_count + new_image_count
video_count = history_video_count + new_video_count
if video_count > 1:
gr.Warning("Only one video is supported.")
return False
if video_count == 1:
if image_count > 0:
gr.Warning("Mixing images and videos is not allowed.")
return False
if "<image>" in message["text"]:
gr.Warning("Using <image> tags with video files is not supported.")
return False
if video_count == 0 and image_count > MAX_NUM_IMAGES:
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
return False
if "<image>" in message["text"]:
image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
image_tag_count = message["text"].count("<image>")
if image_tag_count != len(image_files):
gr.Warning("The number of <image> tags in the text does not match the number of image files.")
return False
return True
##############################################################################
# 비디오 처리
##############################################################################
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
vidcap = cv2.VideoCapture(video_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = max(int(fps), int(total_frames / 10))
frames = []
for i in range(0, total_frames, frame_interval):
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
if len(frames) >= 5:
break
vidcap.release()
return frames
def process_video(video_path: str) -> list[dict]:
content = []
frames = downsample_video(video_path)
for frame in frames:
pil_image, timestamp = frame
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
pil_image.save(temp_file.name)
content.append({"type": "text", "text": f"Frame {timestamp}:"})
content.append({"type": "image", "url": temp_file.name})
logger.debug(f"{content=}")
return content
##############################################################################
# interleaved <image> 처리
##############################################################################
def process_interleaved_images(message: dict) -> list[dict]:
parts = re.split(r"(<image>)", message["text"])
content = []
image_index = 0
image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
for part in parts:
if part == "<image>" and image_index < len(image_files):
content.append({"type": "image", "url": image_files[image_index]})
image_index += 1
elif part.strip():
content.append({"type": "text", "text": part.strip()})
else:
if isinstance(part, str) and part != "<image>":
content.append({"type": "text", "text": part})
return content
##############################################################################
# PDF + CSV + TXT + 이미지/비디오
##############################################################################
def is_image_file(file_path: str) -> bool:
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
def is_video_file(file_path: str) -> bool:
return file_path.endswith(".mp4")
def is_document_file(file_path: str) -> bool:
return (
file_path.lower().endswith(".pdf")
or file_path.lower().endswith(".csv")
or file_path.lower().endswith(".txt")
)
def process_new_user_message(message: dict) -> list[dict]:
if not message["files"]:
return [{"type": "text", "text": message["text"]}]
video_files = [f for f in message["files"] if is_video_file(f)]
image_files = [f for f in message["files"] if is_image_file(f)]
csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
content_list = [{"type": "text", "text": message["text"]}]
for csv_path in csv_files:
csv_analysis = analyze_csv_file(csv_path)
content_list.append({"type": "text", "text": csv_analysis})
for txt_path in txt_files:
txt_analysis = analyze_txt_file(txt_path)
content_list.append({"type": "text", "text": txt_analysis})
for pdf_path in pdf_files:
pdf_markdown = pdf_to_markdown(pdf_path)
content_list.append({"type": "text", "text": pdf_markdown})
if video_files:
content_list += process_video(video_files[0])
return content_list
if "<image>" in message["text"] and image_files:
interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
if content_list and content_list[0]["type"] == "text":
content_list = content_list[1:]
return interleaved_content + content_list
else:
for img_path in image_files:
content_list.append({"type": "image", "url": img_path})
return content_list
##############################################################################
# history -> LLM 메시지 변환
##############################################################################
def process_history(history: list[dict]) -> list[dict]:
messages = []
current_user_content: list[dict] = []
for item in history:
if item["role"] == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
else:
content = item["content"]
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list) and len(content) > 0:
file_path = content[0]
if is_image_file(file_path):
current_user_content.append({"type": "image", "url": file_path})
else:
current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
##############################################################################
# 메인 추론 함수 (web search 체크 시 자동 키워드추출->검색->결과 system msg)
##############################################################################
@spaces.GPU(duration=120)
def run(
message: dict,
history: list[dict],
system_prompt: str = "",
max_new_tokens: int = 512,
use_web_search: bool = False,
web_search_query: str = "",
) -> Iterator[str]:
if not validate_media_constraints(message, history):
yield ""
return
try:
combined_system_msg = ""
if system_prompt.strip():
combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n"
if use_web_search:
user_text = message["text"]
ws_query = extract_keywords(user_text, top_k=5)
if ws_query.strip():
logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
ws_result = do_web_search(ws_query)
combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
else:
combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
messages = []
if combined_system_msg.strip():
messages.append({
"role": "system",
"content": [{"type": "text", "text": combined_system_msg.strip()}],
})
messages.extend(process_history(history))
user_content = process_new_user_message(message)
for item in user_content:
if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
messages.append({"role": "user", "content": user_content})
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
)
t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
t.start()
output = ""
for new_text in streamer:
output += new_text
yield output
except Exception as e:
logger.error(f"Error in run: {str(e)}")
yield f"죄송합니다. 오류가 발생했습니다: {str(e)}"
##############################################################################
# [추가] 별도 함수에서 model.generate(...)를 호출, OOM 캐치
##############################################################################
def _model_gen_with_oom_catch(**kwargs):
"""
별도 스레드에서 OutOfMemoryError를 잡아주기 위해
"""
try:
model.generate(**kwargs)
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[OutOfMemoryError] GPU 메모리가 부족합니다. "
"Max New Tokens을 줄이거나, 프롬프트 길이를 줄여주세요."
)
##############################################################################
# 예시들 (한글화)
##############################################################################
examples = [
[
{
"text": "두 PDF 파일 내용을 비교하라.",
"files": ["assets/additional-examples/pdf.pdf"],
"files": [
"assets/additional-examples/before.pdf",
"assets/additional-examples/after.pdf",
],
}
],
[
{
"text": "CSV 파일 내용을 요약, 분석하라",
"files": ["assets/additional-examples/sample-csv.csv"],
}
],
[
{
"text": "이 영상의 내용을 설명하라",
"files": ["assets/additional-examples/tmp.mp4"],
}
],
[
{
"text": "표지 내용을 설명하고 글자를 읽어주세요.",
"files": ["assets/additional-examples/maz.jpg"],
}
],
[
{
"text": "이미 이 영양제를 <image> 가지고 있고, 이 제품 <image>을 새로 사려 합니다. 함께 섭취할 때 주의해야 할 점이 있을까요?",
"files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
}
],
[
{
"text": "이 적분을 풀어주세요.",
"files": ["assets/additional-examples/4.png"],
}
],
[
{
"text": "이 티켓은 언제 발급된 것이고, 가격은 얼마인가요?",
"files": ["assets/additional-examples/2.png"],
}
],
[
{
"text": "이미지들의 순서를 바탕으로 짧은 이야기를 만들어 주세요.",
"files": [
"assets/sample-images/09-1.png",
"assets/sample-images/09-2.png",
"assets/sample-images/09-3.png",
"assets/sample-images/09-4.png",
"assets/sample-images/09-5.png",
],
}
],
[
{
"text": "이미지의 시각적 요소에서 영감을 받아 시를 작성해주세요.",
"files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
}
],
[
{
"text": "동일한 막대 그래프를 그리는 matplotlib 코드를 작성해주세요.",
"files": ["assets/additional-examples/barchart.png"],
}
],
[
{
"text": "이 세계에서 살고 있을 생물들을 상상해서 묘사해주세요.",
"files": ["assets/sample-images/08.png"],
}
],
[
{
"text": "이미지에 있는 텍스트를 그대로 읽어서 마크다운 형태로 적어주세요.",
"files": ["assets/additional-examples/3.png"],
}
],
[
{
"text": "이 표지판에는 무슨 문구가 적혀 있나요?",
"files": ["assets/sample-images/02.png"],
}
],
[
{
"text": "두 이미지를 비교해서 공통점과 차이점을 말해주세요.",
"files": ["assets/sample-images/03.png"],
}
],
]
##############################################################################
# Gradio UI (Blocks) 구성
##############################################################################
css = """
body {
background: linear-gradient(135deg, #667eea, #764ba2);
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
color: #333;
margin: 0;
padding: 0;
}
.gradio-container {
background: rgba(255, 255, 255, 0.95);
border-radius: 15px;
padding: 30px 40px;
box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3);
margin: 40px auto;
max-width: 1200px;
}
.gradio-container h1 {
color: #333;
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
}
.fillable {
width: 95% !important;
max-width: unset !important;
}
#examples_container {
margin: auto;
width: 90%;
}
#examples_row {
justify-content: center;
}
.sidebar {
background: rgba(255, 255, 255, 0.98);
border-radius: 10px;
padding: 20px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
}
button, .btn {
background: linear-gradient(90deg, #ff8a00, #e52e71);
border: none;
color: #fff;
padding: 12px 24px;
text-transform: uppercase;
font-weight: bold;
letter-spacing: 1px;
border-radius: 5px;
cursor: pointer;
transition: transform 0.2s ease-in-out;
}
button:hover, .btn:hover {
transform: scale(1.05);
}
"""
title_html = """
<h1 align="center" style="margin-bottom: 0.2em;"> 🤗 Vidraft-G3-27B : Multimodal + VLM + Deep Research </h1>
<p align="center" style="font-size:1.1em; color:#555;">
Multimodal Chat Interface + Optional Web Search
</p>
"""
with gr.Blocks(css=css, title="Vidraft-G3-27B") as demo:
gr.Markdown(title_html)
with gr.Row():
# Left Sidebar
with gr.Column(scale=3, variant="panel"):
gr.Markdown("### Menu / Options")
with gr.Row():
web_search_checkbox = gr.Checkbox(
label="Web Search",
value=False,
info="Check to enable a Deep Research(auto keywords) before the chat reply"
)
web_search_text = gr.Textbox(
lines=1,
label="(Unused) Web Search Query",
placeholder="No direct input needed"
)
gr.Markdown("---")
gr.Markdown("#### System Prompt")
system_prompt_box = gr.Textbox(
lines=3,
value=(
"You are a deeply thoughtful AI. Consider problems thoroughly and derive "
"correct solutions through systematic reasoning. Please answer in korean."
),
)
max_tokens_slider = gr.Slider(
label="Max New Tokens",
minimum=100,
maximum=8000,
step=50,
value=2000, # GPU 메모리 절약 위해 기본값 약간 축소
)
gr.Markdown("<br><br>")
# Main ChatInterface
with gr.Column(scale=7):
chat = gr.ChatInterface(
fn=run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
textbox=gr.MultimodalTextbox(
file_types=[
".webp", ".png", ".jpg", ".jpeg", ".gif",
".mp4", ".csv", ".txt", ".pdf"
],
file_count="multiple",
autofocus=True
),
multimodal=True,
additional_inputs=[
system_prompt_box,
max_tokens_slider,
web_search_checkbox,
web_search_text,
],
stop_btn=False,
title="Vidraft-G3-27B",
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths=None,
delete_cache=(1800, 1800),
)
with gr.Row(elem_id="examples_row"):
with gr.Column(scale=12, elem_id="examples_container"):
gr.Markdown("### Example Inputs (click to load)")
gr.Examples(
examples=examples,
inputs=[],
cache_examples=False
)
if __name__ == "__main__":
# share=True 시 HF Spaces에서 경고 발생 - 로컬에서만 동작
# demo.launch(share=True)
demo.launch()