Agentic-AI-CHAT / app.py
ginipick's picture
Update app.py
5ad049a verified
raw
history blame
32.7 kB
#!/usr/bin/env python
import os
import re
import tempfile
import gc # garbage collector ์ถ”๊ฐ€
from collections.abc import Iterator
from threading import Thread
import json
import requests
import cv2
import base64
import logging
import time
from urllib.parse import quote # URL ์ธ์ฝ”๋”ฉ (ํ•„์š” ์‹œ ์‚ฌ์šฉ)
import gradio as gr
import spaces
import torch
from loguru import logger
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
# CSV/TXT/PDF ๋ถ„์„
import pandas as pd
import PyPDF2
# =============================================================================
# (์‹ ๊ทœ) ์ด๋ฏธ์ง€ API ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
# =============================================================================
from gradio_client import Client
API_URL = "http://211.233.58.201:7896"
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def test_api_connection() -> str:
"""API ์„œ๋ฒ„ ์—ฐ๊ฒฐ ํ…Œ์ŠคํŠธ"""
try:
client = Client(API_URL)
return "API ์—ฐ๊ฒฐ ์„ฑ๊ณต: ์ •์ƒ ์ž‘๋™ ์ค‘"
except Exception as e:
logging.error(f"API connection test failed: {e}")
return f"API ์—ฐ๊ฒฐ ์‹คํŒจ: {e}"
def generate_image(prompt: str, width: float, height: float, guidance: float, inference_steps: float, seed: float):
"""
์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜.
์—ฌ๊ธฐ์„œ๋Š” ์„œ๋ฒ„๊ฐ€ ์ตœ์ข… ์ด๋ฏธ์ง€๋ฅผ Base64(๋˜๋Š” data:image/...) ํ˜•ํƒœ๋กœ ์ง์ ‘ ๋ฐ˜ํ™˜ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
/tmp/... ๊ฒฝ๋กœ๋‚˜ ์ถ”๊ฐ€ ๋‹ค์šด๋กœ๋“œ๋ฅผ ์‹œ๋„ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
"""
if not prompt:
return None, "Error: Prompt is required"
try:
logging.info(f"Calling image generation API with prompt: {prompt}")
client = Client(API_URL)
result = client.predict(
prompt=prompt,
width=int(width),
height=int(height),
guidance=float(guidance),
inference_steps=int(inference_steps),
seed=int(seed),
do_img2img=False,
init_image=None,
image2image_strength=0.8,
resize_img=True,
api_name="/generate_image"
)
logging.info(
f"Image generation result: {type(result)}, "
f"length: {len(result) if isinstance(result, (list, tuple)) else 'unknown'}"
)
# ๊ฒฐ๊ณผ๊ฐ€ ํŠœํ”Œ/๋ฆฌ์ŠคํŠธ: [์ด๋ฏธ์ง€_base64 or data_url, seed_info] ๋กœ ๊ฐ€์ •
if isinstance(result, (list, tuple)) and len(result) > 0:
image_data = result[0] # ์ฒซ ๋ฒˆ์งธ ์š”์†Œ๊ฐ€ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ (Base64 or data:image/... ๋“ฑ)
seed_info = result[1] if len(result) > 1 else "Unknown seed"
return image_data, seed_info
else:
# ๋‹ค๋ฅธ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜๋œ ๊ฒฝ์šฐ
return result, "Unknown seed"
except Exception as e:
logging.error(f"Image generation failed: {str(e)}")
return None, f"Error: {str(e)}"
# Base64 ํŒจ๋”ฉ ์ˆ˜์ • ํ•จ์ˆ˜ (ํ•„์š”ํ•˜๋‹ค๋ฉด ์‚ฌ์šฉ)
def fix_base64_padding(data):
"""Base64 ๋ฌธ์ž์—ด์˜ ํŒจ๋”ฉ์„ ์ˆ˜์ •ํ•ฉ๋‹ˆ๋‹ค."""
if isinstance(data, bytes):
data = data.decode('utf-8')
if "base64," in data:
data = data.split("base64,", 1)[1]
missing_padding = len(data) % 4
if missing_padding:
data += '=' * (4 - missing_padding)
return data
# =============================================================================
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜
# =============================================================================
def clear_cuda_cache():
"""CUDA ์บ์‹œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋น„์›๋‹ˆ๋‹ค."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# =============================================================================
# SerpHouse ๊ด€๋ จ ํ•จ์ˆ˜
# =============================================================================
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
def extract_keywords(text: str, top_k: int = 5) -> str:
"""๋‹จ์ˆœ ํ‚ค์›Œ๋“œ ์ถ”์ถœ: ํ•œ๊ธ€, ์˜์–ด, ์ˆซ์ž, ๊ณต๋ฐฑ๋งŒ ๋‚จ๊น€"""
text = re.sub(r"[^a-zA-Z0-9๊ฐ€-ํžฃ\s]", "", text)
tokens = text.split()
return " ".join(tokens[:top_k])
def do_web_search(query: str) -> str:
"""
SerpHouse LIVE API ํ˜ธ์ถœํ•˜์—ฌ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋งˆํฌ๋‹ค์šด ๋ฐ˜ํ™˜
(ํ•„์š”ํ•˜๋‹ค๋ฉด ์ˆ˜์ • or ์‚ญ์ œ ๊ฐ€๋Šฅ)
"""
try:
url = "https://api.serphouse.com/serp/live"
params = {
"q": query,
"domain": "google.com",
"serp_type": "web",
"device": "desktop",
"lang": "en",
"num": "20"
}
headers = {"Authorization": f"Bearer {SERPHOUSE_API_KEY}"}
logger.info(f"SerpHouse API ํ˜ธ์ถœ ์ค‘... ๊ฒ€์ƒ‰์–ด: {query}")
response = requests.get(url, headers=headers, params=params, timeout=60)
response.raise_for_status()
data = response.json()
results = data.get("results", {})
organic = None
if isinstance(results, dict) and "organic" in results:
organic = results["organic"]
elif isinstance(results, dict) and "results" in results:
if isinstance(results["results"], dict) and "organic" in results["results"]:
organic = results["results"]["organic"]
elif "organic" in data:
organic = data["organic"]
if not organic:
logger.warning("์‘๋‹ต์—์„œ organic ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
return "No web search results found or unexpected API response structure."
max_results = min(20, len(organic))
limited_organic = organic[:max_results]
summary_lines = []
for idx, item in enumerate(limited_organic, start=1):
title = item.get("title", "No title")
link = item.get("link", "#")
snippet = item.get("snippet", "No description")
displayed_link = item.get("displayed_link", link)
summary_lines.append(
f"### Result {idx}: {title}\n\n"
f"{snippet}\n\n"
f"**์ถœ์ฒ˜**: [{displayed_link}]({link})\n\n"
f"---\n"
)
instructions = """
# ์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ
์•„๋ž˜๋Š” ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•  ๋•Œ ์ด ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์„ธ์š”:
1. ์—ฌ๋Ÿฌ ์ถœ์ฒ˜ ๋‚ด์šฉ์„ ์ข…ํ•ฉํ•˜์—ฌ ๋‹ต๋ณ€.
2. ์ถœ์ฒ˜ ์ธ์šฉ ์‹œ "[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)" ๋งˆํฌ๋‹ค์šด ํ˜•์‹ ์‚ฌ์šฉ.
3. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— '์ฐธ๊ณ  ์ž๋ฃŒ:' ์„น์…˜์— ์‚ฌ์šฉํ•œ ์ฃผ์š” ์ถœ์ฒ˜๋ฅผ ๋‚˜์—ด.
"""
return instructions + "\n".join(summary_lines)
except Exception as e:
logger.error(f"Web search failed: {e}")
return f"Web search failed: {str(e)}"
# =============================================================================
# ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋”ฉ
# =============================================================================
MAX_CONTENT_CHARS = 2000
MAX_INPUT_LENGTH = 2096
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager"
)
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
# =============================================================================
# CSV, TXT, PDF ๋ถ„์„ ํ•จ์ˆ˜
# =============================================================================
def analyze_csv_file(path: str) -> str:
try:
df = pd.read_csv(path)
if df.shape[0] > 50 or df.shape[1] > 10:
df = df.iloc[:50, :10]
df_str = df.to_string()
if len(df_str) > MAX_CONTENT_CHARS:
df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
except Exception as e:
return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
def analyze_txt_file(path: str) -> str:
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read()
if len(text) > MAX_CONTENT_CHARS:
text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
except Exception as e:
return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
def pdf_to_markdown(pdf_path: str) -> str:
text_chunks = []
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
max_pages = min(5, len(reader.pages))
for page_num in range(max_pages):
page_text = reader.pages[page_num].extract_text() or ""
page_text = page_text.strip()
if page_text:
if len(page_text) > MAX_CONTENT_CHARS // max_pages:
page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
if len(reader.pages) > max_pages:
text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
except Exception as e:
return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
full_text = "\n".join(text_chunks)
if len(full_text) > MAX_CONTENT_CHARS:
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
# =============================================================================
# ์ด๋ฏธ์ง€/๋น„๋””์˜ค ํŒŒ์ผ ์ œํ•œ ๊ฒ€์‚ฌ
# =============================================================================
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
image_count = 0
video_count = 0
for path in paths:
if path.endswith(".mp4"):
video_count += 1
elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
image_count += 1
return image_count, video_count
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
image_count = 0
video_count = 0
for item in history:
if item["role"] != "user" or isinstance(item["content"], str):
continue
if isinstance(item["content"], list) and len(item["content"]) > 0:
file_path = item["content"][0]
if isinstance(file_path, str):
if file_path.endswith(".mp4"):
video_count += 1
elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
image_count += 1
return image_count, video_count
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
"""์ด๋ฏธ์ง€/๋น„๋””์˜ค ์—…๋กœ๋“œ ์ œํ•œ ๊ฒ€์‚ฌ."""
media_files = [f for f in message["files"]
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4")]
new_image_count, new_video_count = count_files_in_new_message(media_files)
history_image_count, history_video_count = count_files_in_history(history)
image_count = history_image_count + new_image_count
video_count = history_video_count + new_video_count
if video_count > 1:
gr.Warning("Only one video is supported.")
return False
if video_count == 1:
if image_count > 0:
gr.Warning("Mixing images and videos is not allowed.")
return False
if "<image>" in message["text"]:
gr.Warning("Using <image> tags with video files is not supported.")
return False
if video_count == 0 and image_count > MAX_NUM_IMAGES:
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
return False
if "<image>" in message["text"]:
image_files = [f for f in message["files"]
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
image_tag_count = message["text"].count("<image>")
if image_tag_count != len(image_files):
gr.Warning("The number of <image> tags in the text does not match the number of image files.")
return False
return True
# =============================================================================
# ๋น„๋””์˜ค ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
# =============================================================================
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
vidcap = cv2.VideoCapture(video_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_interval = max(int(fps), int(total_frames / 10))
frames = []
for i in range(0, total_frames, frame_interval):
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
if len(frames) >= 5:
break
vidcap.release()
return frames
def process_video(video_path: str) -> tuple[list[dict], list[str]]:
content = []
temp_files = []
frames = downsample_video(video_path)
for pil_image, timestamp in frames:
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
pil_image.save(temp_file.name)
temp_files.append(temp_file.name)
content.append({"type": "text", "text": f"Frame {timestamp}:"})
content.append({"type": "image", "url": temp_file.name})
return content, temp_files
# =============================================================================
# interleaved <image> ์ฒ˜๋ฆฌ ํ•จ์ˆ˜ (<image> ํƒœ๊ทธ์™€ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ ํ˜ผํ•ฉ ์ง€์›)
# =============================================================================
def process_interleaved_images(message: dict) -> list[dict]:
parts = re.split(r"(<image>)", message["text"])
content = []
image_files = [f for f in message["files"]
if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
image_index = 0
for part in parts:
if part == "<image>" and image_index < len(image_files):
content.append({"type": "image", "url": image_files[image_index]})
image_index += 1
elif part.strip():
content.append({"type": "text", "text": part.strip()})
else:
if isinstance(part, str) and part != "<image>":
content.append({"type": "text", "text": part})
return content
# =============================================================================
# ํŒŒ์ผ ์ฒ˜๋ฆฌ -> content ์ƒ์„ฑ
# =============================================================================
def is_image_file(file_path: str) -> bool:
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
def is_video_file(file_path: str) -> bool:
return file_path.endswith(".mp4")
def is_document_file(file_path: str) -> bool:
return file_path.lower().endswith(".pdf") or file_path.lower().endswith(".csv") or file_path.lower().endswith(".txt")
def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
"""์‚ฌ์šฉ์ž๊ฐ€ ์ƒˆ๋กœ ์ž…๋ ฅํ•œ ๋ฉ”์‹œ์ง€ + ์—…๋กœ๋“œ ํŒŒ์ผ๋“ค์„ ํ•˜๋‚˜์˜ content(list)๋กœ ๋ณ€ํ™˜."""
temp_files = []
if not message["files"]:
return [{"type": "text", "text": message["text"]}], temp_files
video_files = [f for f in message["files"] if is_video_file(f)]
image_files = [f for f in message["files"] if is_image_file(f)]
csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
content_list = [{"type": "text", "text": message["text"]}]
# ๋ฌธ์„œ๋“ค
for csv_path in csv_files:
content_list.append({"type": "text", "text": analyze_csv_file(csv_path)})
for txt_path in txt_files:
content_list.append({"type": "text", "text": analyze_txt_file(txt_path)})
for pdf_path in pdf_files:
content_list.append({"type": "text", "text": pdf_to_markdown(pdf_path)})
# ๋น„๋””์˜ค ์ฒ˜๋ฆฌ
if video_files:
video_content, video_temp_files = process_video(video_files[0])
content_list += video_content
temp_files.extend(video_temp_files)
return content_list, temp_files
# ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
if "<image>" in message["text"] and image_files:
interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
if content_list and content_list[0]["type"] == "text":
content_list = content_list[1:]
return interleaved_content + content_list, temp_files
else:
for img_path in image_files:
content_list.append({"type": "image", "url": img_path})
return content_list, temp_files
# =============================================================================
# history -> LLM ๋ฉ”์‹œ์ง€ ๋ณ€ํ™˜
# =============================================================================
def process_history(history: list[dict]) -> list[dict]:
"""
๊ธฐ์กด ๋Œ€ํ™” ๊ธฐ๋ก์„ LLM์— ๋งž๊ฒŒ ๋ณ€ํ™˜.
- user -> {"role":"user","content":[{type,text},...]}
- assistant -> {"role":"assistant","content":[{type:"text",text},...]}
"""
messages = []
current_user_content = []
for item in history:
if item["role"] == "assistant":
# ์‚ฌ์šฉ์ž content ๋ˆ„์ ๋ถ„์ด ์žˆ์œผ๋ฉด ํ•œ๋ฒˆ์— user๋กœ ์ถ”๊ฐ€
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
# assistant ๋ฐ”๋กœ ์ถ”๊ฐ€
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
else:
content = item["content"]
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list) and len(content) > 0:
file_path = content[0]
if is_image_file(file_path):
current_user_content.append({"type": "image", "url": file_path})
else:
current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
# =============================================================================
# ๋ชจ๋ธ ์ƒ์„ฑ ํ•จ์ˆ˜ (OOM ์บ์น˜)
# =============================================================================
def _model_gen_with_oom_catch(**kwargs):
try:
model.generate(**kwargs)
except torch.cuda.OutOfMemoryError:
raise RuntimeError("[OutOfMemoryError] GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค.")
finally:
clear_cuda_cache()
# =============================================================================
# ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
# =============================================================================
@spaces.GPU(duration=120)
def run(
message: dict,
history: list[dict],
system_prompt: str = "",
max_new_tokens: int = 512,
use_web_search: bool = False,
web_search_query: str = "",
age_group: str = "20๋Œ€",
mbti_personality: str = "INTP",
sexual_openness: int = 2,
image_gen: bool = False
) -> Iterator[str]:
"""
LLM ์ถ”๋ก  ํ•จ์ˆ˜.
- ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹œ, ์„œ๋ฒ„๊ฐ€ Base64(๋˜๋Š” data:image/... ํ˜•ํƒœ)๋ฅผ ์ง์ ‘ ๋ฐ˜ํ™˜ํ•œ๋‹ค๊ณ  ๊ฐ€์ •.
- /tmp/... ํŒŒ์ผ์— ๋Œ€ํ•œ ์žฌ๋‹ค์šด๋กœ๋“œ๋ฅผ ์‹œ๋„ํ•˜์ง€ ์•Š์Œ (403 Forbidden ๋ฌธ์ œ ํšŒํ”ผ).
"""
if not validate_media_constraints(message, history):
yield ""
return
temp_files = []
try:
# 1) ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ + ํŽ˜๋ฅด์†Œ๋‚˜ ์ •๋ณด
persona = (
f"{system_prompt.strip()}\n\n"
f"Gender: Female\n"
f"Age Group: {age_group}\n"
f"MBTI Persona: {mbti_personality}\n"
f"Sexual Openness (1~5): {sexual_openness}\n"
)
combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
# 2) ์›น ๊ฒ€์ƒ‰ (์˜ต์…˜)
if use_web_search:
user_text = message["text"]
ws_query = extract_keywords(user_text)
if ws_query.strip():
logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
ws_result = do_web_search(ws_query)
combined_system_msg += f"[Search top-20 Full Items]\n{ws_result}\n\n"
combined_system_msg += (
"[์ฐธ๊ณ : ์œ„ ๊ฒ€์ƒ‰๊ฒฐ๊ณผ link๋ฅผ ์ถœ์ฒ˜๋กœ ์ธ์šฉํ•˜์—ฌ ๋‹ต๋ณ€]\n"
"[์ค‘์š” ์ง€์‹œ์‚ฌํ•ญ]\n"
"1. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ์—์„œ ์ฐพ์€ ์ •๋ณด์˜ ์ถœ์ฒ˜๋ฅผ ๋ฐ˜๋“œ์‹œ ์ธ์šฉ.\n"
"2. '[์ถœ์ฒ˜ ์ œ๋ชฉ](๋งํฌ)' ํ˜•์‹์œผ๋กœ ๋งํฌ.\n"
"3. ๋‹ต๋ณ€ ๋งˆ์ง€๋ง‰์— '์ฐธ๊ณ  ์ž๋ฃŒ:' ์„น์…˜.\n"
)
else:
combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
# 3) ๊ธฐ์กด history + ์ƒˆ user ๋ฉ”์‹œ์ง€
messages = []
if combined_system_msg.strip():
messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
messages.extend(process_history(history))
user_content, user_temp_files = process_new_user_message(message)
temp_files.extend(user_temp_files)
for item in user_content:
if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
messages.append({"role": "user", "content": user_content})
# 4) ํ† ํฌ๋‚˜์ด์ง•
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=model.device, dtype=torch.bfloat16)
if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
if 'attention_mask' in inputs:
inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
t.start()
# ์ŠคํŠธ๋ฆฌ๋ฐ ์ถœ๋ ฅ
output_so_far = ""
for new_text in streamer:
output_so_far += new_text
yield output_so_far
# 5) ์ด๋ฏธ์ง€ ์ƒ์„ฑ (Base64)
if image_gen:
last_user_text = message["text"].strip()
if not last_user_text:
yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: Empty user prompt)"
else:
try:
width, height = 512, 512
guidance, steps, seed = 7.5, 30, 42
logger.info(f"Generating image with prompt: {last_user_text}")
# API ํ˜ธ์ถœํ•ด์„œ (base64) ์ด๋ฏธ์ง€ ์ƒ์„ฑ
image_result, seed_info = generate_image(
prompt=last_user_text,
width=width,
height=height,
guidance=guidance,
inference_steps=steps,
seed=seed
)
logger.info(f"Received image data type: {type(image_result)}")
# Base64 or data:image/... ์ฒ˜๋ฆฌ
if image_result:
if isinstance(image_result, str):
# ์ด๋ฏธ data:image/๋กœ ์‹œ์ž‘ํ•˜๋ฉด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
if image_result.startswith("data:image/"):
final_md = f"\n\n**[์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]**\n\n![์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]({image_result})"
yield output_so_far + final_md
else:
# ์ˆœ์ˆ˜ base64๋กœ ํŒ๋‹จ(๋‹จ, ์ผ๋ฐ˜ URL์ด๋‚˜ '/tmp/...'์ด๋ฉด ์ฒ˜๋ฆฌ ๋ถˆ๊ฐ€)
if len(image_result) > 100 and "/" not in image_result:
# base64
image_data = "data:image/webp;base64," + image_result
final_md = f"\n\n**[์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]**\n\n![์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€]({image_data})"
yield output_so_far + final_md
else:
# ๊ทธ ์™ธ (ex. http://..., /tmp/...) -> 403 ๋ฌธ์ œ ๋ฐœ์ƒํ•˜๋ฏ€๋กœ ํ‘œ์‹œ ์•ˆ ํ•จ
yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ base64 ํ˜•์‹์ด ์•„๋‹™๋‹ˆ๋‹ค)"
else:
yield output_so_far + "\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ๋ฌธ์ž์—ด์ด ์•„๋‹˜)"
else:
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์‹คํŒจ: {seed_info})"
except Exception as e:
logger.error(f"Image generation error: {e}")
yield output_so_far + f"\n\n(์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e})"
except Exception as e:
logger.error(f"Error in run: {str(e)}")
yield f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
finally:
for tmp in temp_files:
try:
if os.path.exists(tmp):
os.unlink(tmp)
logger.info(f"Deleted temp file: {tmp}")
except Exception as ee:
logger.warning(f"Failed to delete temp file {tmp}: {ee}")
try:
del inputs, streamer
except Exception:
pass
clear_cuda_cache()
# =============================================================================
# ์˜ˆ์‹œ๋“ค
# =============================================================================
examples = [
[
{
"text": "Compare the contents of the two PDF files.",
"files": [
"assets/additional-examples/before.pdf",
"assets/additional-examples/after.pdf",
],
}
],
[
{
"text": "Summarize and analyze the contents of the CSV file.",
"files": ["assets/additional-examples/sample-csv.csv"],
}
],
# ... ๋‚˜๋จธ์ง€ ์˜ˆ์‹œ ํ•„์š”ํ•˜๋‹ค๋ฉด ์ถ”๊ฐ€ ...
]
# =============================================================================
# Gradio UI (Blocks) ๊ตฌ์„ฑ
# =============================================================================
css = """
.gradio-container {
background: rgba(255, 255, 255, 0.7);
padding: 30px 40px;
margin: 20px auto;
width: 100% !important;
max-width: none !important;
}
"""
title_html = """
<h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> ๐Ÿ’˜ HeartSync : Love Dating AI ๐Ÿ’˜ </h1>
<p align="center" style="font-size:1.1em; color:#555;">
โœ… FLUX Image Generation โœ… Reasoning & Uncensored โœ… Multimodal & VLM โœ… Deep-Research & RAG <br>
</p>
"""
with gr.Blocks(css=css, title="HeartSync") as demo:
gr.Markdown(title_html)
# ๋ณ„๋„ ๊ฐค๋Ÿฌ๋ฆฌ ์˜ˆ์‹œ (ํ•„์š” ์‹œ ์‚ฌ์šฉ)
generated_images = gr.Gallery(
label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€",
show_label=True,
visible=False,
elem_id="generated_images",
columns=2,
height="auto",
object_fit="contain"
)
with gr.Row():
web_search_checkbox = gr.Checkbox(label="Deep Research", value=False)
image_gen_checkbox = gr.Checkbox(label="Image Gen", value=False)
base_system_prompt_box = gr.Textbox(
lines=3,
value="You are a deep thinking AI...\nํŽ˜๋ฅด์†Œ๋‚˜: ๋‹น์‹ ์€ ๋‹ฌ์ฝคํ•˜๊ณ ...",
label="๊ธฐ๋ณธ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ",
visible=False
)
with gr.Row():
age_group_dropdown = gr.Dropdown(
label="์—ฐ๋ น๋Œ€ ์„ ํƒ (๊ธฐ๋ณธ 20๋Œ€)",
choices=["10๋Œ€", "20๋Œ€", "30~40๋Œ€", "50~60๋Œ€", "70๋Œ€ ์ด์ƒ"],
value="20๋Œ€",
interactive=True
)
mbti_choices = [
"INTJ (์šฉ์˜์ฃผ๋„ํ•œ ์ „๋žต๊ฐ€)",
"INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
"ENTJ (๋Œ€๋‹ดํ•œ ํ†ต์†”์ž)",
"ENTP (๋œจ๊ฑฐ์šด ๋…ผ์Ÿ๊ฐ€)",
"INFJ (์„ ์˜์˜ ์˜นํ˜ธ์ž)",
"INFP (์—ด์ •์ ์ธ ์ค‘์žฌ์ž)",
"ENFJ (์ •์˜๋กœ์šด ์‚ฌํšŒ์šด๋™๊ฐ€)",
"ENFP (์žฌ๊ธฐ๋ฐœ๋ž„ํ•œ ํ™œ๋™๊ฐ€)",
"ISTJ (์ฒญ๋ ด๊ฒฐ๋ฐฑํ•œ ๋…ผ๋ฆฌ์ฃผ์˜์ž)",
"ISFJ (์šฉ๊ฐํ•œ ์ˆ˜ํ˜ธ์ž)",
"ESTJ (์—„๊ฒฉํ•œ ๊ด€๋ฆฌ์ž)",
"ESFJ (์‚ฌ๊ต์ ์ธ ์™ธ๊ต๊ด€)",
"ISTP (๋งŒ๋Šฅ ์žฌ์ฃผ๊พผ)",
"ISFP (ํ˜ธ๊ธฐ์‹ฌ ๋งŽ์€ ์˜ˆ์ˆ ๊ฐ€)",
"ESTP (๋ชจํ—˜์„ ์ฆ๊ธฐ๋Š” ์‚ฌ์—…๊ฐ€)",
"ESFP (์ž์œ ๋กœ์šด ์˜ํ˜ผ์˜ ์—ฐ์˜ˆ์ธ)"
]
mbti_dropdown = gr.Dropdown(
label="AI ํŽ˜๋ฅด์†Œ๋‚˜ MBTI (๊ธฐ๋ณธ INTP)",
choices=mbti_choices,
value="INTP (๋…ผ๋ฆฌ์ ์ธ ์‚ฌ์ƒ‰๊ฐ€)",
interactive=True
)
sexual_openness_slider = gr.Slider(
minimum=1, maximum=5, step=1, value=2,
label="์„น์Šˆ์–ผ ๊ด€์‹ฌ๋„/๊ฐœ๋ฐฉ์„ฑ (1~5, ๊ธฐ๋ณธ=2)",
interactive=True
)
max_tokens_slider = gr.Slider(
label="Max New Tokens",
minimum=100, maximum=8000, step=50, value=1000,
visible=False
)
web_search_text = gr.Textbox(
lines=1,
label="(Unused) Web Search Query",
placeholder="No direct input needed",
visible=False
)
def modified_run(
message, history, system_prompt, max_new_tokens,
use_web_search, web_search_query,
age_group, mbti_personality, sexual_openness, image_gen
):
"""
run() ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ํ…์ŠคํŠธ ์ŠคํŠธ๋ฆผ์„ ๋ฐ›๊ณ ,
ํ•„์š” ์‹œ ์ถ”๊ฐ€ ์ฒ˜๋ฆฌ ํ›„ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ (๊ฐค๋Ÿฌ๋ฆฌ ์—…๋ฐ์ดํŠธ ๋“ฑ).
"""
output_so_far = ""
gallery_update = gr.Gallery(visible=False, value=[])
yield output_so_far, gallery_update
text_generator = run(
message, history,
system_prompt, max_new_tokens,
use_web_search, web_search_query,
age_group, mbti_personality,
sexual_openness, image_gen
)
for text_chunk in text_generator:
output_so_far = text_chunk
yield output_so_far, gallery_update
# ๋งŒ์•ฝ run() ๋‚ด๋ถ€์—์„œ Base64 ์ด๋ฏธ์ง€๋ฅผ ์ด๋ฏธ ๋Œ€ํ™”์ฐฝ์— ์‚ฝ์ž…ํ–ˆ๋‹ค๋ฉด,
# ์—ฌ๊ธฐ์„œ ๊ฐค๋Ÿฌ๋ฆฌ์— ๋”ฐ๋กœ ํ‘œ์‹œํ•  ํ•„์š”๋Š” ์—†์„ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
# run() ๋‚ด๋ถ€์—์„œ์˜ image_result๋ฅผ ๊ฐ€์ ธ์˜ค๋ ค๋ฉด, run() ํ•จ์ˆ˜๊ฐ€ ํ•ด๋‹น ์ •๋ณด๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋„๋ก ์ถ”๊ฐ€ ์ˆ˜์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
chat = gr.ChatInterface(
fn=modified_run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
textbox=gr.MultimodalTextbox(
file_types=[".webp", ".png", ".jpg", ".jpeg", ".gif", ".mp4", ".csv", ".txt", ".pdf"],
file_count="multiple",
autofocus=True
),
multimodal=True,
additional_inputs=[
base_system_prompt_box,
max_tokens_slider,
web_search_checkbox,
web_search_text,
age_group_dropdown,
mbti_dropdown,
sexual_openness_slider,
image_gen_checkbox,
],
additional_outputs=[generated_images],
stop_btn=False,
title='<a href="https://discord.gg/openfreeai" target="_blank">https://discord.gg/openfreeai</a>',
examples=examples,
run_examples_on_click=False,
cache_examples=False,
css_paths=None,
delete_cache=(1800, 1800),
)
with gr.Row(elem_id="examples_row"):
with gr.Column(scale=12, elem_id="examples_container"):
gr.Markdown("### Example Inputs (click to load)")
if __name__ == "__main__":
demo.launch(share=True)