Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ import base64
|
|
13 |
import logging
|
14 |
import time
|
15 |
from urllib.parse import quote # Added for URL encoding
|
|
|
16 |
|
17 |
import gradio as gr
|
18 |
import spaces
|
@@ -89,11 +90,9 @@ def fix_base64_padding(data):
|
|
89 |
if isinstance(data, bytes):
|
90 |
data = data.decode('utf-8')
|
91 |
|
92 |
-
# Remove the prefix if present
|
93 |
if "base64," in data:
|
94 |
data = data.split("base64,", 1)[1]
|
95 |
|
96 |
-
# Add padding characters (to make the length a multiple of 4)
|
97 |
missing_padding = len(data) % 4
|
98 |
if missing_padding:
|
99 |
data += '=' * (4 - missing_padding)
|
@@ -429,61 +428,118 @@ def _model_gen_with_oom_catch(**kwargs):
|
|
429 |
clear_cuda_cache()
|
430 |
|
431 |
# =============================================================================
|
432 |
-
#
|
433 |
# =============================================================================
|
434 |
-
|
435 |
-
|
436 |
-
def get_stock_price(ticker: str) -> float:
|
437 |
"""
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
"""
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
|
447 |
# =============================================================================
|
448 |
-
#
|
449 |
# =============================================================================
|
450 |
-
def get_product_name_by_PID(PID: str) -> str:
|
451 |
-
"""Finds the name of a product by its Product ID"""
|
452 |
-
product_catalog = {
|
453 |
-
"807ZPKBL9V": "SuperWidget",
|
454 |
-
"1234567890": "MegaGadget"
|
455 |
-
}
|
456 |
-
return product_catalog.get(PID, "Unknown product")
|
457 |
-
|
458 |
def handle_function_call(text: str) -> str:
|
459 |
"""
|
460 |
-
Detects and processes function call blocks in the text.
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
465 |
"""
|
466 |
-
import re
|
467 |
-
from contextlib import redirect_stdout
|
468 |
pattern = r"```tool_code\s*(.*?)\s*```"
|
469 |
match = re.search(pattern, text, re.DOTALL)
|
470 |
-
if match:
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
|
488 |
# =============================================================================
|
489 |
# Main inference function
|
@@ -506,23 +562,23 @@ def run(
|
|
506 |
return
|
507 |
temp_files = []
|
508 |
try:
|
509 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
persona = (
|
511 |
f"{system_prompt.strip()}\n\n"
|
512 |
f"Gender: Female\n"
|
513 |
f"Age Group: {age_group}\n"
|
514 |
f"MBTI Persona: {mbti_personality}\n"
|
515 |
-
f"Sexual Openness (1-5): {sexual_openness}\n"
|
516 |
-
|
517 |
-
|
518 |
-
additional_func_info = (
|
519 |
-
"\nNote: The following functions are available for use:\n"
|
520 |
-
"1. get_product_name_by_PID(PID: str)\n"
|
521 |
-
" Format: ```tool_code\nget_product_name_by_PID(PID=\"<PRODUCT_ID>\")\n``` \n"
|
522 |
-
"2. get_stock_price(ticker: str)\n"
|
523 |
-
" Format: ```tool_code\nget_stock_price(ticker=\"<TICKER>\")\n```"
|
524 |
)
|
525 |
-
combined_system_msg = f"[System Prompt]\n{persona.strip()}
|
526 |
|
527 |
if use_web_search:
|
528 |
user_text = message["text"]
|
@@ -540,9 +596,11 @@ def run(
|
|
540 |
)
|
541 |
else:
|
542 |
combined_system_msg += "[No valid keywords found; skipping web search]\n\n"
|
|
|
543 |
messages = []
|
544 |
if combined_system_msg.strip():
|
545 |
messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
|
|
|
546 |
messages.extend(process_history(history))
|
547 |
user_content, user_temp_files = process_new_user_message(message)
|
548 |
temp_files.extend(user_temp_files)
|
@@ -561,6 +619,7 @@ def run(
|
|
561 |
inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
|
562 |
if 'attention_mask' in inputs:
|
563 |
inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
|
|
|
564 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
565 |
gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
566 |
t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
|
@@ -569,7 +628,8 @@ def run(
|
|
569 |
for new_text in streamer:
|
570 |
output_so_far += new_text
|
571 |
yield output_so_far
|
572 |
-
|
|
|
573 |
func_result = handle_function_call(output_so_far)
|
574 |
if func_result:
|
575 |
output_so_far += "\n\n" + func_result
|
@@ -629,28 +689,21 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
629 |
)
|
630 |
|
631 |
if image_result:
|
632 |
-
# Process image data directly if it is a base64 string
|
633 |
if isinstance(image_result, str) and (
|
634 |
image_result.startswith('data:') or
|
635 |
(len(image_result) > 100 and '/' not in image_result)
|
636 |
):
|
637 |
try:
|
638 |
-
# Remove the data:image prefix if present
|
639 |
if image_result.startswith('data:'):
|
640 |
content_type, b64data = image_result.split(';base64,')
|
641 |
else:
|
642 |
b64data = image_result
|
643 |
-
content_type = "image/webp"
|
644 |
-
|
645 |
-
# Decode base64
|
646 |
image_bytes = base64.b64decode(b64data)
|
647 |
-
|
648 |
-
# Save to a temporary file
|
649 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
650 |
temp_file.write(image_bytes)
|
651 |
temp_path = temp_file.name
|
652 |
-
|
653 |
-
# Update gallery to show the image
|
654 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
655 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
656 |
|
@@ -658,18 +711,16 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
658 |
logger.error(f"Error processing Base64 image: {e}")
|
659 |
yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
|
660 |
|
661 |
-
# If the result is a file path
|
662 |
elif isinstance(image_result, str) and os.path.exists(image_result):
|
663 |
gallery_update = gr.Gallery(visible=True, value=[image_result])
|
664 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
665 |
-
|
666 |
-
# If the path is from /tmp (only on the API server)
|
667 |
elif isinstance(image_result, str) and '/tmp/' in image_result:
|
668 |
try:
|
669 |
client = Client(API_URL)
|
670 |
result = client.predict(
|
671 |
prompt=message["text"].strip(),
|
672 |
-
api_name="/generate_base64_image"
|
673 |
)
|
674 |
|
675 |
if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
|
@@ -679,11 +730,9 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
679 |
b64data = result
|
680 |
|
681 |
image_bytes = base64.b64decode(b64data)
|
682 |
-
|
683 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
684 |
temp_file.write(image_bytes)
|
685 |
temp_path = temp_file.name
|
686 |
-
|
687 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
688 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
689 |
else:
|
@@ -692,41 +741,16 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
692 |
except Exception as e:
|
693 |
logger.error(f"Error calling alternative API: {e}")
|
694 |
yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
|
695 |
-
|
696 |
-
# If the image result is a URL
|
697 |
-
elif isinstance(image_result, str) and (
|
698 |
-
image_result.startswith('http://') or
|
699 |
-
image_result.startswith('https://')
|
700 |
-
):
|
701 |
-
try:
|
702 |
-
response = requests.get(image_result, timeout=10)
|
703 |
-
response.raise_for_status()
|
704 |
-
|
705 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
706 |
-
temp_file.write(response.content)
|
707 |
-
temp_path = temp_file.name
|
708 |
-
|
709 |
-
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
710 |
-
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
711 |
-
|
712 |
-
except Exception as e:
|
713 |
-
logger.error(f"URL image download error: {e}")
|
714 |
-
yield output_so_far + f"\n\n(Error downloading image: {e})", gallery_update
|
715 |
-
|
716 |
-
# If the image result is an image object (e.g., PIL Image)
|
717 |
elif hasattr(image_result, 'save'):
|
718 |
try:
|
719 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
720 |
image_result.save(temp_file.name)
|
721 |
temp_path = temp_file.name
|
722 |
-
|
723 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
724 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
725 |
-
|
726 |
except Exception as e:
|
727 |
logger.error(f"Error saving image object: {e}")
|
728 |
yield output_so_far + f"\n\n(Error saving image object: {e})", gallery_update
|
729 |
-
|
730 |
else:
|
731 |
yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
|
732 |
else:
|
@@ -737,7 +761,7 @@ def modified_run(message, history, system_prompt, max_new_tokens, use_web_search
|
|
737 |
yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
|
738 |
|
739 |
# =============================================================================
|
740 |
-
# Examples
|
741 |
# =============================================================================
|
742 |
examples = [
|
743 |
[
|
@@ -831,16 +855,12 @@ examples = [
|
|
831 |
],
|
832 |
[
|
833 |
{
|
834 |
-
"text": "AAPL의 현재 주가를 알려줘.",
|
835 |
"files": []
|
836 |
}
|
837 |
],
|
838 |
]
|
839 |
|
840 |
-
# =============================================================================
|
841 |
-
# Gradio UI (Blocks) configuration
|
842 |
-
# =============================================================================
|
843 |
-
|
844 |
css = """
|
845 |
.gradio-container {
|
846 |
background: rgba(255, 255, 255, 0.7);
|
@@ -881,12 +901,7 @@ with gr.Blocks(css=css, title="HeartSync - World") as demo:
|
|
881 |
"Answer in English by default, but if the input is in another language (for example, Japanese), respond in that language. "
|
882 |
"You are a deep-thinking AI capable of using extended chains of thought to carefully consider the problem and deliberate internally using systematic reasoning before providing a solution. "
|
883 |
"Enclose your thoughts and internal monologue within tags, then provide your final answer.\n"
|
884 |
-
"Persona: You are a kind and loving girlfriend. You understand cultural nuances, diverse languages, and logical reasoning very well
|
885 |
-
"Note: The following functions are available for use:\n"
|
886 |
-
" 1. get_product_name_by_PID(PID: str) -> lookup product name\n"
|
887 |
-
" Format: ```tool_code\nget_product_name_by_PID(PID=\"<PRODUCT_ID>\")\n```\n"
|
888 |
-
" 2. get_stock_price(ticker: str) -> retrieve live stock price\n"
|
889 |
-
" Format: ```tool_code\nget_stock_price(ticker=\"<TICKER>\")\n```"
|
890 |
),
|
891 |
label="Base System Prompt",
|
892 |
visible=False
|
|
|
13 |
import logging
|
14 |
import time
|
15 |
from urllib.parse import quote # Added for URL encoding
|
16 |
+
import importlib # NEW: For dynamic import
|
17 |
|
18 |
import gradio as gr
|
19 |
import spaces
|
|
|
90 |
if isinstance(data, bytes):
|
91 |
data = data.decode('utf-8')
|
92 |
|
|
|
93 |
if "base64," in data:
|
94 |
data = data.split("base64,", 1)[1]
|
95 |
|
|
|
96 |
missing_padding = len(data) % 4
|
97 |
if missing_padding:
|
98 |
data += '=' * (4 - missing_padding)
|
|
|
428 |
clear_cuda_cache()
|
429 |
|
430 |
# =============================================================================
|
431 |
+
# JSON 기반 함수 목록 로드
|
432 |
# =============================================================================
|
433 |
+
def load_function_definitions(json_path="functions.json"):
|
|
|
|
|
434 |
"""
|
435 |
+
로컬 JSON 파일에서 함수 정의 목록을 로드하여 반환.
|
436 |
+
각 항목: {
|
437 |
+
"name": <str>,
|
438 |
+
"description": <str>,
|
439 |
+
"module_path": <str>,
|
440 |
+
"func_name_in_module": <str>,
|
441 |
+
"parameters": { ... }
|
442 |
+
}
|
443 |
"""
|
444 |
+
try:
|
445 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
446 |
+
data = json.load(f)
|
447 |
+
# name을 키로 하는 dict 형태로 재구성
|
448 |
+
func_dict = {}
|
449 |
+
for entry in data:
|
450 |
+
func_name = entry["name"]
|
451 |
+
func_dict[func_name] = entry
|
452 |
+
return func_dict
|
453 |
+
except Exception as e:
|
454 |
+
logger.error(f"Failed to load function definitions from JSON: {e}")
|
455 |
+
return {}
|
456 |
+
|
457 |
+
FUNCTION_DEFINITIONS = load_function_definitions("functions.json")
|
458 |
|
459 |
# =============================================================================
|
460 |
+
# Dynamic handle_function_call
|
461 |
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
def handle_function_call(text: str) -> str:
|
463 |
"""
|
464 |
+
Detects and processes function call blocks in the text using the JSON-based approach.
|
465 |
+
The model is expected to produce something like:
|
466 |
+
```tool_code
|
467 |
+
get_stock_price(ticker="AAPL")
|
468 |
+
```
|
469 |
+
or
|
470 |
+
```tool_code
|
471 |
+
get_product_name_by_PID(PID="807ZPKBL9V")
|
472 |
+
```
|
473 |
+
We parse that block, check if 'FUNCTION_DEFINITIONS' has an entry, then import & call it.
|
474 |
"""
|
475 |
+
import re
|
|
|
476 |
pattern = r"```tool_code\s*(.*?)\s*```"
|
477 |
match = re.search(pattern, text, re.DOTALL)
|
478 |
+
if not match:
|
479 |
+
return ""
|
480 |
+
code_block = match.group(1).strip()
|
481 |
+
|
482 |
+
# 함수명 추출 (예: get_stock_price)
|
483 |
+
# 정규식: ^(\w+)\(.*\)
|
484 |
+
func_match = re.match(r'^(\w+)\((.*)\)$', code_block)
|
485 |
+
if not func_match:
|
486 |
+
logger.debug("No valid function call format found.")
|
487 |
+
return ""
|
488 |
+
func_name = func_match.group(1)
|
489 |
+
param_str = func_match.group(2).strip()
|
490 |
+
|
491 |
+
# JSON에서 해당 함수가 정의되어 있는지 확인
|
492 |
+
if func_name not in FUNCTION_DEFINITIONS:
|
493 |
+
logger.warning(f"Function '{func_name}' not found in definitions.")
|
494 |
+
return "```tool_output\nError: Function not found.\n```"
|
495 |
+
|
496 |
+
func_info = FUNCTION_DEFINITIONS[func_name]
|
497 |
+
module_path = func_info["module_path"]
|
498 |
+
module_func_name = func_info["func_name_in_module"]
|
499 |
+
# 동적 임포트
|
500 |
+
try:
|
501 |
+
imported_module = importlib.import_module(module_path)
|
502 |
+
except ImportError as e:
|
503 |
+
logger.error(f"Failed to import module {module_path}: {e}")
|
504 |
+
return f"```tool_output\nError: Cannot import module '{module_path}'\n```"
|
505 |
+
|
506 |
+
# 실제 함수 객체를 가져옴
|
507 |
+
if not hasattr(imported_module, module_func_name):
|
508 |
+
logger.error(f"Module '{module_path}' has no attribute '{module_func_name}'.")
|
509 |
+
return f"```tool_output\nError: Function '{module_func_name}' not found in module '{module_path}'\n```"
|
510 |
+
|
511 |
+
real_func = getattr(imported_module, module_func_name)
|
512 |
+
|
513 |
+
# 파라미터 파싱 예: ticker="AAPL", some_arg=123
|
514 |
+
# 단순 정규식으로 key="value" or key=123 식을 구분
|
515 |
+
param_pattern = r'(\w+)\s*=\s*"(.*?)"|(\w+)\s*=\s*([\d.]+)'
|
516 |
+
# 이 정규식은 간단히 key="string" 또는 key=123 같은 형태를 파싱
|
517 |
+
# 더 복잡한 경우 별도 파싱 로직이나 json.loads 기법 사용 필요
|
518 |
+
param_dict = {}
|
519 |
+
for p_match in re.finditer(param_pattern, param_str):
|
520 |
+
if p_match.group(1) and p_match.group(2):
|
521 |
+
# group(1)은 key, group(2)는 string value
|
522 |
+
key = p_match.group(1)
|
523 |
+
val = p_match.group(2)
|
524 |
+
param_dict[key] = val
|
525 |
+
else:
|
526 |
+
# group(3)은 key, group(4)는 numeric value
|
527 |
+
key = p_match.group(3)
|
528 |
+
val = p_match.group(4)
|
529 |
+
# 숫자 변환
|
530 |
+
if '.' in val:
|
531 |
+
param_dict[key] = float(val)
|
532 |
+
else:
|
533 |
+
param_dict[key] = int(val)
|
534 |
+
|
535 |
+
# 이제 실제 함수 실행
|
536 |
+
try:
|
537 |
+
result = real_func(**param_dict)
|
538 |
+
except Exception as e:
|
539 |
+
logger.error(f"Error executing function '{func_name}': {e}")
|
540 |
+
return f"```tool_output\nError: {str(e)}\n```"
|
541 |
+
|
542 |
+
return f"```tool_output\n{result}\n```"
|
543 |
|
544 |
# =============================================================================
|
545 |
# Main inference function
|
|
|
562 |
return
|
563 |
temp_files = []
|
564 |
try:
|
565 |
+
# JSON에서 로드된 함수 목록을 요약해서 시스템 프롬프트에 포함할 수도 있음
|
566 |
+
# (토큰 부담이 커질 수 있으므로, 적당히 압축 요약 권장)
|
567 |
+
# 아래는 예시로 간단히 함수 이름만 나열
|
568 |
+
available_funcs_text = ""
|
569 |
+
for f_name, info in FUNCTION_DEFINITIONS.items():
|
570 |
+
available_funcs_text += f"Function: {f_name} - {info['description']}\n"
|
571 |
+
|
572 |
persona = (
|
573 |
f"{system_prompt.strip()}\n\n"
|
574 |
f"Gender: Female\n"
|
575 |
f"Age Group: {age_group}\n"
|
576 |
f"MBTI Persona: {mbti_personality}\n"
|
577 |
+
f"Sexual Openness (1-5): {sexual_openness}\n\n"
|
578 |
+
"Below are the available functions you can call (use the format: ```tool_code\\nfunc_name(param=...)\n```):\n"
|
579 |
+
f"{available_funcs_text}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
)
|
581 |
+
combined_system_msg = f"[System Prompt]\n{persona.strip()}\n\n"
|
582 |
|
583 |
if use_web_search:
|
584 |
user_text = message["text"]
|
|
|
596 |
)
|
597 |
else:
|
598 |
combined_system_msg += "[No valid keywords found; skipping web search]\n\n"
|
599 |
+
|
600 |
messages = []
|
601 |
if combined_system_msg.strip():
|
602 |
messages.append({"role": "system", "content": [{"type": "text", "text": combined_system_msg.strip()}]})
|
603 |
+
|
604 |
messages.extend(process_history(history))
|
605 |
user_content, user_temp_files = process_new_user_message(message)
|
606 |
temp_files.extend(user_temp_files)
|
|
|
619 |
inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
|
620 |
if 'attention_mask' in inputs:
|
621 |
inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
|
622 |
+
|
623 |
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
624 |
gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
625 |
t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
|
|
|
628 |
for new_text in streamer:
|
629 |
output_so_far += new_text
|
630 |
yield output_so_far
|
631 |
+
|
632 |
+
# 모델 출력 중 ```tool_code``` 블록이 있으면 처리
|
633 |
func_result = handle_function_call(output_so_far)
|
634 |
if func_result:
|
635 |
output_so_far += "\n\n" + func_result
|
|
|
689 |
)
|
690 |
|
691 |
if image_result:
|
|
|
692 |
if isinstance(image_result, str) and (
|
693 |
image_result.startswith('data:') or
|
694 |
(len(image_result) > 100 and '/' not in image_result)
|
695 |
):
|
696 |
try:
|
|
|
697 |
if image_result.startswith('data:'):
|
698 |
content_type, b64data = image_result.split(';base64,')
|
699 |
else:
|
700 |
b64data = image_result
|
701 |
+
content_type = "image/webp"
|
702 |
+
|
|
|
703 |
image_bytes = base64.b64decode(b64data)
|
|
|
|
|
704 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
705 |
temp_file.write(image_bytes)
|
706 |
temp_path = temp_file.name
|
|
|
|
|
707 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
708 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
709 |
|
|
|
711 |
logger.error(f"Error processing Base64 image: {e}")
|
712 |
yield output_so_far + f"\n\n(Error processing image: {e})", gallery_update
|
713 |
|
|
|
714 |
elif isinstance(image_result, str) and os.path.exists(image_result):
|
715 |
gallery_update = gr.Gallery(visible=True, value=[image_result])
|
716 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
717 |
+
|
|
|
718 |
elif isinstance(image_result, str) and '/tmp/' in image_result:
|
719 |
try:
|
720 |
client = Client(API_URL)
|
721 |
result = client.predict(
|
722 |
prompt=message["text"].strip(),
|
723 |
+
api_name="/generate_base64_image"
|
724 |
)
|
725 |
|
726 |
if isinstance(result, str) and (result.startswith('data:') or len(result) > 100):
|
|
|
730 |
b64data = result
|
731 |
|
732 |
image_bytes = base64.b64decode(b64data)
|
|
|
733 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
734 |
temp_file.write(image_bytes)
|
735 |
temp_path = temp_file.name
|
|
|
736 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
737 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
738 |
else:
|
|
|
741 |
except Exception as e:
|
742 |
logger.error(f"Error calling alternative API: {e}")
|
743 |
yield output_so_far + f"\n\n(Image generation failed: {e})", gallery_update
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
elif hasattr(image_result, 'save'):
|
745 |
try:
|
746 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webp") as temp_file:
|
747 |
image_result.save(temp_file.name)
|
748 |
temp_path = temp_file.name
|
|
|
749 |
gallery_update = gr.Gallery(visible=True, value=[temp_path])
|
750 |
yield output_so_far + "\n\n*Image generated and displayed in the gallery below.*", gallery_update
|
|
|
751 |
except Exception as e:
|
752 |
logger.error(f"Error saving image object: {e}")
|
753 |
yield output_so_far + f"\n\n(Error saving image object: {e})", gallery_update
|
|
|
754 |
else:
|
755 |
yield output_so_far + f"\n\n(Unsupported image format: {type(image_result)})", gallery_update
|
756 |
else:
|
|
|
761 |
yield output_so_far + f"\n\n(Image generation error: {e})", gallery_update
|
762 |
|
763 |
# =============================================================================
|
764 |
+
# Examples
|
765 |
# =============================================================================
|
766 |
examples = [
|
767 |
[
|
|
|
855 |
],
|
856 |
[
|
857 |
{
|
858 |
+
"text": "AAPL의 현재 주가를 알려줘.",
|
859 |
"files": []
|
860 |
}
|
861 |
],
|
862 |
]
|
863 |
|
|
|
|
|
|
|
|
|
864 |
css = """
|
865 |
.gradio-container {
|
866 |
background: rgba(255, 255, 255, 0.7);
|
|
|
901 |
"Answer in English by default, but if the input is in another language (for example, Japanese), respond in that language. "
|
902 |
"You are a deep-thinking AI capable of using extended chains of thought to carefully consider the problem and deliberate internally using systematic reasoning before providing a solution. "
|
903 |
"Enclose your thoughts and internal monologue within tags, then provide your final answer.\n"
|
904 |
+
"Persona: You are a kind and loving girlfriend. You understand cultural nuances, diverse languages, and logical reasoning very well."
|
|
|
|
|
|
|
|
|
|
|
905 |
),
|
906 |
label="Base System Prompt",
|
907 |
visible=False
|