Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ import base64
|
|
13 |
import logging
|
14 |
import time
|
15 |
from urllib.parse import quote # Added for URL encoding
|
16 |
-
import importlib #
|
17 |
|
18 |
import gradio as gr
|
19 |
import spaces
|
@@ -84,7 +84,6 @@ def generate_image(prompt: str, width: float, height: float, guidance: float, in
|
|
84 |
logging.error(f"Image generation failed: {str(e)}")
|
85 |
return None, f"Error: {str(e)}"
|
86 |
|
87 |
-
# Base64 padding fix function
|
88 |
def fix_base64_padding(data):
|
89 |
"""Fix the padding of a Base64 string."""
|
90 |
if isinstance(data, bytes):
|
@@ -99,18 +98,12 @@ def fix_base64_padding(data):
|
|
99 |
|
100 |
return data
|
101 |
|
102 |
-
# =============================================================================
|
103 |
-
# Memory cleanup function
|
104 |
-
# =============================================================================
|
105 |
def clear_cuda_cache():
|
106 |
"""Explicitly clear the CUDA cache."""
|
107 |
if torch.cuda.is_available():
|
108 |
torch.cuda.empty_cache()
|
109 |
gc.collect()
|
110 |
|
111 |
-
# =============================================================================
|
112 |
-
# SerpHouse related functions
|
113 |
-
# =============================================================================
|
114 |
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
|
115 |
|
116 |
def extract_keywords(text: str, top_k: int = 5) -> str:
|
@@ -176,9 +169,6 @@ Below are the search results. Use this information to answer the query:
|
|
176 |
logger.error(f"Web search failed: {e}")
|
177 |
return f"Web search failed: {str(e)}"
|
178 |
|
179 |
-
# =============================================================================
|
180 |
-
# Model and processor loading
|
181 |
-
# =============================================================================
|
182 |
MAX_CONTENT_CHARS = 2000
|
183 |
MAX_INPUT_LENGTH = 2096
|
184 |
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
|
@@ -191,9 +181,6 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
191 |
)
|
192 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
193 |
|
194 |
-
# =============================================================================
|
195 |
-
# CSV, TXT, PDF analysis functions
|
196 |
-
# =============================================================================
|
197 |
def analyze_csv_file(path: str) -> str:
|
198 |
try:
|
199 |
df = pd.read_csv(path)
|
@@ -238,9 +225,6 @@ def pdf_to_markdown(pdf_path: str) -> str:
|
|
238 |
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
|
239 |
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
|
240 |
|
241 |
-
# =============================================================================
|
242 |
-
# Check media file limits
|
243 |
-
# =============================================================================
|
244 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
245 |
image_count = 0
|
246 |
video_count = 0
|
@@ -293,9 +277,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
|
293 |
return False
|
294 |
return True
|
295 |
|
296 |
-
# =============================================================================
|
297 |
-
# Video processing functions
|
298 |
-
# =============================================================================
|
299 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
300 |
vidcap = cv2.VideoCapture(video_path)
|
301 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
@@ -328,9 +309,6 @@ def process_video(video_path: str) -> tuple[list[dict], list[str]]:
|
|
328 |
content.append({"type": "image", "url": temp_file.name})
|
329 |
return content, temp_files
|
330 |
|
331 |
-
# =============================================================================
|
332 |
-
# Interleaved <image> processing function
|
333 |
-
# =============================================================================
|
334 |
def process_interleaved_images(message: dict) -> list[dict]:
|
335 |
parts = re.split(r"(<image>)", message["text"])
|
336 |
content = []
|
@@ -347,9 +325,6 @@ def process_interleaved_images(message: dict) -> list[dict]:
|
|
347 |
content.append({"type": "text", "text": part})
|
348 |
return content
|
349 |
|
350 |
-
# =============================================================================
|
351 |
-
# File processing -> content creation
|
352 |
-
# =============================================================================
|
353 |
def is_image_file(file_path: str) -> bool:
|
354 |
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
|
355 |
|
@@ -390,9 +365,6 @@ def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
|
|
390 |
content_list.append({"type": "image", "url": img_path})
|
391 |
return content_list, temp_files
|
392 |
|
393 |
-
# =============================================================================
|
394 |
-
# Convert history to LLM messages
|
395 |
-
# =============================================================================
|
396 |
def process_history(history: list[dict]) -> list[dict]:
|
397 |
messages = []
|
398 |
current_user_content = []
|
@@ -416,9 +388,6 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
416 |
messages.append({"role": "user", "content": current_user_content})
|
417 |
return messages
|
418 |
|
419 |
-
# =============================================================================
|
420 |
-
# Model generation function (with OOM catching)
|
421 |
-
# =============================================================================
|
422 |
def _model_gen_with_oom_catch(**kwargs):
|
423 |
try:
|
424 |
model.generate(**kwargs)
|
@@ -433,18 +402,10 @@ def _model_gen_with_oom_catch(**kwargs):
|
|
433 |
def load_function_definitions(json_path="functions.json"):
|
434 |
"""
|
435 |
로컬 JSON 파일에서 함수 정의 목록을 로드하여 반환.
|
436 |
-
각 항목: {
|
437 |
-
"name": <str>,
|
438 |
-
"description": <str>,
|
439 |
-
"module_path": <str>,
|
440 |
-
"func_name_in_module": <str>,
|
441 |
-
"parameters": { ... }
|
442 |
-
}
|
443 |
"""
|
444 |
try:
|
445 |
with open(json_path, "r", encoding="utf-8") as f:
|
446 |
data = json.load(f)
|
447 |
-
# name을 키로 하는 dict 형태로 재구성
|
448 |
func_dict = {}
|
449 |
for entry in data:
|
450 |
func_name = entry["name"]
|
@@ -456,9 +417,6 @@ def load_function_definitions(json_path="functions.json"):
|
|
456 |
|
457 |
FUNCTION_DEFINITIONS = load_function_definitions("functions.json")
|
458 |
|
459 |
-
# =============================================================================
|
460 |
-
# Dynamic handle_function_call
|
461 |
-
# =============================================================================
|
462 |
def handle_function_call(text: str) -> str:
|
463 |
"""
|
464 |
Detects and processes function call blocks in the text using the JSON-based approach.
|
@@ -470,7 +428,6 @@ def handle_function_call(text: str) -> str:
|
|
470 |
```tool_code
|
471 |
get_product_name_by_PID(PID="807ZPKBL9V")
|
472 |
```
|
473 |
-
We parse that block, check if 'FUNCTION_DEFINITIONS' has an entry, then import & call it.
|
474 |
"""
|
475 |
import re
|
476 |
pattern = r"```tool_code\s*(.*?)\s*```"
|
@@ -479,12 +436,11 @@ def handle_function_call(text: str) -> str:
|
|
479 |
return ""
|
480 |
code_block = match.group(1).strip()
|
481 |
|
482 |
-
# 함수명 추출 (예: get_stock_price)
|
483 |
-
# 정규식: ^(\w+)\(.*\)
|
484 |
func_match = re.match(r'^(\w+)\((.*)\)$', code_block)
|
485 |
if not func_match:
|
486 |
logger.debug("No valid function call format found.")
|
487 |
return ""
|
|
|
488 |
func_name = func_match.group(1)
|
489 |
param_str = func_match.group(2).strip()
|
490 |
|
@@ -496,43 +452,35 @@ def handle_function_call(text: str) -> str:
|
|
496 |
func_info = FUNCTION_DEFINITIONS[func_name]
|
497 |
module_path = func_info["module_path"]
|
498 |
module_func_name = func_info["func_name_in_module"]
|
499 |
-
|
500 |
try:
|
501 |
imported_module = importlib.import_module(module_path)
|
502 |
except ImportError as e:
|
503 |
logger.error(f"Failed to import module {module_path}: {e}")
|
504 |
return f"```tool_output\nError: Cannot import module '{module_path}'\n```"
|
505 |
|
506 |
-
# 실제 함수 객체를 가져옴
|
507 |
if not hasattr(imported_module, module_func_name):
|
508 |
logger.error(f"Module '{module_path}' has no attribute '{module_func_name}'.")
|
509 |
return f"```tool_output\nError: Function '{module_func_name}' not found in module '{module_path}'\n```"
|
510 |
|
511 |
real_func = getattr(imported_module, module_func_name)
|
512 |
|
513 |
-
# 파라미터 파싱
|
514 |
-
# 단순 정규식으로 key="value" or key=123 식을 구분
|
515 |
param_pattern = r'(\w+)\s*=\s*"(.*?)"|(\w+)\s*=\s*([\d.]+)'
|
516 |
-
# 이 정규식은 간단히 key="string" 또는 key=123 같은 형태를 파싱
|
517 |
-
# 더 복잡한 경우 별도 파싱 로직이나 json.loads 기법 사용 필요
|
518 |
param_dict = {}
|
519 |
for p_match in re.finditer(param_pattern, param_str):
|
520 |
if p_match.group(1) and p_match.group(2):
|
521 |
-
# group(1)은 key, group(2)는 string value
|
522 |
key = p_match.group(1)
|
523 |
val = p_match.group(2)
|
524 |
param_dict[key] = val
|
525 |
else:
|
526 |
-
# group(3)은 key, group(4)는 numeric value
|
527 |
key = p_match.group(3)
|
528 |
val = p_match.group(4)
|
529 |
-
# 숫자 변환
|
530 |
if '.' in val:
|
531 |
param_dict[key] = float(val)
|
532 |
else:
|
533 |
param_dict[key] = int(val)
|
534 |
|
535 |
-
# 이제 실제 함수 실행
|
536 |
try:
|
537 |
result = real_func(**param_dict)
|
538 |
except Exception as e:
|
@@ -541,9 +489,6 @@ def handle_function_call(text: str) -> str:
|
|
541 |
|
542 |
return f"```tool_output\n{result}\n```"
|
543 |
|
544 |
-
# =============================================================================
|
545 |
-
# Main inference function
|
546 |
-
# =============================================================================
|
547 |
@spaces.GPU(duration=120)
|
548 |
def run(
|
549 |
message: dict,
|
@@ -555,19 +500,18 @@ def run(
|
|
555 |
age_group: str = "20s",
|
556 |
mbti_personality: str = "INTP",
|
557 |
sexual_openness: int = 2,
|
558 |
-
image_gen: bool = False
|
559 |
) -> Iterator[str]:
|
560 |
if not validate_media_constraints(message, history):
|
561 |
yield ""
|
562 |
return
|
563 |
temp_files = []
|
564 |
try:
|
565 |
-
# JSON에서 로드된 함수
|
566 |
-
# (토큰 부담이 커질 수 있으므로, 적당히 압축 요약 권장)
|
567 |
-
# 아래는 예시로 간단히 함수 이름만 나열
|
568 |
available_funcs_text = ""
|
569 |
for f_name, info in FUNCTION_DEFINITIONS.items():
|
570 |
-
|
|
|
571 |
|
572 |
persona = (
|
573 |
f"{system_prompt.strip()}\n\n"
|
@@ -575,7 +519,9 @@ def run(
|
|
575 |
f"Age Group: {age_group}\n"
|
576 |
f"MBTI Persona: {mbti_personality}\n"
|
577 |
f"Sexual Openness (1-5): {sexual_openness}\n\n"
|
578 |
-
"Below are the available functions you can call
|
|
|
|
|
579 |
f"{available_funcs_text}\n"
|
580 |
)
|
581 |
combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
|
@@ -629,7 +575,6 @@ def run(
|
|
629 |
output_so_far += new_text
|
630 |
yield output_so_far
|
631 |
|
632 |
-
# 모델 출력 중 ```tool_code``` 블록이 있으면 처리
|
633 |
func_result = handle_function_call(output_so_far)
|
634 |
if func_result:
|
635 |
output_so_far += "\n\n" + func_result
|
@@ -652,17 +597,12 @@ def run(
|
|
652 |
pass
|
653 |
clear_cuda_cache()
|
654 |
|
655 |
-
# =============================================================================
|
656 |
-
# Modified model run function - handles image generation and gallery update
|
657 |
-
# =============================================================================
|
658 |
def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
|
659 |
age_group, mbti_personality, sexual_openness, image_gen):
|
660 |
-
# Initialize and hide the gallery component
|
661 |
output_so_far = ""
|
662 |
gallery_update = gr.Gallery(visible=False, value=[])
|
663 |
yield output_so_far, gallery_update
|
664 |
|
665 |
-
# Execute the original run function
|
666 |
text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
|
667 |
web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
|
668 |
|
@@ -670,15 +610,12 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
670 |
output_so_far = text_chunk
|
671 |
yield output_so_far, gallery_update
|
672 |
|
673 |
-
# If image generation is enabled and there is text input, update the gallery
|
674 |
if image_gen and message["text"].strip():
|
675 |
try:
|
676 |
width, height = 512, 512
|
677 |
guidance, steps, seed = 7.5, 30, 42
|
678 |
|
679 |
logger.info(f"Calling image generation for gallery with prompt: {message['text']}")
|
680 |
-
|
681 |
-
# Call the API to generate an image
|
682 |
image_result, seed_info = generate_image(
|
683 |
prompt=message["text"].strip(),
|
684 |
width=width,
|
@@ -687,7 +624,6 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
687 |
inference_steps=steps,
|
688 |
seed=seed
|
689 |
)
|
690 |
-
|
691 |
if image_result:
|
692 |
if isinstance(image_result, str) and (
|
693 |
image_result.startswith('data:') or
|
@@ -699,22 +635,18 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
699 |
else:
|
700 |
b64data = image_result
|
701 |
content_type = "image/webp"
|
702 |
-
|
703 |
image_bytes = base64.b64decode(b64data)
|
704 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
705 |
temp_file.write(image_bytes)
|
706 |
temp_path = temp_file.name
|
707 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
708 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
709 |
-
|
710 |
except Exception as e:
|
711 |
logger.error(f"Error processing Base64 image: {e}")
|
712 |
yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
|
713 |
-
|
714 |
elif isinstance(image_result, str) and os.path.exists(image_result):
|
715 |
gallery_update = gr.Gallery(visible=True, value=[image_result])
|
716 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
717 |
-
|
718 |
elif isinstance(image_result, str) and '/tmp/' in image_result:
|
719 |
try:
|
720 |
client = Client(API_URL)
|
@@ -722,13 +654,11 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
722 |
prompt=message["text"].strip(),
|
723 |
api_name="/generate_base64_image"
|
724 |
)
|
725 |
-
|
726 |
if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
|
727 |
if result.startswith('data:'):
|
728 |
content_type, b64data = result.split(';base64,')
|
729 |
else:
|
730 |
b64data = result
|
731 |
-
|
732 |
image_bytes = base64.b64decode(b64data)
|
733 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
734 |
temp_file.write(image_bytes)
|
@@ -737,7 +667,6 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
737 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
738 |
else:
|
739 |
yield output_so_far + "\n\n(Image generation failed: Invalid format)", gallery_update
|
740 |
-
|
741 |
except Exception as e:
|
742 |
logger.error(f"Error calling alternative API: {e}")
|
743 |
yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
|
@@ -755,14 +684,10 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
755 |
yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
|
756 |
else:
|
757 |
yield output_so_far + f"\n\n(Image generation failed: {seed_info})", gallery_update
|
758 |
-
|
759 |
except Exception as e:
|
760 |
logger.error(f"Error during gallery image generation: {e}")
|
761 |
yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
|
762 |
|
763 |
-
# =============================================================================
|
764 |
-
# Examples
|
765 |
-
# =============================================================================
|
766 |
examples = [
|
767 |
[
|
768 |
{
|
@@ -855,7 +780,7 @@ examples = [
|
|
855 |
],
|
856 |
[
|
857 |
{
|
858 |
-
"text": "AAPL의 현재 주가를 알려줘.",
|
859 |
"files": []
|
860 |
}
|
861 |
],
|
|
|
13 |
import logging
|
14 |
import time
|
15 |
from urllib.parse import quote # Added for URL encoding
|
16 |
+
import importlib # For dynamic import
|
17 |
|
18 |
import gradio as gr
|
19 |
import spaces
|
|
|
84 |
logging.error(f"Image generation failed: {str(e)}")
|
85 |
return None, f"Error: {str(e)}"
|
86 |
|
|
|
87 |
def fix_base64_padding(data):
|
88 |
"""Fix the padding of a Base64 string."""
|
89 |
if isinstance(data, bytes):
|
|
|
98 |
|
99 |
return data
|
100 |
|
|
|
|
|
|
|
101 |
def clear_cuda_cache():
|
102 |
"""Explicitly clear the CUDA cache."""
|
103 |
if torch.cuda.is_available():
|
104 |
torch.cuda.empty_cache()
|
105 |
gc.collect()
|
106 |
|
|
|
|
|
|
|
107 |
SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
|
108 |
|
109 |
def extract_keywords(text: str, top_k: int = 5) -> str:
|
|
|
169 |
logger.error(f"Web search failed: {e}")
|
170 |
return f"Web search failed: {str(e)}"
|
171 |
|
|
|
|
|
|
|
172 |
MAX_CONTENT_CHARS = 2000
|
173 |
MAX_INPUT_LENGTH = 2096
|
174 |
model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
|
|
|
181 |
)
|
182 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
183 |
|
|
|
|
|
|
|
184 |
def analyze_csv_file(path: str) -> str:
|
185 |
try:
|
186 |
df = pd.read_csv(path)
|
|
|
225 |
full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
|
226 |
return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
|
227 |
|
|
|
|
|
|
|
228 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
229 |
image_count = 0
|
230 |
video_count = 0
|
|
|
277 |
return False
|
278 |
return True
|
279 |
|
|
|
|
|
|
|
280 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
281 |
vidcap = cv2.VideoCapture(video_path)
|
282 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
|
|
309 |
content.append({"type": "image", "url": temp_file.name})
|
310 |
return content, temp_files
|
311 |
|
|
|
|
|
|
|
312 |
def process_interleaved_images(message: dict) -> list[dict]:
|
313 |
parts = re.split(r"(<image>)", message["text"])
|
314 |
content = []
|
|
|
325 |
content.append({"type": "text", "text": part})
|
326 |
return content
|
327 |
|
|
|
|
|
|
|
328 |
def is_image_file(file_path: str) -> bool:
|
329 |
return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
|
330 |
|
|
|
365 |
content_list.append({"type": "image", "url": img_path})
|
366 |
return content_list, temp_files
|
367 |
|
|
|
|
|
|
|
368 |
def process_history(history: list[dict]) -> list[dict]:
|
369 |
messages = []
|
370 |
current_user_content = []
|
|
|
388 |
messages.append({"role": "user", "content": current_user_content})
|
389 |
return messages
|
390 |
|
|
|
|
|
|
|
391 |
def _model_gen_with_oom_catch(**kwargs):
|
392 |
try:
|
393 |
model.generate(**kwargs)
|
|
|
402 |
def load_function_definitions(json_path="functions.json"):
|
403 |
"""
|
404 |
로컬 JSON 파일에서 함수 정의 목록을 로드하여 반환.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
"""
|
406 |
try:
|
407 |
with open(json_path, "r", encoding="utf-8") as f:
|
408 |
data = json.load(f)
|
|
|
409 |
func_dict = {}
|
410 |
for entry in data:
|
411 |
func_name = entry["name"]
|
|
|
417 |
|
418 |
FUNCTION_DEFINITIONS = load_function_definitions("functions.json")
|
419 |
|
|
|
|
|
|
|
420 |
def handle_function_call(text: str) -> str:
|
421 |
"""
|
422 |
Detects and processes function call blocks in the text using the JSON-based approach.
|
|
|
428 |
```tool_code
|
429 |
get_product_name_by_PID(PID="807ZPKBL9V")
|
430 |
```
|
|
|
431 |
"""
|
432 |
import re
|
433 |
pattern = r"```tool_code\s*(.*?)\s*```"
|
|
|
436 |
return ""
|
437 |
code_block = match.group(1).strip()
|
438 |
|
|
|
|
|
439 |
func_match = re.match(r'^(\w+)\((.*)\)$', code_block)
|
440 |
if not func_match:
|
441 |
logger.debug("No valid function call format found.")
|
442 |
return ""
|
443 |
+
|
444 |
func_name = func_match.group(1)
|
445 |
param_str = func_match.group(2).strip()
|
446 |
|
|
|
452 |
func_info = FUNCTION_DEFINITIONS[func_name]
|
453 |
module_path = func_info["module_path"]
|
454 |
module_func_name = func_info["func_name_in_module"]
|
455 |
+
|
456 |
try:
|
457 |
imported_module = importlib.import_module(module_path)
|
458 |
except ImportError as e:
|
459 |
logger.error(f"Failed to import module {module_path}: {e}")
|
460 |
return f"```tool_output\nError: Cannot import module '{module_path}'\n```"
|
461 |
|
|
|
462 |
if not hasattr(imported_module, module_func_name):
|
463 |
logger.error(f"Module '{module_path}' has no attribute '{module_func_name}'.")
|
464 |
return f"```tool_output\nError: Function '{module_func_name}' not found in module '{module_path}'\n```"
|
465 |
|
466 |
real_func = getattr(imported_module, module_func_name)
|
467 |
|
468 |
+
# 간단 파라미터 파싱 (key="value" or key=123)
|
|
|
469 |
param_pattern = r'(\w+)\s*=\s*"(.*?)"|(\w+)\s*=\s*([\d.]+)'
|
|
|
|
|
470 |
param_dict = {}
|
471 |
for p_match in re.finditer(param_pattern, param_str):
|
472 |
if p_match.group(1) and p_match.group(2):
|
|
|
473 |
key = p_match.group(1)
|
474 |
val = p_match.group(2)
|
475 |
param_dict[key] = val
|
476 |
else:
|
|
|
477 |
key = p_match.group(3)
|
478 |
val = p_match.group(4)
|
|
|
479 |
if '.' in val:
|
480 |
param_dict[key] = float(val)
|
481 |
else:
|
482 |
param_dict[key] = int(val)
|
483 |
|
|
|
484 |
try:
|
485 |
result = real_func(**param_dict)
|
486 |
except Exception as e:
|
|
|
489 |
|
490 |
return f"```tool_output\n{result}\n```"
|
491 |
|
|
|
|
|
|
|
492 |
@spaces.GPU(duration=120)
|
493 |
def run(
|
494 |
message: dict,
|
|
|
500 |
age_group: str = "20s",
|
501 |
mbti_personality: str = "INTP",
|
502 |
sexual_openness: int = 2,
|
503 |
+
image_gen: bool = False
|
504 |
) -> Iterator[str]:
|
505 |
if not validate_media_constraints(message, history):
|
506 |
yield ""
|
507 |
return
|
508 |
temp_files = []
|
509 |
try:
|
510 |
+
# JSON에서 로드된 함수 정보 문자열화 (예: 함수명과 example_usage만)
|
|
|
|
|
511 |
available_funcs_text = ""
|
512 |
for f_name, info in FUNCTION_DEFINITIONS.items():
|
513 |
+
example_usage = info.get("example_usage", "")
|
514 |
+
available_funcs_text += f"\n\nFunction: {f_name}\nDescription: {info['description']}\nExample:\n{example_usage}\n"
|
515 |
|
516 |
persona = (
|
517 |
f"{system_prompt.strip()}\n\n"
|
|
|
519 |
f"Age Group: {age_group}\n"
|
520 |
f"MBTI Persona: {mbti_personality}\n"
|
521 |
f"Sexual Openness (1-5): {sexual_openness}\n\n"
|
522 |
+
"Below are the available functions you can call.\n"
|
523 |
+
"Important: Use the format exactly like: ```tool_code\nfunctionName(param=\"string\", ...)\n```\n"
|
524 |
+
"(Strings must be in double quotes)\n"
|
525 |
f"{available_funcs_text}\n"
|
526 |
)
|
527 |
combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
|
|
|
575 |
output_so_far += new_text
|
576 |
yield output_so_far
|
577 |
|
|
|
578 |
func_result = handle_function_call(output_so_far)
|
579 |
if func_result:
|
580 |
output_so_far += "\n\n" + func_result
|
|
|
597 |
pass
|
598 |
clear_cuda_cache()
|
599 |
|
|
|
|
|
|
|
600 |
def modified_run(message, history, system_prompt, max_new_tokens, use_web_search, web_search_query,
|
601 |
age_group, mbti_personality, sexual_openness, image_gen):
|
|
|
602 |
output_so_far = ""
|
603 |
gallery_update = gr.Gallery(visible=False, value=[])
|
604 |
yield output_so_far, gallery_update
|
605 |
|
|
|
606 |
text_generator = run(message, history, system_prompt, max_new_tokens, use_web_search,
|
607 |
web_search_query, age_group, mbti_personality, sexual_openness, image_gen)
|
608 |
|
|
|
610 |
output_so_far = text_chunk
|
611 |
yield output_so_far, gallery_update
|
612 |
|
|
|
613 |
if image_gen and message["text"].strip():
|
614 |
try:
|
615 |
width, height = 512, 512
|
616 |
guidance, steps, seed = 7.5, 30, 42
|
617 |
|
618 |
logger.info(f"Calling image generation for gallery with prompt: {message['text']}")
|
|
|
|
|
619 |
image_result, seed_info = generate_image(
|
620 |
prompt=message["text"].strip(),
|
621 |
width=width,
|
|
|
624 |
inference_steps=steps,
|
625 |
seed=seed
|
626 |
)
|
|
|
627 |
if image_result:
|
628 |
if isinstance(image_result, str) and (
|
629 |
image_result.startswith('data:') or
|
|
|
635 |
else:
|
636 |
b64data = image_result
|
637 |
content_type = "image/webp"
|
|
|
638 |
image_bytes = base64.b64decode(b64data)
|
639 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
640 |
temp_file.write(image_bytes)
|
641 |
temp_path = temp_file.name
|
642 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
643 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
|
|
644 |
except Exception as e:
|
645 |
logger.error(f"Error processing Base64 image: {e}")
|
646 |
yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
|
|
|
647 |
elif isinstance(image_result, str) and os.path.exists(image_result):
|
648 |
gallery_update = gr.Gallery(visible=True, value=[image_result])
|
649 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
|
|
650 |
elif isinstance(image_result, str) and '/tmp/' in image_result:
|
651 |
try:
|
652 |
client = Client(API_URL)
|
|
|
654 |
prompt=message["text"].strip(),
|
655 |
api_name="/generate_base64_image"
|
656 |
)
|
|
|
657 |
if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
|
658 |
if result.startswith('data:'):
|
659 |
content_type, b64data = result.split(';base64,')
|
660 |
else:
|
661 |
b64data = result
|
|
|
662 |
image_bytes = base64.b64decode(b64data)
|
663 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
664 |
temp_file.write(image_bytes)
|
|
|
667 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
668 |
else:
|
669 |
yield output_so_far + "\n\n(Image generation failed: Invalid format)", gallery_update
|
|
|
670 |
except Exception as e:
|
671 |
logger.error(f"Error calling alternative API: {e}")
|
672 |
yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
|
|
|
684 |
yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
|
685 |
else:
|
686 |
yield output_so_far + f"\n\n(Image generation failed: {seed_info})", gallery_update
|
|
|
687 |
except Exception as e:
|
688 |
logger.error(f"Error during gallery image generation: {e}")
|
689 |
yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
|
690 |
|
|
|
|
|
|
|
691 |
examples = [
|
692 |
[
|
693 |
{
|
|
|
780 |
],
|
781 |
[
|
782 |
{
|
783 |
+
"text": "AAPL의 현재 주가를 알려줘.",
|
784 |
"files": []
|
785 |
}
|
786 |
],
|