ROBO-R1984 / app.py
openfree's picture
Update app.py
9d13074 verified
raw
history blame
34.5 kB
#!/usr/bin/env python3
import os
import re
import tempfile
import gc
from collections.abc import Iterator
from threading import Thread
import json
import requests
import cv2
import gradio as gr
import spaces
import torch
import numpy as np
from loguru import logger
from PIL import Image
import time
import warnings
from typing import Dict, List, Optional, Union
import base64
from io import BytesIO
# llama-cpp-python for GGUF
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava16ChatHandler
# Model download
from huggingface_hub import hf_hub_download
# CSV/TXT ๋ถ„์„
import pandas as pd
# PDF ํ…์ŠคํŠธ ์ถ”์ถœ
import PyPDF2
warnings.filterwarnings('ignore')
print("๐ŸŽฎ ๋กœ๋ด‡ ์‹œ๊ฐ ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™” (Gemma3-4B GGUF Q4_K_M)...")
##############################################################################
# ์ƒ์ˆ˜ ์ •์˜
##############################################################################
MAX_CONTENT_CHARS = 2000
MAX_INPUT_LENGTH = 2096
MAX_NUM_IMAGES = 5
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
##############################################################################
# ์ „์—ญ ๋ณ€์ˆ˜
##############################################################################
llm = None
model_loaded = False
model_name = "Gemma3-4B-GGUF-Q4_K_M"
##############################################################################
# ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
##############################################################################
def clear_cuda_cache():
"""CUDA ์บ์‹œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋น„์›๋‹ˆ๋‹ค."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
##############################################################################
# ํ‚ค์›Œ๋“œ ์ถ”์ถœ ํ•จ์ˆ˜
##############################################################################
def extract_keywords(text: str, top_k: int = 5) -> str:
"""ํ‚ค์›Œ๋“œ ์ถ”์ถœ"""
text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
tokens = text.split()
seen = set()
unique_tokens = []
for token in tokens:
if token not in seen and len(token) > 1:
seen.add(token)
unique_tokens.append(token)
key_tokens = unique_tokens[:top_k]
return " ".join(key_tokens)
##############################################################################
# ์›น ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
##############################################################################
def do_web_search(query: str) -> str:
"""SerpHouse API๋ฅผ ์‚ฌ์šฉํ•œ ์›น ๊ฒ€์ƒ‰"""
try:
url = "https://api.serphouse.com/serp/live"
params = {
"q": query,
"domain": "google.com",
"serp_type": "web",
"device": "desktop",
"lang": "ko",
"num": "10"
}
headers = {
"Authorization": f"Bearer {SERPHOUSE_API_KEY}"
}
logger.info(f"์›น ๊ฒ€์ƒ‰ ์ค‘... ๊ฒ€์ƒ‰์–ด: {query}")
response = requests.get(url, headers=headers, params=params, timeout=60)
response.raise_for_status()
data = response.json()
results = data.get("results", {})
organic = results.get("organic", []) if isinstance(results, dict) else []
if not organic:
return "๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
max_results = min(10, len(organic))
limited_organic = organic[:max_results]
summary_lines = []
for idx, item in enumerate(limited_organic, start=1):
title = item.get("title", "์ œ๋ชฉ ์—†์Œ")
link = item.get("link", "#")
snippet = item.get("snippet", "์„ค๋ช… ์—†์Œ")
displayed_link = item.get("displayed_link", link)
summary_lines.append(
f"### ๊ฒฐ๊ณผ {idx}: {title}\n\n"
f"{snippet}\n\n"
f"**์ถœ์ฒ˜**: [{displayed_link}]({link})\n\n"
f"---\n"
)
instructions = """# ์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
์•„๋ž˜๋Š” ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ๋‹ต๋ณ€ ์‹œ ์ด ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์„ธ์š”:
1. ๊ฐ ๊ฒฐ๊ณผ์˜ ์ œ๋ชฉ, ๋‚ด์šฉ, ์ถœ์ฒ˜ ๋งํฌ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”
2. ๊ด€๋ จ ์ถœ์ฒ˜๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ธ์šฉํ•˜์„ธ์š”
3. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜์˜ ์ •๋ณด๋ฅผ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€ํ•˜์„ธ์š”
"""
search_results = instructions + "\n".join(summary_lines)
return search_results
except Exception as e:
logger.error(f"์›น ๊ฒ€์ƒ‰ ์‹คํŒจ: {e}")
return f"์›น ๊ฒ€์ƒ‰ ์‹คํŒจ: {str(e)}"
##############################################################################
# ๋ฌธ์„œ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
##############################################################################
def analyze_csv_file(path: str) -> str:
"""CSV ํŒŒ์ผ ๋ถ„์„"""
try:
df = pd.read_csv(path)
if df.shape[0] > 50 or df.shape[1] > 10:
df = df.iloc[:50, :10]
df_str = df.to_string()
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(์ค‘๋žต)..."
return f"**[CSV ํŒŒ์ผ: {os.path.basename(path)}]**\n\n{df_str}"
except Exception as e:
return f"CSV ์ฝ๊ธฐ ์‹คํŒจ ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
"""TXT ํŒŒ์ผ ๋ถ„์„"""
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read()
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(์ค‘๋žต)..."
return f"**[TXT ํŒŒ์ผ: {os.path.basename(path)}]**\n\n{text}"
except Exception as e:
return f"TXT ์ฝ๊ธฐ ์‹คํŒจ ({os.path.basename(path)}): {str(e)}"
def pdf_to_markdown(pdf_path: str) -> str:
"""PDF๋ฅผ ๋งˆํฌ๋‹ค์šด์œผ๋กœ ๋ณ€ํ™˜"""
text_chunks = []
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
max_pages = min(5, len(reader.pages))
for page_num in range(max_pages):
page = reader.pages[page_num]
page_text = page.extract_text() or ""
page_text = page_text.strip()
if page_text:
if len(page_text) > MAX_CONTENT_CHARS // max_pages:
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(์ค‘๋žต)"
text_chunks.append(f"## ํŽ˜์ด์ง€ {page_num+1}\n\n{page_text}\n")
if len(reader.pages) > max_pages:
text_chunks.append(f"\n...({max_pages}/{len(reader.pages)} ํŽ˜์ด์ง€ ํ‘œ์‹œ)...")
except Exception as e:
return f"PDF ์ฝ๊ธฐ ์‹คํŒจ ({os.path.basename(pdf_path)}): {str(e)}"
full_text = "\n".join(text_chunks)
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(์ค‘๋žต)..."
return f"**[PDF ํŒŒ์ผ: {os.path.basename(pdf_path)}]**\n\n{full_text}"
##############################################################################
# ์ด๋ฏธ์ง€๋ฅผ base64๋กœ ๋ณ€ํ™˜
##############################################################################
def image_to_base64_data_uri(image: Union[np.ndarray, Image.Image]) -> str:
"""์ด๋ฏธ์ง€๋ฅผ base64 data URI๋กœ ๋ณ€ํ™˜"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
##############################################################################
# ๋ชจ๋ธ ๋กœ๋“œ
##############################################################################
def download_model_files():
"""Hugging Face Hub์—์„œ ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ"""
# ์—ฌ๋Ÿฌ ๊ฐ€๋Šฅํ•œ ์ €์žฅ์†Œ ์‹œ๋„
model_repos = [
# ์ฒซ ๋ฒˆ์งธ ์‹œ๋„: ์ผ๋ฐ˜์ ์ธ Gemma 3 4B GGUF
{
"repo": "Mungert/gemma-3-4b-it-gguf",
"model": "google_gemma-3-4b-it-q4_k_m.gguf",
"mmproj": "google_gemma-3-4b-it-mmproj-bf16.gguf"
},
# ๋‘ ๋ฒˆ์งธ ์‹œ๋„: LM Studio ๋ฒ„์ „
{
"repo": "lmstudio-community/gemma-3-4b-it-GGUF",
"model": "gemma-3-4b-it-Q4_K_M.gguf",
"mmproj": "gemma-3-4b-it-mmproj-f16.gguf"
},
# ์„ธ ๋ฒˆ์งธ ์‹œ๋„: unsloth ๋ฒ„์ „
{
"repo": "unsloth/gemma-3-4b-it-GGUF",
"model": "gemma-3-4b-it.Q4_K_M.gguf",
"mmproj": "gemma-3-4b-it.mmproj.gguf"
}
]
for repo_info in model_repos:
try:
logger.info(f"์ €์žฅ์†Œ ์‹œ๋„: {repo_info['repo']}")
# ๋ฉ”์ธ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ
model_filename = repo_info["model"]
logger.info(f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ค‘: {model_filename}")
model_path = hf_hub_download(
repo_id=repo_info["repo"],
filename=model_filename,
resume_download=True,
local_files_only=False
)
# Vision projection ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
mmproj_filename = repo_info["mmproj"]
logger.info(f"Vision ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์ค‘: {mmproj_filename}")
try:
mmproj_path = hf_hub_download(
repo_id=repo_info["repo"],
filename=mmproj_filename,
resume_download=True,
local_files_only=False
)
except:
# mmproj ํŒŒ์ผ์ด ์—†์„ ์ˆ˜๋„ ์žˆ์Œ
logger.warning(f"Vision ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {mmproj_filename}")
logger.warning("ํ…์ŠคํŠธ ์ „์šฉ ๋ชจ๋“œ๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
mmproj_path = None
logger.info(f"โœ… ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์„ฑ๊ณต!")
logger.info(f"๋ชจ๋ธ ๊ฒฝ๋กœ: {model_path}")
if mmproj_path:
logger.info(f"Vision ๊ฒฝ๋กœ: {mmproj_path}")
return model_path, mmproj_path
except Exception as e:
logger.error(f"์ €์žฅ์†Œ {repo_info['repo']} ์‹œ๋„ ์‹คํŒจ: {e}")
continue
# ๋ชจ๋“  ์‹œ๋„๊ฐ€ ์‹คํŒจํ•œ ๊ฒฝ์šฐ
raise Exception("์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GGUF ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ธํ„ฐ๋„ท ์—ฐ๊ฒฐ์„ ํ™•์ธํ•˜์„ธ์š”.")
@spaces.GPU(duration=120)
def load_model():
global llm, model_loaded
if model_loaded:
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋กœ๋“œ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.")
return True
try:
logger.info("Gemma3-4B GGUF Q4_K_M ๋ชจ๋ธ ๋กœ๋”ฉ ์‹œ์ž‘...")
clear_cuda_cache()
# ๋ชจ๋ธ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
model_path, mmproj_path = download_model_files()
# GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ
n_gpu_layers = -1 if torch.cuda.is_available() else 0
# ์ฑ„ํŒ… ํ•ธ๋“ค๋Ÿฌ ์ƒ์„ฑ (๋น„์ „ ์ง€์› - mmproj๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ๋งŒ)
chat_handler = None
if mmproj_path:
try:
chat_handler = Llava16ChatHandler(
clip_model_path=mmproj_path,
verbose=False
)
logger.info("โœ… Vision ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต")
except Exception as e:
logger.warning(f"Vision ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ, ํ…์ŠคํŠธ ์ „์šฉ ๋ชจ๋“œ๋กœ ์ „ํ™˜: {e}")
chat_handler = None
# ๋ชจ๋ธ ๋กœ๋“œ
llm_params = {
"model_path": model_path,
"n_ctx": 4096, # ์ปจํ…์ŠคํŠธ ํฌ๊ธฐ
"n_gpu_layers": n_gpu_layers, # GPU ๋ ˆ์ด์–ด
"n_threads": 8, # CPU ์Šค๋ ˆ๋“œ
"verbose": False,
"seed": 42,
}
# chat_handler๊ฐ€ ์žˆ์œผ๋ฉด ์ถ”๊ฐ€
if chat_handler:
llm_params["chat_handler"] = chat_handler
llm_params["logits_all"] = True # ๋น„์ „ ๋ชจ๋ธ์— ํ•„์š”
llm = Llama(**llm_params)
model_loaded = True
logger.info(f"โœ… Gemma3-4B ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
if not chat_handler:
logger.warning("โš ๏ธ ํ…์ŠคํŠธ ์ „์šฉ ๋ชจ๋“œ๋กœ ์‹คํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€ ๋ถ„์„์ด ์ œํ•œ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
return True
except Exception as e:
logger.error(f"๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e}")
import traceback
logger.error(traceback.format_exc())
return False
##############################################################################
# ์ฑ„ํŒ… ํ…œํ”Œ๋ฆฟ ํฌ๋งทํŒ…
##############################################################################
def format_chat_prompt(system_prompt: str, user_prompt: str, image_uri: Optional[str] = None) -> List[Dict]:
"""Gemma ์Šคํƒ€์ผ ์ฑ„ํŒ… ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
messages = []
# ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€
messages.append({
"role": "system",
"content": system_prompt
})
# ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€
user_content = []
if image_uri:
user_content.append({
"type": "image_url",
"image_url": {"url": image_uri}
})
user_content.append({
"type": "text",
"text": user_prompt
})
messages.append({
"role": "user",
"content": user_content
})
return messages
##############################################################################
# ์ด๋ฏธ์ง€ ๋ถ„์„ (๋กœ๋ด‡ ํƒœ์Šคํฌ ์ค‘์‹ฌ)
##############################################################################
@spaces.GPU(duration=60)
def analyze_image_for_robot(
image: Union[np.ndarray, Image.Image],
prompt: str,
task_type: str = "general",
use_web_search: bool = False,
enable_thinking: bool = False,
max_new_tokens: int = 300
) -> str:
"""๋กœ๋ด‡ ์ž‘์—…์„ ์œ„ํ•œ ์ด๋ฏธ์ง€ ๋ถ„์„"""
global llm
if not model_loaded:
if not load_model():
return "โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ"
try:
# Vision ๋ชจ๋ธ์ด ์—†๋Š” ๊ฒฝ์šฐ ๊ฒฝ๊ณ 
if not hasattr(llm, 'chat_handler') or llm.chat_handler is None:
logger.warning("Vision ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํ…์ŠคํŠธ ๊ธฐ๋ฐ˜ ๋ถ„์„๋งŒ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.")
# ํ…์ŠคํŠธ ์ „์šฉ ๋ถ„์„
system_prompt = f"""๋‹น์‹ ์€ ๋กœ๋ด‡ ์‹œ๊ฐ ์‹œ์Šคํ…œ ์‹œ๋ฎฌ๋ ˆ์ดํ„ฐ์ž…๋‹ˆ๋‹ค.
์‹ค์ œ ์ด๋ฏธ์ง€๋ฅผ ๋ณผ ์ˆ˜๋Š” ์—†์ง€๋งŒ, ์‚ฌ์šฉ์ž์˜ ์„ค๋ช…์„ ๋ฐ”ํƒ•์œผ๋กœ ๋กœ๋ด‡ ์ž‘์—…์„ ๊ณ„ํšํ•˜๊ณ  ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
ํƒœ์Šคํฌ ์œ ํ˜•: {task_type}"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"[์ด๋ฏธ์ง€ ๋ถ„์„ ์š”์ฒญ] {prompt}"}
]
response = llm.create_chat_completion(
messages=messages,
max_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
stream=False
)
result = response['choices'][0]['message']['content'].strip()
return f"โš ๏ธ ํ…์ŠคํŠธ ์ „์šฉ ๋ชจ๋“œ\n\n{result}"
# ์ด๋ฏธ์ง€๋ฅผ base64๋กœ ๋ณ€ํ™˜
image_uri = image_to_base64_data_uri(image)
# ํƒœ์Šคํฌ๋ณ„ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
system_prompts = {
"general": "๋‹น์‹ ์€ ๋กœ๋ด‡ ์‹œ๊ฐ ์‹œ์Šคํ…œ์ž…๋‹ˆ๋‹ค. ๋จผ์ € ์žฅ๋ฉด์„ 1-2์ค„๋กœ ์„ค๋ช…ํ•˜๊ณ , ํ•ต์‹ฌ ๋‚ด์šฉ์„ ๊ฐ„๊ฒฐํ•˜๊ฒŒ ๋ถ„์„ํ•˜์„ธ์š”.",
"planning": """๋‹น์‹ ์€ ๋กœ๋ด‡ ์ž‘์—… ๊ณ„ํš AI์ž…๋‹ˆ๋‹ค.
๋จผ์ € ์žฅ๋ฉด ์ดํ•ด๋ฅผ 1-2์ค„๋กœ ์„ค๋ช…ํ•˜๊ณ , ๊ทธ ๋‹ค์Œ ์ž‘์—… ๊ณ„ํš์„ ์ž‘์„ฑํ•˜์„ธ์š”.
ํ˜•์‹:
[์žฅ๋ฉด ์ดํ•ด] ํ˜„์žฌ ๋ณด์ด๋Š” ์žฅ๋ฉด์„ 1-2์ค„๋กœ ์„ค๋ช…
[์ž‘์—… ๊ณ„ํš]
Step_1: xxx
Step_2: xxx
Step_n: xxx""",
"grounding": "๋‹น์‹ ์€ ๊ฐ์ฒด ์œ„์น˜ ์‹œ์Šคํ…œ์ž…๋‹ˆ๋‹ค. ๋จผ์ € ๋ณด์ด๋Š” ๊ฐ์ฒด๋“ค์„ ํ•œ ์ค„๋กœ ์„ค๋ช…ํ•˜๊ณ , ์š”์ฒญ๋œ ๊ฐ์ฒด ์œ„์น˜๋ฅผ [x1, y1, x2, y2]๋กœ ๋ฐ˜ํ™˜ํ•˜์„ธ์š”.",
"affordance": "๋‹น์‹ ์€ ํŒŒ์ง€์  ๋ถ„์„ AI์ž…๋‹ˆ๋‹ค. ๋จผ์ € ๋Œ€์ƒ ๊ฐ์ฒด๋ฅผ ํ•œ ์ค„๋กœ ์„ค๋ช…ํ•˜๊ณ , ํŒŒ์ง€ ์˜์—ญ์„ [x1, y1, x2, y2]๋กœ ๋ฐ˜ํ™˜ํ•˜์„ธ์š”.",
"trajectory": "๋‹น์‹ ์€ ๊ฒฝ๋กœ ๊ณ„ํš AI์ž…๋‹ˆ๋‹ค. ๋จผ์ € ํ™˜๊ฒฝ์„ ํ•œ ์ค„๋กœ ์„ค๋ช…ํ•˜๊ณ , ๊ฒฝ๋กœ๋ฅผ [(x1,y1), (x2,y2), ...]๋กœ ์ œ์‹œํ•˜์„ธ์š”.",
"pointing": "๋‹น์‹ ์€ ์ง€์  ์ง€์ • ์‹œ์Šคํ…œ์ž…๋‹ˆ๋‹ค. ๋จผ์ € ์ฐธ์กฐ์ ๋“ค์„ ํ•œ ์ค„๋กœ ์„ค๋ช…ํ•˜๊ณ , ์œ„์น˜๋ฅผ [(x1,y1), (x2,y2), ...]๋กœ ๋ฐ˜ํ™˜ํ•˜์„ธ์š”."
}
system_prompt = system_prompts.get(task_type, system_prompts["general"])
# Chain-of-Thought ์ถ”๊ฐ€ (์„ ํƒ์ )
if enable_thinking:
system_prompt += "\n\n์ถ”๋ก  ๊ณผ์ •์„ <thinking></thinking> ํƒœ๊ทธ ์•ˆ์— ์ž‘์„ฑ ํ›„ ์ตœ์ข… ๋‹ต๋ณ€์„ ์ œ์‹œํ•˜์„ธ์š”. ์žฅ๋ฉด ์ดํ•ด๋Š” ์ถ”๋ก  ๊ณผ์ •๊ณผ ๋ณ„๋„๋กœ ๋ฐ˜๋“œ์‹œ ํฌํ•จํ•˜์„ธ์š”."
# ์›น ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
combined_system = system_prompt
if use_web_search:
keywords = extract_keywords(prompt, top_k=5)
if keywords:
logger.info(f"์›น ๊ฒ€์ƒ‰ ํ‚ค์›Œ๋“œ: {keywords}")
search_results = do_web_search(keywords)
combined_system = f"{search_results}\n\n{system_prompt}"
# ๋ฉ”์‹œ์ง€ ๊ตฌ์„ฑ
messages = format_chat_prompt(combined_system, prompt, image_uri)
# ์ƒ์„ฑ
response = llm.create_chat_completion(
messages=messages,
max_tokens=max_new_tokens,
temperature=0.7,
top_p=0.9,
stream=False
)
# ์‘๋‹ต ์ถ”์ถœ
result = response['choices'][0]['message']['content'].strip()
return result
except Exception as e:
logger.error(f"์ด๋ฏธ์ง€ ๋ถ„์„ ์˜ค๋ฅ˜: {e}")
import traceback
return f"โŒ ๋ถ„์„ ์˜ค๋ฅ˜: {str(e)}\n{traceback.format_exc()}"
finally:
clear_cuda_cache()
##############################################################################
# ๋ฌธ์„œ ๋ถ„์„ (์ŠคํŠธ๋ฆฌ๋ฐ)
##############################################################################
@spaces.GPU(duration=120)
def analyze_documents_streaming(
files: List[str],
prompt: str,
use_web_search: bool = False,
max_new_tokens: int = 2048
) -> Iterator[str]:
"""๋ฌธ์„œ ๋ถ„์„ (์ŠคํŠธ๋ฆฌ๋ฐ)"""
global llm
if not model_loaded:
if not load_model():
yield "โŒ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ"
return
try:
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
system_content = "๋‹น์‹ ์€ ๋ฌธ์„œ๋ฅผ ๋ถ„์„ํ•˜๊ณ  ์š”์•ฝํ•˜๋Š” ์ „๋ฌธ AI์ž…๋‹ˆ๋‹ค."
# ์›น ๊ฒ€์ƒ‰
if use_web_search:
keywords = extract_keywords(prompt, top_k=5)
if keywords:
search_results = do_web_search(keywords)
system_content = f"{search_results}\n\n{system_content}"
# ๋ฌธ์„œ ๋‚ด์šฉ ์ฒ˜๋ฆฌ
doc_contents = []
for file_path in files:
if file_path.lower().endswith('.csv'):
content = analyze_csv_file(file_path)
elif file_path.lower().endswith('.txt'):
content = analyze_txt_file(file_path)
elif file_path.lower().endswith('.pdf'):
content = pdf_to_markdown(file_path)
else:
continue
doc_contents.append(content)
# ์ „์ฒด ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
full_prompt = "\n\n".join(doc_contents) + f"\n\n{prompt}"
# ๋ฉ”์‹œ์ง€ ๊ตฌ์„ฑ
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": full_prompt}
]
# ์ŠคํŠธ๋ฆฌ๋ฐ ์ƒ์„ฑ
stream = llm.create_chat_completion(
messages=messages,
max_tokens=max_new_tokens,
temperature=0.8,
top_p=0.9,
stream=True
)
# ์ŠคํŠธ๋ฆฌ๋ฐ ์ถœ๋ ฅ
output = ""
for chunk in stream:
if 'choices' in chunk and len(chunk['choices']) > 0:
delta = chunk['choices'][0].get('delta', {})
if 'content' in delta:
output += delta['content']
yield output
except Exception as e:
logger.error(f"๋ฌธ์„œ ๋ถ„์„ ์˜ค๋ฅ˜: {e}")
yield f"โŒ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
finally:
clear_cuda_cache()
##############################################################################
# Gradio UI (๋กœ๋ด‡ ์‹œ๊ฐํ™” ์ค‘์‹ฌ)
##############################################################################
css = """
.robot-header {
text-align: center;
background: linear-gradient(135deg, #1e3c72 0%, #2a5298 50%, #667eea 100%);
color: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.status-box {
text-align: center;
padding: 10px;
border-radius: 5px;
margin: 10px 0;
font-weight: bold;
}
.info-box {
background: #f0f0f0;
padding: 15px;
border-radius: 8px;
margin: 10px 0;
border-left: 4px solid #2a5298;
}
.task-button {
min-height: 60px;
font-size: 1.1em;
}
.webcam-container {
border: 3px solid #2a5298;
border-radius: 10px;
padding: 10px;
background: #f8f9fa;
}
.auto-capture-status {
text-align: center;
padding: 5px;
border-radius: 5px;
margin: 5px 0;
font-weight: bold;
background: #e8f5e9;
color: #2e7d32;
}
.model-info {
background: #fff3cd;
color: #856404;
padding: 10px;
border-radius: 5px;
margin: 10px 0;
text-align: center;
}
"""
with gr.Blocks(title="๐Ÿค– ๋กœ๋ด‡ ์‹œ๊ฐ ์‹œ์Šคํ…œ (Gemma3-4B GGUF)", css=css) as demo:
gr.HTML("""
<div class="robot-header">
<h1>๐Ÿค– ๋กœ๋ด‡ ์‹œ๊ฐ ์‹œ์Šคํ…œ</h1>
<h3>๐ŸŽฎ Gemma3-4B GGUF Q4_K_M + ๐Ÿ“ท ์‹ค์‹œ๊ฐ„ ์›น์บ  + ๐Ÿ” ์›น ๊ฒ€์ƒ‰</h3>
<p>โšก ์–‘์žํ™” ๋ชจ๋ธ๋กœ ๋” ๋น ๋ฅด๊ณ  ํšจ์œจ์ ์ธ ๋กœ๋ด‡ ์ž‘์—… ๋ถ„์„!</p>
</div>
""")
gr.HTML("""
<div class="model-info">
<strong>๋ชจ๋ธ:</strong> Gemma3-4B Q4_K_M (2.5GB) | <strong>๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ:</strong> ~3-4GB VRAM
</div>
""")
with gr.Row():
# ์™ผ์ชฝ: ์›น์บ  ๋ฐ ์ž…๋ ฅ
with gr.Column(scale=1):
gr.Markdown("### ๐Ÿ“ท ์‹ค์‹œ๊ฐ„ ์›น์บ ")
with gr.Group(elem_classes="webcam-container"):
webcam = gr.Image(
sources=["webcam"],
streaming=True,
type="numpy",
label="์‹ค์‹œ๊ฐ„ ์ŠคํŠธ๋ฆฌ๋ฐ",
height=350
)
# ์ž๋™ ์บก์ฒ˜ ์ƒํƒœ ํ‘œ์‹œ
auto_capture_status = gr.HTML(
'<div class="auto-capture-status">๐Ÿ”„ ์ž๋™ ์บก์ฒ˜: ๋Œ€๊ธฐ ์ค‘</div>'
)
# ์บก์ฒ˜๋œ ์ด๋ฏธ์ง€ ํ‘œ์‹œ
captured_image = gr.Image(
label="์บก์ฒ˜๋œ ์ด๋ฏธ์ง€",
height=200,
visible=False
)
# ๋กœ๋ด‡ ์ž‘์—… ๋ฒ„ํŠผ๋“ค
gr.Markdown("### ๐ŸŽฏ ๋กœ๋ด‡ ์ž‘์—… ์„ ํƒ")
with gr.Row():
capture_btn = gr.Button("๐Ÿ“ธ ์ˆ˜๋™ ์บก์ฒ˜", variant="primary", elem_classes="task-button")
clear_capture_btn = gr.Button("๐Ÿ—‘๏ธ ์ดˆ๊ธฐํ™”", elem_classes="task-button")
with gr.Row():
auto_capture_toggle = gr.Checkbox(
label="๐Ÿ”„ ์ž๋™ ์บก์ฒ˜ ํ™œ์„ฑํ™” (10์ดˆ๋งˆ๋‹ค)",
value=False,
info="ํ™œ์„ฑํ™” ์‹œ 10์ดˆ๋งˆ๋‹ค ์ž๋™์œผ๋กœ ์บก์ฒ˜ ๋ฐ ๋ถ„์„"
)
with gr.Row():
planning_btn = gr.Button("๐Ÿ“‹ ์ž‘์—… ๊ณ„ํš", elem_classes="task-button")
grounding_btn = gr.Button("๐Ÿ“ ๊ฐ์ฒด ์œ„์น˜", elem_classes="task-button")
with gr.Row():
affordance_btn = gr.Button("๐Ÿค ํŒŒ์ง€์  ๋ถ„์„", elem_classes="task-button")
trajectory_btn = gr.Button("๐Ÿ›ค๏ธ ๊ฒฝ๋กœ ๊ณ„ํš", elem_classes="task-button")
# ์˜ค๋ฅธ์ชฝ: ๋ถ„์„ ์„ค์ • ๋ฐ ๊ฒฐ๊ณผ
with gr.Column(scale=2):
gr.Markdown("### โš™๏ธ ๋ถ„์„ ์„ค์ •")
with gr.Row():
with gr.Column():
task_prompt = gr.Textbox(
label="์ž‘์—… ์„ค๋ช… / ์งˆ๋ฌธ",
placeholder="์˜ˆ: ํ…Œ์ด๋ธ” ์œ„์˜ ์ปต์„ ์žก์•„์„œ ์‹ฑํฌ๋Œ€์— ๋†“๊ธฐ",
value="ํ˜„์žฌ ์žฅ๋ฉด์„ ๋ถ„์„ํ•˜๊ณ  ๋กœ๋ด‡์ด ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” ์ž‘์—…์„ ์ œ์•ˆํ•˜์„ธ์š”.",
lines=2
)
with gr.Row():
use_web_search = gr.Checkbox(
label="๐Ÿ” ์›น ๊ฒ€์ƒ‰ ์‚ฌ์šฉ",
value=False,
info="๊ด€๋ จ ์ •๋ณด๋ฅผ ์›น์—์„œ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค"
)
enable_thinking = gr.Checkbox(
label="๐Ÿค” ์ถ”๋ก  ๊ณผ์ • ํ‘œ์‹œ",
value=False,
info="Chain-of-Thought ์ถ”๋ก  ๊ณผ์ •์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค"
)
max_tokens = gr.Slider(
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
minimum=100,
maximum=2048,
value=300,
step=50
)
gr.Markdown("### ๐Ÿ“Š ๋ถ„์„ ๊ฒฐ๊ณผ")
result_output = gr.Textbox(
label="AI ๋ถ„์„ ๊ฒฐ๊ณผ",
lines=20,
max_lines=40,
show_copy_button=True,
elem_id="result"
)
status_display = gr.HTML(
'<div class="status-box" style="background:#d4edda; color:#155724;">๐ŸŽฎ ์‹œ์Šคํ…œ ์ค€๋น„ ์™„๋ฃŒ</div>'
)
# ๋ฌธ์„œ ๋ถ„์„ ํƒญ
with gr.Tab("๐Ÿ“„ ๋ฌธ์„œ ๋ถ„์„", visible=False):
with gr.Row():
with gr.Column():
doc_files = gr.File(
label="๋ฌธ์„œ ์—…๋กœ๋“œ",
file_count="multiple",
file_types=[".pdf", ".csv", ".txt"],
type="filepath"
)
doc_prompt = gr.Textbox(
label="๋ถ„์„ ์š”์ฒญ",
placeholder="์˜ˆ: ์ด ๋ฌธ์„œ๋“ค์˜ ํ•ต์‹ฌ ๋‚ด์šฉ์„ ์š”์•ฝํ•˜๊ณ  ๋น„๊ต ๋ถ„์„ํ•˜์„ธ์š”.",
lines=3
)
doc_web_search = gr.Checkbox(
label="๐Ÿ” ์›น ๊ฒ€์ƒ‰ ์‚ฌ์šฉ",
value=False
)
analyze_docs_btn = gr.Button("๐Ÿ“Š ๋ฌธ์„œ ๋ถ„์„", variant="primary")
with gr.Column():
doc_result = gr.Textbox(
label="๋ถ„์„ ๊ฒฐ๊ณผ",
lines=25,
max_lines=50
)
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
webcam_state = gr.State(None)
auto_capture_state = gr.State({"enabled": False, "timer": None})
def capture_webcam(frame):
"""์›น์บ  ํ”„๋ ˆ์ž„ ์บก์ฒ˜"""
if frame is None:
return None, None, '<div class="status-box" style="background:#f8d7da; color:#721c24;">โŒ ์›น์บ  ํ”„๋ ˆ์ž„ ์—†์Œ</div>'
return frame, gr.update(value=frame, visible=True), '<div class="status-box" style="background:#d4edda; color:#155724;">โœ… ์ด๋ฏธ์ง€ ์บก์ฒ˜ ์™„๋ฃŒ</div>'
def clear_capture():
"""์บก์ฒ˜ ์ดˆ๊ธฐํ™”"""
return None, gr.update(visible=False), '<div class="status-box" style="background:#d4edda; color:#155724;">๐ŸŽฎ ์‹œ์Šคํ…œ ์ค€๋น„ ์™„๋ฃŒ</div>'
def analyze_with_task(image, prompt, task_type, use_search, thinking, tokens):
"""ํŠน์ • ํƒœ์Šคํฌ๋กœ ์ด๋ฏธ์ง€ ๋ถ„์„"""
if image is None:
return "โŒ ๋จผ์ € ์ด๋ฏธ์ง€๋ฅผ ์บก์ฒ˜ํ•˜์„ธ์š”.", '<div class="status-box" style="background:#f8d7da; color:#721c24;">โŒ ์ด๋ฏธ์ง€ ์—†์Œ</div>'
status = f'<div class="status-box" style="background:#cce5ff; color:#004085;">๐Ÿš€ {task_type} ๋ถ„์„ ์ค‘...</div>'
result = analyze_image_for_robot(
image=image,
prompt=prompt,
task_type=task_type,
use_web_search=use_search,
enable_thinking=thinking,
max_new_tokens=tokens
)
# ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
timestamp = time.strftime("%H:%M:%S")
task_names = {
"planning": "์ž‘์—… ๊ณ„ํš",
"grounding": "๊ฐ์ฒด ์œ„์น˜",
"affordance": "ํŒŒ์ง€์ ",
"trajectory": "๊ฒฝ๋กœ ๊ณ„ํš"
}
formatted_result = f"""๐Ÿค– {task_names.get(task_type, '๋ถ„์„')} ๊ฒฐ๊ณผ ({timestamp})
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
{result}
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”"""
complete_status = '<div class="status-box" style="background:#d4edda; color:#155724;">โœ… ๋ถ„์„ ์™„๋ฃŒ!</div>'
return formatted_result, complete_status
# ์ž๋™ ์บก์ฒ˜ ๋ฐ ๋ถ„์„ ํ•จ์ˆ˜
def auto_capture_and_analyze(webcam_frame, task_prompt, use_search, thinking, tokens, auto_state):
"""์ž๋™ ์บก์ฒ˜ ๋ฐ ๋ถ„์„"""
if webcam_frame is None:
return (
None,
"์ž๋™ ์บก์ฒ˜ ๋Œ€๊ธฐ ์ค‘...",
'<div class="status-box" style="background:#fff3cd; color:#856404;">โณ ์›น์บ  ๋Œ€๊ธฐ ์ค‘</div>',
'<div class="auto-capture-status">๐Ÿ”„ ์ž๋™ ์บก์ฒ˜: ์›น์บ  ๋Œ€๊ธฐ ์ค‘</div>'
)
# ์บก์ฒ˜ ์ˆ˜ํ–‰
timestamp = time.strftime("%H:%M:%S")
# ์ด๋ฏธ์ง€ ๋ถ„์„ (์ž‘์—… ๊ณ„ํš ๋ชจ๋“œ๋กœ)
result = analyze_image_for_robot(
image=webcam_frame,
prompt=task_prompt,
task_type="planning",
use_web_search=use_search,
enable_thinking=thinking,
max_new_tokens=tokens
)
formatted_result = f"""๐Ÿ”„ ์ž๋™ ๋ถ„์„ ์™„๋ฃŒ ({timestamp})
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
{result}
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”"""
return (
webcam_frame,
formatted_result,
'<div class="status-box" style="background:#d4edda; color:#155724;">โœ… ์ž๋™ ๋ถ„์„ ์™„๋ฃŒ</div>',
f'<div class="auto-capture-status">๐Ÿ”„ ์ž๋™ ์บก์ฒ˜: ๋งˆ์ง€๋ง‰ ๋ถ„์„ {timestamp}</div>'
)
# ์›น์บ  ์ŠคํŠธ๋ฆฌ๋ฐ
webcam.stream(
fn=lambda x: x,
inputs=[webcam],
outputs=[webcam_state]
)
# ์ˆ˜๋™ ์บก์ฒ˜ ๋ฒ„ํŠผ
capture_btn.click(
fn=capture_webcam,
inputs=[webcam_state],
outputs=[webcam_state, captured_image, status_display]
)
# ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ
clear_capture_btn.click(
fn=clear_capture,
outputs=[webcam_state, captured_image, status_display]
)
# ์ž‘์—… ๋ฒ„ํŠผ๋“ค
planning_btn.click(
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "planning", s, t, tk),
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens],
outputs=[result_output, status_display]
)
grounding_btn.click(
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "grounding", s, t, tk),
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens],
outputs=[result_output, status_display]
)
affordance_btn.click(
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "affordance", s, t, tk),
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens],
outputs=[result_output, status_display]
)
trajectory_btn.click(
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "trajectory", s, t, tk),
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens],
outputs=[result_output, status_display]
)
# ๋ฌธ์„œ ๋ถ„์„
def analyze_docs(files, prompt, use_search):
if not files:
return "โŒ ๋ฌธ์„œ๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”."
output = ""
for chunk in analyze_documents_streaming(files, prompt, use_search):
output = chunk
return output
analyze_docs_btn.click(
fn=analyze_docs,
inputs=[doc_files, doc_prompt, doc_web_search],
outputs=[doc_result]
)
# ์ž๋™ ์บก์ฒ˜ ํƒ€์ด๋จธ (10์ดˆ๋งˆ๋‹ค)
timer = gr.Timer(10.0, active=False)
# ์ž๋™ ์บก์ฒ˜ ํ† ๊ธ€ ์ด๋ฒคํŠธ
def toggle_auto_capture(enabled):
if enabled:
return gr.Timer(10.0, active=True), '<div class="auto-capture-status">๐Ÿ”„ ์ž๋™ ์บก์ฒ˜: ํ™œ์„ฑํ™”๋จ (10์ดˆ๋งˆ๋‹ค)</div>'
else:
return gr.Timer(active=False), '<div class="auto-capture-status">๐Ÿ”„ ์ž๋™ ์บก์ฒ˜: ๋น„ํ™œ์„ฑํ™”๋จ</div>'
auto_capture_toggle.change(
fn=toggle_auto_capture,
inputs=[auto_capture_toggle],
outputs=[timer, auto_capture_status]
)
# ํƒ€์ด๋จธ ํ‹ฑ ์ด๋ฒคํŠธ
timer.tick(
fn=auto_capture_and_analyze,
inputs=[webcam_state, task_prompt, use_web_search, enable_thinking, max_tokens, auto_capture_state],
outputs=[captured_image, result_output, status_display, auto_capture_status]
)
# ์ดˆ๊ธฐ ๋ชจ๋ธ ๋กœ๋“œ
def initial_load():
# ์ฒซ ์‹คํ–‰ ์‹œ GPU์—์„œ ๋ชจ๋ธ ๋กœ๋“œ
return "์‹œ์Šคํ…œ ์ค€๋น„ ์™„๋ฃŒ! ์ฒซ ๋ถ„์„ ์‹œ ๋ชจ๋ธ์ด ์ž๋™์œผ๋กœ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค. ๐Ÿš€"
demo.load(
fn=initial_load,
outputs=None
)
if __name__ == "__main__":
print("๐Ÿš€ ๋กœ๋ด‡ ์‹œ๊ฐ ์‹œ์Šคํ…œ ์‹œ์ž‘ (Gemma3-4B GGUF Q4_K_M)...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
debug=False
)