Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import os | |
import re | |
import tempfile | |
import gc | |
from collections.abc import Iterator | |
from threading import Thread | |
import json | |
import requests | |
import cv2 | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from loguru import logger | |
from PIL import Image | |
import time | |
import warnings | |
from typing import Dict, List, Optional, Union | |
import base64 | |
from io import BytesIO | |
# llama-cpp-python for GGUF | |
from llama_cpp import Llama | |
from llama_cpp.llama_chat_format import Llava16ChatHandler | |
# Model download | |
from huggingface_hub import hf_hub_download | |
# CSV/TXT ๋ถ์ | |
import pandas as pd | |
# PDF ํ ์คํธ ์ถ์ถ | |
import PyPDF2 | |
warnings.filterwarnings('ignore') | |
print("๐ฎ ๋ก๋ด ์๊ฐ ์์คํ ์ด๊ธฐํ (Gemma3-4B GGUF Q4_K_M)...") | |
############################################################################## | |
# ์์ ์ ์ | |
############################################################################## | |
MAX_CONTENT_CHARS = 2000 | |
MAX_INPUT_LENGTH = 2096 | |
MAX_NUM_IMAGES = 5 | |
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "") | |
############################################################################## | |
# ์ ์ญ ๋ณ์ | |
############################################################################## | |
llm = None | |
model_loaded = False | |
model_name = "Gemma3-4B-GGUF-Q4_K_M" | |
############################################################################## | |
# ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ | |
############################################################################## | |
def clear_cuda_cache(): | |
"""CUDA ์บ์๋ฅผ ๋ช ์์ ์ผ๋ก ๋น์๋๋ค.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
############################################################################## | |
# ํค์๋ ์ถ์ถ ํจ์ | |
############################################################################## | |
def extract_keywords(text: str, top_k: int = 5) -> str: | |
"""ํค์๋ ์ถ์ถ""" | |
text = re.sub(r"[^a-zA-Z0-9๊ฐ-ํฃ\s]", "", text) | |
tokens = text.split() | |
seen = set() | |
unique_tokens = [] | |
for token in tokens: | |
if token not in seen and len(token) > 1: | |
seen.add(token) | |
unique_tokens.append(token) | |
key_tokens = unique_tokens[:top_k] | |
return " ".join(key_tokens) | |
############################################################################## | |
# ์น ๊ฒ์ ํจ์ | |
############################################################################## | |
def do_web_search(query: str) -> str: | |
"""SerpHouse API๋ฅผ ์ฌ์ฉํ ์น ๊ฒ์""" | |
try: | |
url = "https://api.serphouse.com/serp/live" | |
params = { | |
"q": query, | |
"domain": "google.com", | |
"serp_type": "web", | |
"device": "desktop", | |
"lang": "ko", | |
"num": "10" | |
} | |
headers = { | |
"Authorization": f"Bearer {SERPHOUSE_API_KEY}" | |
} | |
logger.info(f"์น ๊ฒ์ ์ค... ๊ฒ์์ด: {query}") | |
response = requests.get(url, headers=headers, params=params, timeout=60) | |
response.raise_for_status() | |
data = response.json() | |
results = data.get("results", {}) | |
organic = results.get("organic", []) if isinstance(results, dict) else [] | |
if not organic: | |
return "๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." | |
max_results = min(10, len(organic)) | |
limited_organic = organic[:max_results] | |
summary_lines = [] | |
for idx, item in enumerate(limited_organic, start=1): | |
title = item.get("title", "์ ๋ชฉ ์์") | |
link = item.get("link", "#") | |
snippet = item.get("snippet", "์ค๋ช ์์") | |
displayed_link = item.get("displayed_link", link) | |
summary_lines.append( | |
f"### ๊ฒฐ๊ณผ {idx}: {title}\n\n" | |
f"{snippet}\n\n" | |
f"**์ถ์ฒ**: [{displayed_link}]({link})\n\n" | |
f"---\n" | |
) | |
instructions = """# ์น ๊ฒ์ ๊ฒฐ๊ณผ | |
์๋๋ ๊ฒ์ ๊ฒฐ๊ณผ์ ๋๋ค. ๋ต๋ณ ์ ์ด ์ ๋ณด๋ฅผ ํ์ฉํ์ธ์: | |
1. ๊ฐ ๊ฒฐ๊ณผ์ ์ ๋ชฉ, ๋ด์ฉ, ์ถ์ฒ ๋งํฌ๋ฅผ ์ฐธ์กฐํ์ธ์ | |
2. ๊ด๋ จ ์ถ์ฒ๋ฅผ ๋ช ์์ ์ผ๋ก ์ธ์ฉํ์ธ์ | |
3. ์ฌ๋ฌ ์ถ์ฒ์ ์ ๋ณด๋ฅผ ์ข ํฉํ์ฌ ๋ต๋ณํ์ธ์ | |
""" | |
search_results = instructions + "\n".join(summary_lines) | |
return search_results | |
except Exception as e: | |
logger.error(f"์น ๊ฒ์ ์คํจ: {e}") | |
return f"์น ๊ฒ์ ์คํจ: {str(e)}" | |
############################################################################## | |
# ๋ฌธ์ ์ฒ๋ฆฌ ํจ์ | |
############################################################################## | |
def analyze_csv_file(path: str) -> str: | |
"""CSV ํ์ผ ๋ถ์""" | |
try: | |
df = pd.read_csv(path) | |
if df.shape[0] > 50 or df.shape[1] > 10: | |
df = df.iloc[:50, :10] | |
df_str = df.to_string() | |
if len(df_str) > MAX_CONTENT_CHARS: | |
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(์ค๋ต)..." | |
return f"**[CSV ํ์ผ: {os.path.basename(path)}]**\n\n{df_str}" | |
except Exception as e: | |
return f"CSV ์ฝ๊ธฐ ์คํจ ({os.path.basename(path)}): {str(e)}" | |
def analyze_txt_file(path: str) -> str: | |
"""TXT ํ์ผ ๋ถ์""" | |
try: | |
with open(path, "r", encoding="utf-8") as f: | |
text = f.read() | |
if len(text) > MAX_CONTENT_CHARS: | |
text = text[:MAX_CONTENT_CHARS] + "\n...(์ค๋ต)..." | |
return f"**[TXT ํ์ผ: {os.path.basename(path)}]**\n\n{text}" | |
except Exception as e: | |
return f"TXT ์ฝ๊ธฐ ์คํจ ({os.path.basename(path)}): {str(e)}" | |
def pdf_to_markdown(pdf_path: str) -> str: | |
"""PDF๋ฅผ ๋งํฌ๋ค์ด์ผ๋ก ๋ณํ""" | |
text_chunks = [] | |
try: | |
with open(pdf_path, "rb") as f: | |
reader = PyPDF2.PdfReader(f) | |
max_pages = min(5, len(reader.pages)) | |
for page_num in range(max_pages): | |
page = reader.pages[page_num] | |
page_text = page.extract_text() or "" | |
page_text = page_text.strip() | |
if page_text: | |
if len(page_text) > MAX_CONTENT_CHARS // max_pages: | |
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(์ค๋ต)" | |
text_chunks.append(f"## ํ์ด์ง {page_num+1}\n\n{page_text}\n") | |
if len(reader.pages) > max_pages: | |
text_chunks.append(f"\n...({max_pages}/{len(reader.pages)} ํ์ด์ง ํ์)...") | |
except Exception as e: | |
return f"PDF ์ฝ๊ธฐ ์คํจ ({os.path.basename(pdf_path)}): {str(e)}" | |
full_text = "\n".join(text_chunks) | |
if len(full_text) > MAX_CONTENT_CHARS: | |
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(์ค๋ต)..." | |
return f"**[PDF ํ์ผ: {os.path.basename(pdf_path)}]**\n\n{full_text}" | |
############################################################################## | |
# ์ด๋ฏธ์ง๋ฅผ base64๋ก ๋ณํ | |
############################################################################## | |
def image_to_base64_data_uri(image: Union[np.ndarray, Image.Image]) -> str: | |
"""์ด๋ฏธ์ง๋ฅผ base64 data URI๋ก ๋ณํ""" | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image).convert('RGB') | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG", quality=85) | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return f"data:image/jpeg;base64,{img_str}" | |
############################################################################## | |
# ๋ชจ๋ธ ๋ก๋ | |
############################################################################## | |
def download_model_files(): | |
"""Hugging Face Hub์์ ๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋""" | |
# ์ฌ๋ฌ ๊ฐ๋ฅํ ์ ์ฅ์ ์๋ | |
model_repos = [ | |
# ์ฒซ ๋ฒ์งธ ์๋: ์ผ๋ฐ์ ์ธ Gemma 3 4B GGUF | |
{ | |
"repo": "Mungert/gemma-3-4b-it-gguf", | |
"model": "google_gemma-3-4b-it-q4_k_m.gguf", | |
"mmproj": "google_gemma-3-4b-it-mmproj-bf16.gguf" | |
}, | |
# ๋ ๋ฒ์งธ ์๋: LM Studio ๋ฒ์ | |
{ | |
"repo": "lmstudio-community/gemma-3-4b-it-GGUF", | |
"model": "gemma-3-4b-it-Q4_K_M.gguf", | |
"mmproj": "gemma-3-4b-it-mmproj-f16.gguf" | |
}, | |
# ์ธ ๋ฒ์งธ ์๋: unsloth ๋ฒ์ | |
{ | |
"repo": "unsloth/gemma-3-4b-it-GGUF", | |
"model": "gemma-3-4b-it.Q4_K_M.gguf", | |
"mmproj": "gemma-3-4b-it.mmproj.gguf" | |
} | |
] | |
for repo_info in model_repos: | |
try: | |
logger.info(f"์ ์ฅ์ ์๋: {repo_info['repo']}") | |
# ๋ฉ์ธ ๋ชจ๋ธ ๋ค์ด๋ก๋ | |
model_filename = repo_info["model"] | |
logger.info(f"๋ชจ๋ธ ๋ค์ด๋ก๋ ์ค: {model_filename}") | |
model_path = hf_hub_download( | |
repo_id=repo_info["repo"], | |
filename=model_filename, | |
resume_download=True, | |
local_files_only=False | |
) | |
# Vision projection ํ์ผ ๋ค์ด๋ก๋ | |
mmproj_filename = repo_info["mmproj"] | |
logger.info(f"Vision ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ค: {mmproj_filename}") | |
try: | |
mmproj_path = hf_hub_download( | |
repo_id=repo_info["repo"], | |
filename=mmproj_filename, | |
resume_download=True, | |
local_files_only=False | |
) | |
except: | |
# mmproj ํ์ผ์ด ์์ ์๋ ์์ | |
logger.warning(f"Vision ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ต๋๋ค: {mmproj_filename}") | |
logger.warning("ํ ์คํธ ์ ์ฉ ๋ชจ๋๋ก ์งํํฉ๋๋ค.") | |
mmproj_path = None | |
logger.info(f"โ ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ฑ๊ณต!") | |
logger.info(f"๋ชจ๋ธ ๊ฒฝ๋ก: {model_path}") | |
if mmproj_path: | |
logger.info(f"Vision ๊ฒฝ๋ก: {mmproj_path}") | |
return model_path, mmproj_path | |
except Exception as e: | |
logger.error(f"์ ์ฅ์ {repo_info['repo']} ์๋ ์คํจ: {e}") | |
continue | |
# ๋ชจ๋ ์๋๊ฐ ์คํจํ ๊ฒฝ์ฐ | |
raise Exception("์ฌ์ฉ ๊ฐ๋ฅํ GGUF ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ต๋๋ค. ์ธํฐ๋ท ์ฐ๊ฒฐ์ ํ์ธํ์ธ์.") | |
def load_model(): | |
global llm, model_loaded | |
if model_loaded: | |
logger.info("๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋์ด ์์ต๋๋ค.") | |
return True | |
try: | |
logger.info("Gemma3-4B GGUF Q4_K_M ๋ชจ๋ธ ๋ก๋ฉ ์์...") | |
clear_cuda_cache() | |
# ๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋ | |
model_path, mmproj_path = download_model_files() | |
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ | |
n_gpu_layers = -1 if torch.cuda.is_available() else 0 | |
# ์ฑํ ํธ๋ค๋ฌ ์์ฑ (๋น์ ์ง์ - mmproj๊ฐ ์๋ ๊ฒฝ์ฐ๋ง) | |
chat_handler = None | |
if mmproj_path: | |
try: | |
chat_handler = Llava16ChatHandler( | |
clip_model_path=mmproj_path, | |
verbose=False | |
) | |
logger.info("โ Vision ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต") | |
except Exception as e: | |
logger.warning(f"Vision ๋ชจ๋ธ ๋ก๋ ์คํจ, ํ ์คํธ ์ ์ฉ ๋ชจ๋๋ก ์ ํ: {e}") | |
chat_handler = None | |
# ๋ชจ๋ธ ๋ก๋ | |
llm_params = { | |
"model_path": model_path, | |
"n_ctx": 4096, # ์ปจํ ์คํธ ํฌ๊ธฐ | |
"n_gpu_layers": n_gpu_layers, # GPU ๋ ์ด์ด | |
"n_threads": 8, # CPU ์ค๋ ๋ | |
"verbose": False, | |
"seed": 42, | |
} | |
# chat_handler๊ฐ ์์ผ๋ฉด ์ถ๊ฐ | |
if chat_handler: | |
llm_params["chat_handler"] = chat_handler | |
llm_params["logits_all"] = True # ๋น์ ๋ชจ๋ธ์ ํ์ | |
llm = Llama(**llm_params) | |
model_loaded = True | |
logger.info(f"โ Gemma3-4B ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
if not chat_handler: | |
logger.warning("โ ๏ธ ํ ์คํธ ์ ์ฉ ๋ชจ๋๋ก ์คํ ์ค์ ๋๋ค. ์ด๋ฏธ์ง ๋ถ์์ด ์ ํ๋ ์ ์์ต๋๋ค.") | |
return True | |
except Exception as e: | |
logger.error(f"๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
return False | |
############################################################################## | |
# ์ฑํ ํ ํ๋ฆฟ ํฌ๋งทํ | |
############################################################################## | |
def format_chat_prompt(system_prompt: str, user_prompt: str, image_uri: Optional[str] = None) -> List[Dict]: | |
"""Gemma ์คํ์ผ ์ฑํ ํ๋กฌํํธ ์์ฑ""" | |
messages = [] | |
# ์์คํ ๋ฉ์์ง | |
messages.append({ | |
"role": "system", | |
"content": system_prompt | |
}) | |
# ์ฌ์ฉ์ ๋ฉ์์ง | |
user_content = [] | |
if image_uri: | |
user_content.append({ | |
"type": "image_url", | |
"image_url": {"url": image_uri} | |
}) | |
user_content.append({ | |
"type": "text", | |
"text": user_prompt | |
}) | |
messages.append({ | |
"role": "user", | |
"content": user_content | |
}) | |
return messages | |
############################################################################## | |
# ์ด๋ฏธ์ง ๋ถ์ (๋ก๋ด ํ์คํฌ ์ค์ฌ) | |
############################################################################## | |
def analyze_image_for_robot( | |
image: Union[np.ndarray, Image.Image], | |
prompt: str, | |
task_type: str = "general", | |
use_web_search: bool = False, | |
enable_thinking: bool = False, | |
max_new_tokens: int = 300 | |
) -> str: | |
"""๋ก๋ด ์์ ์ ์ํ ์ด๋ฏธ์ง ๋ถ์""" | |
global llm | |
if not model_loaded: | |
if not load_model(): | |
return "โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ" | |
try: | |
# Vision ๋ชจ๋ธ์ด ์๋ ๊ฒฝ์ฐ ๊ฒฝ๊ณ | |
if not hasattr(llm, 'chat_handler') or llm.chat_handler is None: | |
logger.warning("Vision ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. ํ ์คํธ ๊ธฐ๋ฐ ๋ถ์๋ง ๊ฐ๋ฅํฉ๋๋ค.") | |
# ํ ์คํธ ์ ์ฉ ๋ถ์ | |
system_prompt = f"""๋น์ ์ ๋ก๋ด ์๊ฐ ์์คํ ์๋ฎฌ๋ ์ดํฐ์ ๋๋ค. | |
์ค์ ์ด๋ฏธ์ง๋ฅผ ๋ณผ ์๋ ์์ง๋ง, ์ฌ์ฉ์์ ์ค๋ช ์ ๋ฐํ์ผ๋ก ๋ก๋ด ์์ ์ ๊ณํํ๊ณ ๋ถ์ํฉ๋๋ค. | |
ํ์คํฌ ์ ํ: {task_type}""" | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"[์ด๋ฏธ์ง ๋ถ์ ์์ฒญ] {prompt}"} | |
] | |
response = llm.create_chat_completion( | |
messages=messages, | |
max_tokens=max_new_tokens, | |
temperature=0.7, | |
top_p=0.9, | |
stream=False | |
) | |
result = response['choices'][0]['message']['content'].strip() | |
return f"โ ๏ธ ํ ์คํธ ์ ์ฉ ๋ชจ๋\n\n{result}" | |
# ์ด๋ฏธ์ง๋ฅผ base64๋ก ๋ณํ | |
image_uri = image_to_base64_data_uri(image) | |
# ํ์คํฌ๋ณ ์์คํ ํ๋กฌํํธ ๊ตฌ์ฑ | |
system_prompts = { | |
"general": "๋น์ ์ ๋ก๋ด ์๊ฐ ์์คํ ์ ๋๋ค. ๋จผ์ ์ฅ๋ฉด์ 1-2์ค๋ก ์ค๋ช ํ๊ณ , ํต์ฌ ๋ด์ฉ์ ๊ฐ๊ฒฐํ๊ฒ ๋ถ์ํ์ธ์.", | |
"planning": """๋น์ ์ ๋ก๋ด ์์ ๊ณํ AI์ ๋๋ค. | |
๋จผ์ ์ฅ๋ฉด ์ดํด๋ฅผ 1-2์ค๋ก ์ค๋ช ํ๊ณ , ๊ทธ ๋ค์ ์์ ๊ณํ์ ์์ฑํ์ธ์. | |
ํ์: | |
[์ฅ๋ฉด ์ดํด] ํ์ฌ ๋ณด์ด๋ ์ฅ๋ฉด์ 1-2์ค๋ก ์ค๋ช | |
[์์ ๊ณํ] | |
Step_1: xxx | |
Step_2: xxx | |
Step_n: xxx""", | |
"grounding": "๋น์ ์ ๊ฐ์ฒด ์์น ์์คํ ์ ๋๋ค. ๋จผ์ ๋ณด์ด๋ ๊ฐ์ฒด๋ค์ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ์์ฒญ๋ ๊ฐ์ฒด ์์น๋ฅผ [x1, y1, x2, y2]๋ก ๋ฐํํ์ธ์.", | |
"affordance": "๋น์ ์ ํ์ง์ ๋ถ์ AI์ ๋๋ค. ๋จผ์ ๋์ ๊ฐ์ฒด๋ฅผ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ํ์ง ์์ญ์ [x1, y1, x2, y2]๋ก ๋ฐํํ์ธ์.", | |
"trajectory": "๋น์ ์ ๊ฒฝ๋ก ๊ณํ AI์ ๋๋ค. ๋จผ์ ํ๊ฒฝ์ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ๊ฒฝ๋ก๋ฅผ [(x1,y1), (x2,y2), ...]๋ก ์ ์ํ์ธ์.", | |
"pointing": "๋น์ ์ ์ง์ ์ง์ ์์คํ ์ ๋๋ค. ๋จผ์ ์ฐธ์กฐ์ ๋ค์ ํ ์ค๋ก ์ค๋ช ํ๊ณ , ์์น๋ฅผ [(x1,y1), (x2,y2), ...]๋ก ๋ฐํํ์ธ์." | |
} | |
system_prompt = system_prompts.get(task_type, system_prompts["general"]) | |
# Chain-of-Thought ์ถ๊ฐ (์ ํ์ ) | |
if enable_thinking: | |
system_prompt += "\n\n์ถ๋ก ๊ณผ์ ์ <thinking></thinking> ํ๊ทธ ์์ ์์ฑ ํ ์ต์ข ๋ต๋ณ์ ์ ์ํ์ธ์. ์ฅ๋ฉด ์ดํด๋ ์ถ๋ก ๊ณผ์ ๊ณผ ๋ณ๋๋ก ๋ฐ๋์ ํฌํจํ์ธ์." | |
# ์น ๊ฒ์ ์ํ | |
combined_system = system_prompt | |
if use_web_search: | |
keywords = extract_keywords(prompt, top_k=5) | |
if keywords: | |
logger.info(f"์น ๊ฒ์ ํค์๋: {keywords}") | |
search_results = do_web_search(keywords) | |
combined_system = f"{search_results}\n\n{system_prompt}" | |
# ๋ฉ์์ง ๊ตฌ์ฑ | |
messages = format_chat_prompt(combined_system, prompt, image_uri) | |
# ์์ฑ | |
response = llm.create_chat_completion( | |
messages=messages, | |
max_tokens=max_new_tokens, | |
temperature=0.7, | |
top_p=0.9, | |
stream=False | |
) | |
# ์๋ต ์ถ์ถ | |
result = response['choices'][0]['message']['content'].strip() | |
return result | |
except Exception as e: | |
logger.error(f"์ด๋ฏธ์ง ๋ถ์ ์ค๋ฅ: {e}") | |
import traceback | |
return f"โ ๋ถ์ ์ค๋ฅ: {str(e)}\n{traceback.format_exc()}" | |
finally: | |
clear_cuda_cache() | |
############################################################################## | |
# ๋ฌธ์ ๋ถ์ (์คํธ๋ฆฌ๋ฐ) | |
############################################################################## | |
def analyze_documents_streaming( | |
files: List[str], | |
prompt: str, | |
use_web_search: bool = False, | |
max_new_tokens: int = 2048 | |
) -> Iterator[str]: | |
"""๋ฌธ์ ๋ถ์ (์คํธ๋ฆฌ๋ฐ)""" | |
global llm | |
if not model_loaded: | |
if not load_model(): | |
yield "โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ" | |
return | |
try: | |
# ์์คํ ํ๋กฌํํธ | |
system_content = "๋น์ ์ ๋ฌธ์๋ฅผ ๋ถ์ํ๊ณ ์์ฝํ๋ ์ ๋ฌธ AI์ ๋๋ค." | |
# ์น ๊ฒ์ | |
if use_web_search: | |
keywords = extract_keywords(prompt, top_k=5) | |
if keywords: | |
search_results = do_web_search(keywords) | |
system_content = f"{search_results}\n\n{system_content}" | |
# ๋ฌธ์ ๋ด์ฉ ์ฒ๋ฆฌ | |
doc_contents = [] | |
for file_path in files: | |
if file_path.lower().endswith('.csv'): | |
content = analyze_csv_file(file_path) | |
elif file_path.lower().endswith('.txt'): | |
content = analyze_txt_file(file_path) | |
elif file_path.lower().endswith('.pdf'): | |
content = pdf_to_markdown(file_path) | |
else: | |
continue | |
doc_contents.append(content) | |
# ์ ์ฒด ํ๋กฌํํธ ๊ตฌ์ฑ | |
full_prompt = "\n\n".join(doc_contents) + f"\n\n{prompt}" | |
# ๋ฉ์์ง ๊ตฌ์ฑ | |
messages = [ | |
{"role": "system", "content": system_content}, | |
{"role": "user", "content": full_prompt} | |
] | |
# ์คํธ๋ฆฌ๋ฐ ์์ฑ | |
stream = llm.create_chat_completion( | |
messages=messages, | |
max_tokens=max_new_tokens, | |
temperature=0.8, | |
top_p=0.9, | |
stream=True | |
) | |
# ์คํธ๋ฆฌ๋ฐ ์ถ๋ ฅ | |
output = "" | |
for chunk in stream: | |
if 'choices' in chunk and len(chunk['choices']) > 0: | |
delta = chunk['choices'][0].get('delta', {}) | |
if 'content' in delta: | |
output += delta['content'] | |
yield output | |
except Exception as e: | |
logger.error(f"๋ฌธ์ ๋ถ์ ์ค๋ฅ: {e}") | |
yield f"โ ์ค๋ฅ ๋ฐ์: {str(e)}" | |
finally: | |
clear_cuda_cache() | |
############################################################################## | |
# Gradio UI (๋ก๋ด ์๊ฐํ ์ค์ฌ) | |
############################################################################## | |
css = """ | |
.robot-header { | |
text-align: center; | |
background: linear-gradient(135deg, #1e3c72 0%, #2a5298 50%, #667eea 100%); | |
color: white; | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.status-box { | |
text-align: center; | |
padding: 10px; | |
border-radius: 5px; | |
margin: 10px 0; | |
font-weight: bold; | |
} | |
.info-box { | |
background: #f0f0f0; | |
padding: 15px; | |
border-radius: 8px; | |
margin: 10px 0; | |
border-left: 4px solid #2a5298; | |
} | |
.task-button { | |
min-height: 60px; | |
font-size: 1.1em; | |
} | |
.webcam-container { | |
border: 3px solid #2a5298; | |
border-radius: 10px; | |
padding: 10px; | |
background: #f8f9fa; | |
} | |
.auto-capture-status { | |
text-align: center; | |
padding: 5px; | |
border-radius: 5px; | |
margin: 5px 0; | |
font-weight: bold; | |
background: #e8f5e9; | |
color: #2e7d32; | |
} | |
.model-info { | |
background: #fff3cd; | |
color: #856404; | |
padding: 10px; | |
border-radius: 5px; | |
margin: 10px 0; | |
text-align: center; | |
} | |
""" | |
with gr.Blocks(title="๐ค ๋ก๋ด ์๊ฐ ์์คํ (Gemma3-4B GGUF)", css=css) as demo: | |
gr.HTML(""" | |
<div class="robot-header"> | |
<h1>๐ค ๋ก๋ด ์๊ฐ ์์คํ </h1> | |
<h3>๐ฎ Gemma3-4B GGUF Q4_K_M + ๐ท ์ค์๊ฐ ์น์บ + ๐ ์น ๊ฒ์</h3> | |
<p>โก ์์ํ ๋ชจ๋ธ๋ก ๋ ๋น ๋ฅด๊ณ ํจ์จ์ ์ธ ๋ก๋ด ์์ ๋ถ์!</p> | |
</div> | |
""") | |
gr.HTML(""" | |
<div class="model-info"> | |
<strong>๋ชจ๋ธ:</strong> Gemma3-4B Q4_K_M (2.5GB) | <strong>๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ:</strong> ~3-4GB VRAM | |
</div> | |
""") | |
with gr.Row(): | |
# ์ผ์ชฝ: ์น์บ ๋ฐ ์ ๋ ฅ | |
with gr.Column(scale=1): | |
gr.Markdown("### ๐ท ์ค์๊ฐ ์น์บ ") | |
with gr.Group(elem_classes="webcam-container"): | |
webcam = gr.Image( | |
sources=["webcam"], | |
streaming=True, | |
type="numpy", | |
label="์ค์๊ฐ ์คํธ๋ฆฌ๋ฐ", | |
height=350 | |
) | |
# ์๋ ์บก์ฒ ์ํ ํ์ | |
auto_capture_status = gr.HTML( | |
'<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ๋๊ธฐ ์ค</div>' | |
) | |
# ์บก์ฒ๋ ์ด๋ฏธ์ง ํ์ | |
captured_image = gr.Image( | |
label="์บก์ฒ๋ ์ด๋ฏธ์ง", | |
height=200, | |
visible=False | |
) | |
# ๋ก๋ด ์์ ๋ฒํผ๋ค | |
gr.Markdown("### ๐ฏ ๋ก๋ด ์์ ์ ํ") | |
with gr.Row(): | |
capture_btn = gr.Button("๐ธ ์๋ ์บก์ฒ", variant="primary", elem_classes="task-button") | |
clear_capture_btn = gr.Button("๐๏ธ ์ด๊ธฐํ", elem_classes="task-button") | |
with gr.Row(): | |
auto_capture_toggle = gr.Checkbox( | |
label="๐ ์๋ ์บก์ฒ ํ์ฑํ (10์ด๋ง๋ค)", | |
value=False, | |
info="ํ์ฑํ ์ 10์ด๋ง๋ค ์๋์ผ๋ก ์บก์ฒ ๋ฐ ๋ถ์" | |
) | |
with gr.Row(): | |
planning_btn = gr.Button("๐ ์์ ๊ณํ", elem_classes="task-button") | |
grounding_btn = gr.Button("๐ ๊ฐ์ฒด ์์น", elem_classes="task-button") | |
with gr.Row(): | |
affordance_btn = gr.Button("๐ค ํ์ง์ ๋ถ์", elem_classes="task-button") | |
trajectory_btn = gr.Button("๐ค๏ธ ๊ฒฝ๋ก ๊ณํ", elem_classes="task-button") | |
# ์ค๋ฅธ์ชฝ: ๋ถ์ ์ค์ ๋ฐ ๊ฒฐ๊ณผ | |
with gr.Column(scale=2): | |
gr.Markdown("### โ๏ธ ๋ถ์ ์ค์ ") | |
with gr.Row(): | |
with gr.Column(): | |
task_prompt = gr.Textbox( | |
label="์์ ์ค๋ช / ์ง๋ฌธ", | |
placeholder="์: ํ ์ด๋ธ ์์ ์ปต์ ์ก์์ ์ฑํฌ๋์ ๋๊ธฐ", | |
value="ํ์ฌ ์ฅ๋ฉด์ ๋ถ์ํ๊ณ ๋ก๋ด์ด ์ํํ ์ ์๋ ์์ ์ ์ ์ํ์ธ์.", | |
lines=2 | |
) | |
with gr.Row(): | |
use_web_search = gr.Checkbox( | |
label="๐ ์น ๊ฒ์ ์ฌ์ฉ", | |
value=False, | |
info="๊ด๋ จ ์ ๋ณด๋ฅผ ์น์์ ๊ฒ์ํฉ๋๋ค" | |
) | |
enable_thinking = gr.Checkbox( | |
label="๐ค ์ถ๋ก ๊ณผ์ ํ์", | |
value=False, | |
info="Chain-of-Thought ์ถ๋ก ๊ณผ์ ์ ๋ณด์ฌ์ค๋๋ค" | |
) | |
max_tokens = gr.Slider( | |
label="์ต๋ ํ ํฐ ์", | |
minimum=100, | |
maximum=2048, | |
value=300, | |
step=50 | |
) | |
gr.Markdown("### ๐ ๋ถ์ ๊ฒฐ๊ณผ") | |
result_output = gr.Textbox( | |
label="AI ๋ถ์ ๊ฒฐ๊ณผ", | |
lines=20, | |
max_lines=40, | |
show_copy_button=True, | |
elem_id="result" | |
) | |
status_display = gr.HTML( | |
'<div class="status-box" style="background:#d4edda; color:#155724;">๐ฎ ์์คํ ์ค๋น ์๋ฃ</div>' | |
) | |
# ๋ฌธ์ ๋ถ์ ํญ | |
with gr.Tab("๐ ๋ฌธ์ ๋ถ์", visible=False): | |
with gr.Row(): | |
with gr.Column(): | |
doc_files = gr.File( | |
label="๋ฌธ์ ์ ๋ก๋", | |
file_count="multiple", | |
file_types=[".pdf", ".csv", ".txt"], | |
type="filepath" | |
) | |
doc_prompt = gr.Textbox( | |
label="๋ถ์ ์์ฒญ", | |
placeholder="์: ์ด ๋ฌธ์๋ค์ ํต์ฌ ๋ด์ฉ์ ์์ฝํ๊ณ ๋น๊ต ๋ถ์ํ์ธ์.", | |
lines=3 | |
) | |
doc_web_search = gr.Checkbox( | |
label="๐ ์น ๊ฒ์ ์ฌ์ฉ", | |
value=False | |
) | |
analyze_docs_btn = gr.Button("๐ ๋ฌธ์ ๋ถ์", variant="primary") | |
with gr.Column(): | |
doc_result = gr.Textbox( | |
label="๋ถ์ ๊ฒฐ๊ณผ", | |
lines=25, | |
max_lines=50 | |
) | |
# ์ด๋ฒคํธ ํธ๋ค๋ฌ | |
webcam_state = gr.State(None) | |
auto_capture_state = gr.State({"enabled": False, "timer": None}) | |
def capture_webcam(frame): | |
"""์น์บ ํ๋ ์ ์บก์ฒ""" | |
if frame is None: | |
return None, None, '<div class="status-box" style="background:#f8d7da; color:#721c24;">โ ์น์บ ํ๋ ์ ์์</div>' | |
return frame, gr.update(value=frame, visible=True), '<div class="status-box" style="background:#d4edda; color:#155724;">โ ์ด๋ฏธ์ง ์บก์ฒ ์๋ฃ</div>' | |
def clear_capture(): | |
"""์บก์ฒ ์ด๊ธฐํ""" | |
return None, gr.update(visible=False), '<div class="status-box" style="background:#d4edda; color:#155724;">๐ฎ ์์คํ ์ค๋น ์๋ฃ</div>' | |
def analyze_with_task(image, prompt, task_type, use_search, thinking, tokens): | |
"""ํน์ ํ์คํฌ๋ก ์ด๋ฏธ์ง ๋ถ์""" | |
if image is None: | |
return "โ ๋จผ์ ์ด๋ฏธ์ง๋ฅผ ์บก์ฒํ์ธ์.", '<div class="status-box" style="background:#f8d7da; color:#721c24;">โ ์ด๋ฏธ์ง ์์</div>' | |
status = f'<div class="status-box" style="background:#cce5ff; color:#004085;">๐ {task_type} ๋ถ์ ์ค...</div>' | |
result = analyze_image_for_robot( | |
image=image, | |
prompt=prompt, | |
task_type=task_type, | |
use_web_search=use_search, | |
enable_thinking=thinking, | |
max_new_tokens=tokens | |
) | |
# ๊ฒฐ๊ณผ ํฌ๋งทํ | |
timestamp = time.strftime("%H:%M:%S") | |
task_names = { | |
"planning": "์์ ๊ณํ", | |
"grounding": "๊ฐ์ฒด ์์น", | |
"affordance": "ํ์ง์ ", | |
"trajectory": "๊ฒฝ๋ก ๊ณํ" | |
} | |
formatted_result = f"""๐ค {task_names.get(task_type, '๋ถ์')} ๊ฒฐ๊ณผ ({timestamp}) | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
{result} | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ""" | |
complete_status = '<div class="status-box" style="background:#d4edda; color:#155724;">โ ๋ถ์ ์๋ฃ!</div>' | |
return formatted_result, complete_status | |
# ์๋ ์บก์ฒ ๋ฐ ๋ถ์ ํจ์ | |
def auto_capture_and_analyze(webcam_frame, task_prompt, use_search, thinking, tokens, auto_state): | |
"""์๋ ์บก์ฒ ๋ฐ ๋ถ์""" | |
if webcam_frame is None: | |
return ( | |
None, | |
"์๋ ์บก์ฒ ๋๊ธฐ ์ค...", | |
'<div class="status-box" style="background:#fff3cd; color:#856404;">โณ ์น์บ ๋๊ธฐ ์ค</div>', | |
'<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ์น์บ ๋๊ธฐ ์ค</div>' | |
) | |
# ์บก์ฒ ์ํ | |
timestamp = time.strftime("%H:%M:%S") | |
# ์ด๋ฏธ์ง ๋ถ์ (์์ ๊ณํ ๋ชจ๋๋ก) | |
result = analyze_image_for_robot( | |
image=webcam_frame, | |
prompt=task_prompt, | |
task_type="planning", | |
use_web_search=use_search, | |
enable_thinking=thinking, | |
max_new_tokens=tokens | |
) | |
formatted_result = f"""๐ ์๋ ๋ถ์ ์๋ฃ ({timestamp}) | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
{result} | |
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ""" | |
return ( | |
webcam_frame, | |
formatted_result, | |
'<div class="status-box" style="background:#d4edda; color:#155724;">โ ์๋ ๋ถ์ ์๋ฃ</div>', | |
f'<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ๋ง์ง๋ง ๋ถ์ {timestamp}</div>' | |
) | |
# ์น์บ ์คํธ๋ฆฌ๋ฐ | |
webcam.stream( | |
fn=lambda x: x, | |
inputs=[webcam], | |
outputs=[webcam_state] | |
) | |
# ์๋ ์บก์ฒ ๋ฒํผ | |
capture_btn.click( | |
fn=capture_webcam, | |
inputs=[webcam_state], | |
outputs=[webcam_state, captured_image, status_display] | |
) | |
# ์ด๊ธฐํ ๋ฒํผ | |
clear_capture_btn.click( | |
fn=clear_capture, | |
outputs=[webcam_state, captured_image, status_display] | |
) | |
# ์์ ๋ฒํผ๋ค | |
planning_btn.click( | |
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "planning", s, t, tk), | |
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens], | |
outputs=[result_output, status_display] | |
) | |
grounding_btn.click( | |
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "grounding", s, t, tk), | |
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens], | |
outputs=[result_output, status_display] | |
) | |
affordance_btn.click( | |
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "affordance", s, t, tk), | |
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens], | |
outputs=[result_output, status_display] | |
) | |
trajectory_btn.click( | |
fn=lambda img, p, s, t, tk: analyze_with_task(img, p, "trajectory", s, t, tk), | |
inputs=[captured_image, task_prompt, use_web_search, enable_thinking, max_tokens], | |
outputs=[result_output, status_display] | |
) | |
# ๋ฌธ์ ๋ถ์ | |
def analyze_docs(files, prompt, use_search): | |
if not files: | |
return "โ ๋ฌธ์๋ฅผ ์ ๋ก๋ํ์ธ์." | |
output = "" | |
for chunk in analyze_documents_streaming(files, prompt, use_search): | |
output = chunk | |
return output | |
analyze_docs_btn.click( | |
fn=analyze_docs, | |
inputs=[doc_files, doc_prompt, doc_web_search], | |
outputs=[doc_result] | |
) | |
# ์๋ ์บก์ฒ ํ์ด๋จธ (10์ด๋ง๋ค) | |
timer = gr.Timer(10.0, active=False) | |
# ์๋ ์บก์ฒ ํ ๊ธ ์ด๋ฒคํธ | |
def toggle_auto_capture(enabled): | |
if enabled: | |
return gr.Timer(10.0, active=True), '<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ํ์ฑํ๋จ (10์ด๋ง๋ค)</div>' | |
else: | |
return gr.Timer(active=False), '<div class="auto-capture-status">๐ ์๋ ์บก์ฒ: ๋นํ์ฑํ๋จ</div>' | |
auto_capture_toggle.change( | |
fn=toggle_auto_capture, | |
inputs=[auto_capture_toggle], | |
outputs=[timer, auto_capture_status] | |
) | |
# ํ์ด๋จธ ํฑ ์ด๋ฒคํธ | |
timer.tick( | |
fn=auto_capture_and_analyze, | |
inputs=[webcam_state, task_prompt, use_web_search, enable_thinking, max_tokens, auto_capture_state], | |
outputs=[captured_image, result_output, status_display, auto_capture_status] | |
) | |
# ์ด๊ธฐ ๋ชจ๋ธ ๋ก๋ | |
def initial_load(): | |
# ์ฒซ ์คํ ์ GPU์์ ๋ชจ๋ธ ๋ก๋ | |
return "์์คํ ์ค๋น ์๋ฃ! ์ฒซ ๋ถ์ ์ ๋ชจ๋ธ์ด ์๋์ผ๋ก ๋ก๋๋ฉ๋๋ค. ๐" | |
demo.load( | |
fn=initial_load, | |
outputs=None | |
) | |
if __name__ == "__main__": | |
print("๐ ๋ก๋ด ์๊ฐ ์์คํ ์์ (Gemma3-4B GGUF Q4_K_M)...") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True, | |
debug=False | |
) |