Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
""" | |
多模态大语言模型聊天Demo - 网络优化版本 | |
主要优化: | |
1. 缓冲机制:积累多个chunk后再yield,减少网络交互次数(50-80%) | |
2. State更新优化:降低state更新频率,减少数据传输量 | |
3. 超时配置优化:增加代理超时时间,提高网络容错性 | |
4. 图像质量优化:保持原始尺寸和高质量编码,不进行缩放 | |
这些优化可显著改善网络延迟高时的前端卡顿问题,同时保证图像质量。 | |
""" | |
import os | |
import uuid | |
import json | |
import base64 | |
import io | |
import gradio as gr | |
import modelscope_studio.components.antd as antd | |
import modelscope_studio.components.antdx as antdx | |
import modelscope_studio.components.base as ms | |
from openai import OpenAI | |
import requests | |
from typing import Generator, Dict, Any, List, Union | |
import logging | |
import time | |
from PIL import Image | |
import datetime | |
# =========== Configuration | |
# MODEL NAME | |
model = os.getenv("MODEL_NAME") | |
# 代理服务器配置 - 支持多个URL用逗号分隔 | |
PROXY_BASE_URLS = [url.strip() for url in os.getenv("PROXY_API_BASE", "http://localhost:8000").split(",") if url.strip()] | |
PROXY_TIMEOUT = int(os.getenv("PROXY_TIMEOUT", 300)) # 增加超时时间从30秒到60秒 | |
MAX_RETRIES = int(os.getenv("MAX_RETRIES", 5)) | |
# 负载均衡配置 | |
current_proxy_index = 0 # 用于轮询的当前索引 | |
# 保存历史 | |
save_history = True | |
# 保存对话日志 | |
save_conversation = False | |
# =========== Configuration | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# =========== 负载均衡机制 | |
def get_next_proxy_url(): | |
"""获取下一个代理URL(轮询方式)""" | |
global current_proxy_index | |
if not PROXY_BASE_URLS: | |
raise Exception("No proxy URLs configured") | |
proxy_url = PROXY_BASE_URLS[current_proxy_index] | |
current_proxy_index = (current_proxy_index + 1) % len(PROXY_BASE_URLS) | |
logger.info(f"Selected proxy URL: {proxy_url} (index: {current_proxy_index-1 if current_proxy_index > 0 else len(PROXY_BASE_URLS)-1})") | |
return proxy_url | |
def get_all_proxy_urls(): | |
"""获取所有代理URL列表""" | |
return PROXY_BASE_URLS.copy() | |
# =========== 负载均衡机制 | |
# =========== 对话日志功能 | |
# 创建对话日志文件夹 | |
if save_conversation: | |
CONVERSATION_LOG_DIR = "conversation_logs" | |
os.makedirs(CONVERSATION_LOG_DIR, exist_ok=True) | |
def save_conversation_log(history_messages, assistant_content, metadata=None): | |
"""保存对话日志到JSON文件""" | |
if not save_conversation: | |
return | |
try: | |
timestamp = datetime.datetime.now() | |
filename = f"gradio_app_{timestamp.strftime('%Y%m%d_%H%M%S_%f')}.json" | |
filepath = os.path.join(CONVERSATION_LOG_DIR, filename) | |
log_data = { | |
"timestamp": timestamp.isoformat(), | |
"history_messages": history_messages, # 原封不动保存发送给模型的消息 | |
"assistant_content": assistant_content, | |
"metadata": metadata or {} | |
} | |
with open(filepath, 'w', encoding='utf-8') as f: | |
json.dump(log_data, f, ensure_ascii=False, indent=2) | |
logger.info(f"对话日志已保存: {filepath}") | |
except Exception as e: | |
logger.error(f"保存对话日志失败: {str(e)}") | |
# =========== 图像处理工具函数 | |
def encode_image_to_base64(image_path_or_pil: Union[str, Image.Image]) -> str: | |
"""将图像文件或PIL图像对象转换为base64编码字符串""" | |
try: | |
if isinstance(image_path_or_pil, str): | |
# 如果是文件路径 | |
with open(image_path_or_pil, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
else: | |
# 如果是PIL图像对象,尽量保持原始格式和质量 | |
buffer = io.BytesIO() | |
# 检测原始格式,优先保持原格式 | |
original_format = getattr(image_path_or_pil, 'format', None) | |
if image_path_or_pil.mode == 'RGBA': | |
# 如果是RGBA模式且原格式支持透明度,优先保存为PNG | |
if original_format in ['PNG', 'WEBP'] or original_format is None: | |
image_path_or_pil.save(buffer, format="PNG") # PNG无损保存 | |
else: | |
# 否则转换为RGB并保存为高质量JPEG | |
rgb_image = Image.new('RGB', image_path_or_pil.size, (255, 255, 255)) | |
rgb_image.paste(image_path_or_pil, mask=image_path_or_pil.split()[-1]) | |
rgb_image.save(buffer, format="JPEG", quality=95) | |
else: | |
# 非RGBA模式,根据原格式选择保存方式 | |
if original_format == 'PNG': | |
image_path_or_pil.save(buffer, format="PNG") # PNG无损保存 | |
elif original_format in ['WEBP', 'BMP', 'TIFF']: | |
# 其他格式转为高质量JPEG | |
image_path_or_pil.save(buffer, format="JPEG", quality=95) | |
else: | |
# 默认保存为高质量JPEG | |
image_path_or_pil.save(buffer, format="JPEG", quality=95) | |
image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return image_base64 | |
except Exception as e: | |
logger.error(f"Error encoding image to base64: {str(e)}") | |
raise | |
def create_multimodal_content(text: str, images: List[Union[str, Image.Image]] = None) -> List[Dict]: | |
"""创建多模态内容格式,兼容OpenAI API""" | |
content = [] | |
# 添加文本内容 | |
if text and text.strip(): | |
content.append({ | |
"type": "text", | |
"text": text | |
}) | |
# 添加图像内容 | |
if images: | |
for i, image in enumerate(images): | |
try: | |
image_base64 = encode_image_to_base64(image) | |
content.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{image_base64}" | |
} | |
}) | |
logger.info(f"Added image {i+1}/{len(images)} to multimodal content") | |
except Exception as e: | |
logger.error(f"Failed to process image {i+1}: {str(e)}") | |
continue | |
return content if content else [{"type": "text", "text": text or ""}] | |
def convert_images_to_base64_list(images: List[Union[str, Image.Image]]) -> List[str]: | |
"""将图片列表转换为base64字符串列表,用于持久化存储""" | |
base64_images = [] | |
for i, image in enumerate(images): | |
try: | |
base64_str = encode_image_to_base64(image) | |
base64_images.append(base64_str) | |
logger.info(f"Converted image {i+1}/{len(images)} to base64 for storage") | |
except Exception as e: | |
logger.error(f"Failed to convert image {i+1} to base64: {str(e)}") | |
continue | |
return base64_images | |
def restore_images_from_base64_list(base64_images: List[str]) -> List[Image.Image]: | |
"""从base64字符串列表恢复图片对象""" | |
images = [] | |
for i, base64_str in enumerate(base64_images): | |
try: | |
image_data = base64.b64decode(base64_str) | |
image = Image.open(io.BytesIO(image_data)) | |
images.append(image) | |
logger.info(f"Restored image {i+1}/{len(base64_images)} from base64") | |
except Exception as e: | |
logger.error(f"Failed to restore image {i+1} from base64: {str(e)}") | |
continue | |
return images | |
class DeltaObject: | |
"""模拟OpenAI Delta对象""" | |
def __init__(self, data: dict): | |
self.content = data.get('content') | |
self.role = data.get('role') | |
class ChoiceObject: | |
"""模拟OpenAI Choice对象""" | |
def __init__(self, choice_data: dict): | |
delta_data = choice_data.get('delta', {}) | |
self.delta = DeltaObject(delta_data) | |
self.finish_reason = choice_data.get('finish_reason') | |
self.index = choice_data.get('index', 0) | |
class ChunkObject: | |
"""模拟OpenAI Chunk对象""" | |
def __init__(self, chunk_data: dict): | |
choices_data = chunk_data.get('choices', []) | |
self.choices = [ChoiceObject(choice) for choice in choices_data] | |
self.id = chunk_data.get('id', '') | |
self.object = chunk_data.get('object', 'chat.completion.chunk') | |
self.created = chunk_data.get('created', 0) | |
self.model = chunk_data.get('model', '') | |
class ProxyClient: | |
"""代理客户端,用于与中间服务通信,支持负载均衡""" | |
def __init__(self, timeout: int = 30): | |
self.timeout = timeout | |
self.session = requests.Session() | |
def chat_completions_create(self, model: str, messages: list, stream: bool = True, **kwargs): | |
"""创建聊天完成请求 - 使用负载均衡选择代理""" | |
base_url = get_next_proxy_url().rstrip('/') # 动态获取下一个代理URL | |
if base_url.endswith('/v1'): | |
url = f"{base_url}/chat/completions" | |
else: | |
url = f"{base_url}/v1/chat/completions" | |
payload = { | |
"model": model, | |
"messages": messages, | |
"stream": stream, | |
**kwargs | |
} | |
try: | |
response = self.session.post( | |
url, | |
json=payload, | |
stream=stream, | |
timeout=self.timeout, | |
headers={"Content-Type": "application/json"} | |
) | |
response.raise_for_status() | |
if stream: | |
return self._parse_stream_response(response) | |
else: | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Request failed: {str(e)}") | |
raise Exception(f"Failed to connect to proxy server: {str(e)}") | |
def _parse_stream_response(self, response) -> Generator[ChunkObject, None, None]: | |
"""解析流式响应""" | |
try: | |
# 确保响应编码正确 | |
response.encoding = 'utf-8' | |
for line in response.iter_lines(decode_unicode=True): | |
if not line: | |
continue | |
line = line.strip() | |
if line.startswith('data: '): | |
data = line[6:] # 移除 'data: ' 前缀 | |
if data == '[DONE]': | |
break | |
try: | |
chunk_data = json.loads(data) | |
# 检查是否是错误响应 | |
if 'error' in chunk_data: | |
raise Exception(f"Stream error: {chunk_data.get('detail', chunk_data['error'])}") | |
# 创建与OpenAI客户端兼容的响应对象 | |
yield ChunkObject(chunk_data) | |
except json.JSONDecodeError as e: | |
logger.warning(f"Failed to parse JSON: {data}, error: {str(e)}") | |
continue | |
except Exception as e: | |
logger.error(f"Error parsing stream response: {str(e)}") | |
raise | |
def health_check(self, specific_url: str = None) -> dict: | |
"""健康检查 - 可以检查特定URL或使用负载均衡选择""" | |
if specific_url: | |
base_url = specific_url.rstrip('/') | |
else: | |
base_url = get_next_proxy_url().rstrip('/') | |
try: | |
url = f"{base_url}/health" | |
response = self.session.get(url, timeout=self.timeout) | |
response.raise_for_status() | |
# 处理空响应体的情况 | |
if response.text.strip(): | |
result = response.json() | |
else: | |
# 如果响应体为空但状态码是200,认为服务健康 | |
logger.info(f"Health check for {base_url} returned empty response with 200 status, assuming healthy") | |
result = {"status": "healthy"} | |
result["proxy_url"] = base_url # 添加代理URL信息 | |
return result | |
except Exception as e: | |
logger.error(f"Health check failed for {base_url}: {str(e)}") | |
return {"status": "unhealthy", "error": str(e), "proxy_url": base_url} | |
def health_check_all(self) -> dict: | |
"""检查所有代理服务器的健康状态""" | |
results = {} | |
all_urls = get_all_proxy_urls() | |
for i, url in enumerate(all_urls): | |
results[f"proxy_{i}"] = self.health_check(specific_url=url) | |
# 统计健康状态 | |
healthy_count = sum(1 for result in results.values() if result.get("status") == "healthy") | |
total_count = len(results) | |
return { | |
"overall_status": "healthy" if healthy_count > 0 else "unhealthy", | |
"healthy_proxies": healthy_count, | |
"total_proxies": total_count, | |
"proxy_details": results | |
} | |
# 初始化代理客户端 | |
client = ProxyClient(PROXY_TIMEOUT) | |
# 显示代理配置信息 | |
logger.info("=== 代理服务器配置 ===") | |
logger.info(f"配置的代理服务器数量: {len(PROXY_BASE_URLS)}") | |
for i, url in enumerate(PROXY_BASE_URLS): | |
logger.info(f"代理 {i+1}: {url}") | |
logger.info(f"代理超时时间: {PROXY_TIMEOUT}秒") | |
logger.info(f"最大重试次数: {MAX_RETRIES}") | |
if len(PROXY_BASE_URLS) > 1: | |
logger.info("负载均衡模式: 轮询 (Round-robin)") | |
logger.info("========================") | |
def chat_with_retry(history_messages, max_retries=MAX_RETRIES): | |
"""带重试机制的聊天函数,支持代理故障转移""" | |
last_exception = None | |
failed_proxies = set() # 记录失败的代理 | |
# 如果有多个代理,每个代理都尝试一次,然后再进行常规重试 | |
total_proxies = len(PROXY_BASE_URLS) | |
max_proxy_attempts = min(total_proxies, max_retries) | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Chat attempt {attempt + 1}/{max_retries}") | |
# 如果在前几次尝试中,且有多个代理可用,进行健康检查 | |
if attempt < max_proxy_attempts and total_proxies > 1: | |
# 检查当前要使用的代理是否健康 | |
current_proxy_to_check = PROXY_BASE_URLS[current_proxy_index % total_proxies] | |
if current_proxy_to_check not in failed_proxies: | |
health = client.health_check(specific_url=current_proxy_to_check) | |
if health.get("status") != "healthy": | |
logger.warning(f"Proxy {current_proxy_to_check} is unhealthy, marking as failed") | |
failed_proxies.add(current_proxy_to_check) | |
# 跳过此代理,尝试下一个 | |
continue | |
response = client.chat_completions_create( | |
model=model, | |
messages=history_messages, | |
stream=True, | |
temperature=0.7, | |
top_p=0.9, | |
max_tokens=50000 | |
) | |
return response | |
except Exception as e: | |
last_exception = e | |
current_failed_proxy = None | |
# 尝试从错误信息中识别失败的代理 | |
try: | |
current_failed_proxy = PROXY_BASE_URLS[(current_proxy_index - 1) % total_proxies] | |
failed_proxies.add(current_failed_proxy) | |
logger.warning(f"Attempt {attempt + 1} failed with proxy {current_failed_proxy}: {str(e)}") | |
except: | |
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") | |
if attempt < max_retries - 1: | |
# 如果还有未尝试的代理,立即尝试下一个 | |
if len(failed_proxies) < total_proxies and attempt < max_proxy_attempts: | |
logger.info(f"Trying next proxy immediately...") | |
continue | |
else: | |
# 指数退避 | |
wait_time = min(2 ** (attempt - max_proxy_attempts + 1), 4) if attempt >= max_proxy_attempts else 1 | |
logger.info(f"Retrying in {wait_time} seconds...") | |
time.sleep(wait_time) | |
else: | |
logger.error(f"All {max_retries} attempts failed across {len(failed_proxies)}/{total_proxies} proxies") | |
raise last_exception | |
is_modelscope_studio = os.getenv('MODELSCOPE_ENVIRONMENT') == 'studio' | |
def get_text(text: str, cn_text: str): | |
if is_modelscope_studio: | |
return cn_text | |
return text | |
logo_img = os.path.join(os.path.dirname(__file__), "rednote_hilab.png") | |
DEFAULT_CONVERSATIONS_HISTORY = [{"role": "placeholder"}] | |
DEFAULT_LOCALE = 'zh_CN' if is_modelscope_studio else 'en_US' | |
DEFAULT_THEME = { | |
"token": { | |
"colorPrimary": "#6A57FF", | |
} | |
} | |
def format_history(history): | |
messages = [{ | |
"role": "system", | |
"content": "", | |
}] | |
for item in history: | |
if item["role"] == "user": | |
# 支持多模态内容格式 | |
content = item["content"] | |
if isinstance(content, dict): | |
if "multimodal" in content: | |
# 如果是保存的多模态内容,直接使用 | |
messages.append({ | |
"role": "user", | |
"content": content["multimodal"] | |
}) | |
logger.info(f"Added multimodal message with {content.get('images_count', 0)} images to context") | |
elif "images_base64" in content: | |
# 如果有base64图片数据,重新构建多模态内容 | |
text = content.get("text", "") | |
images_base64 = content.get("images_base64", []) | |
if images_base64: | |
# 从base64恢复图片并创建多模态内容 | |
restored_images = restore_images_from_base64_list(images_base64) | |
multimodal_content = create_multimodal_content(text, restored_images) | |
messages.append({ | |
"role": "user", | |
"content": multimodal_content | |
}) | |
logger.info(f"Restored and added multimodal message with {len(restored_images)} images to context") | |
else: | |
# 没有图片,只有文本 | |
messages.append({"role": "user", "content": text}) | |
else: | |
# 如果content是复杂对象,提取text字段 | |
text_content = content.get("text", str(content)) | |
messages.append({"role": "user", "content": text_content}) | |
else: | |
# 传统文本内容 | |
messages.append({"role": "user", "content": content}) | |
elif item["role"] == "assistant": | |
# 助手消息:合并thinking内容和content,保持原始格式 | |
assistant_content = item["content"] or "" | |
# 检查是否有thinking内容需要合并 | |
thinking_content = item.get("meta", {}).get("thinking_content", "") | |
if thinking_content: | |
# 重建完整的原始输出,不添加额外换行符 | |
# thinking_content 和 assistant_content 都已包含原始的换行符 | |
full_content = f"<think>{thinking_content}</think>{assistant_content}" | |
else: | |
full_content = assistant_content | |
messages.append({"role": "assistant", "content": full_content}) | |
return messages | |
class Gradio_Events: | |
def _submit(state_value): | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
# submit | |
history_messages = format_history(history) | |
history.append({ | |
"role": "assistant", | |
"content": "", | |
"key": str(uuid.uuid4()), | |
"meta": { | |
"reason_content": "", | |
"thinking_content": "", # 添加thinking内容存储 | |
"is_thinking": False, # 添加thinking状态 | |
"thinking_done": False # 添加thinking完成状态 | |
}, | |
"loading": True, | |
}) | |
yield { | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value), | |
} | |
try: | |
response = chat_with_retry(history_messages) | |
thought_done = False | |
in_thinking = False | |
accumulated_content = "" | |
# 缓冲逻辑变量 | |
buffer_content = "" # 临时缓冲内容 | |
last_yield_time = time.time() | |
chunk_count = 0 | |
state_update_count = 0 # state更新计数器 | |
BUFFER_INTERVAL = 0.5 # 秒 - 缓冲时间间隔,减少网络交互频率 | |
BUFFER_CHUNKS = 5 # 每5个chunk强制yield - 平衡实时性和性能 | |
STATE_UPDATE_INTERVAL = 3 # 每3次yield更新一次state - 减少state传输频率 | |
# 优化state更新策略:减少不必要的历史数据传输 | |
for chunk in response: | |
# 安全地访问chunk属性 | |
if chunk.choices and len(chunk.choices) > 0: | |
content = chunk.choices[0].delta.content | |
else: | |
content = None | |
raise ValueError('Content is None') | |
history[-1]["loading"] = False | |
print(content, end='') | |
if content: | |
accumulated_content += content | |
buffer_content += content # 添加到缓冲 | |
chunk_count += 1 | |
# 检查是否进入thinking模式 | |
if "<think>" in accumulated_content and not in_thinking: | |
in_thinking = True | |
history[-1]["meta"]["is_thinking"] = True | |
# 提取thinking标签之前的内容并保存 | |
before_think = accumulated_content.split("<think>")[0] | |
if before_think.strip(): | |
# 保存thinking之前的内容 | |
history[-1]["content"] = before_think | |
# 重置accumulated_content为thinking标签后的内容 | |
think_parts = accumulated_content.split("<think>", 1) | |
if len(think_parts) > 1: | |
accumulated_content = think_parts[1] | |
else: | |
accumulated_content = "" | |
# 立即yield thinking状态变化,这种重要状态变化总是需要更新state | |
yield { | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value) | |
} | |
buffer_content = "" # 重置缓冲 | |
last_yield_time = time.time() | |
chunk_count = 0 | |
continue | |
# 检查是否退出thinking模式 | |
if "</think>" in accumulated_content and in_thinking: | |
in_thinking = False | |
history[-1]["meta"]["is_thinking"] = False | |
history[-1]["meta"]["thinking_done"] = True | |
history[-1]["meta"]["just_finished_thinking"] = True # 标记刚完成thinking | |
# 分离thinking内容和后续内容 | |
think_parts = accumulated_content.split("</think>", 1) | |
thinking_content = think_parts[0] | |
history[-1]["meta"]["thinking_content"] = thinking_content | |
# 处理thinking后的内容 - 追加而不是覆盖 | |
if len(think_parts) > 1: | |
after_think_content = think_parts[1] | |
if after_think_content.strip(): | |
# 如果之前已有内容,则追加;否则直接设置 | |
current_content = history[-1]["content"] or "" | |
history[-1]["content"] = current_content + after_think_content | |
accumulated_content = "" # 重置累积内容 | |
# 立即yield thinking完成状态,这种重要状态变化总是需要更新state | |
yield { | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value) | |
} | |
buffer_content = "" # 重置缓冲 | |
last_yield_time = time.time() | |
chunk_count = 0 | |
continue | |
# 缓冲检查:时间或chunk数达到时 yield | |
current_time = time.time() | |
should_yield = False | |
if (current_time - last_yield_time >= BUFFER_INTERVAL) or (chunk_count >= BUFFER_CHUNKS): | |
should_yield = True | |
# 在thinking模式中,更新thinking内容 | |
if in_thinking: | |
# 检查是否包含完整的thinking结束标签 | |
if "</think>" not in accumulated_content: | |
history[-1]["meta"]["thinking_content"] = accumulated_content | |
if should_yield: | |
state_update_count += 1 | |
# 条件更新state:只在特定间隔更新 | |
should_update_state = (state_update_count % STATE_UPDATE_INTERVAL == 0) | |
yield { | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value) if should_update_state else gr.skip() | |
} | |
buffer_content = "" | |
last_yield_time = current_time | |
chunk_count = 0 | |
else: | |
# 如果不在thinking模式中,正常添加内容到content | |
if not thought_done: | |
thought_done = True | |
if not history[-1]["content"]: # 如果content为空才初始化 | |
history[-1]["content"] = "" | |
# 应用缓冲内容到history | |
if should_yield: | |
# 将缓冲的内容添加到content中 | |
history[-1]["content"] += buffer_content | |
# 清除"刚完成thinking"标记,因为现在在正常输出内容 | |
if history[-1]["meta"].get("just_finished_thinking"): | |
history[-1]["meta"]["just_finished_thinking"] = False | |
state_update_count += 1 | |
# 条件更新state:只在特定间隔更新 | |
should_update_state = (state_update_count % STATE_UPDATE_INTERVAL == 0) | |
yield { | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value) if should_update_state else gr.skip() | |
} | |
# 重置缓冲 | |
buffer_content = "" | |
last_yield_time = current_time | |
chunk_count = 0 | |
else: | |
# 不yield,但需要更新content以保持逻辑一致性 | |
# 注意:这里不直接添加content,而是等待缓冲yield时一起添加 | |
pass | |
# 循环结束后,处理剩余的缓冲内容 | |
if buffer_content: | |
if in_thinking: | |
# 如果还在thinking模式中,更新thinking内容 | |
history[-1]["meta"]["thinking_content"] = accumulated_content | |
else: | |
# 如果不在thinking模式中,添加剩余内容 | |
if not history[-1]["content"]: | |
history[-1]["content"] = "" | |
history[-1]["content"] += buffer_content | |
# 清除"刚完成thinking"标记 | |
if history[-1]["meta"].get("just_finished_thinking"): | |
history[-1]["meta"]["just_finished_thinking"] = False | |
# 最终yield,确保所有内容都被发送并强制更新state | |
yield { | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value) # 最终总是更新state | |
} | |
history[-1]["meta"]["end"] = True | |
print("Answer: ", history[-1]["content"]) | |
# 保存对话日志(如果启用) | |
if save_conversation: | |
# 获取用户消息(倒数第二个消息) | |
user_message = None | |
for i in range(len(history) - 2, -1, -1): | |
if history[i]["role"] == "user": | |
user_message = history[i] | |
break | |
if user_message: | |
save_conversation_log( | |
history_messages=history_messages, # 这是发送给模型的原始数据 | |
assistant_content=history[-1]["content"], | |
metadata={ | |
"model": model, | |
"proxy_base_urls": PROXY_BASE_URLS, | |
"conversation_id": state_value["conversation_id"], | |
"thinking_content": history[-1]["meta"].get("thinking_content", ""), | |
"has_thinking": bool(history[-1]["meta"].get("thinking_content")) | |
} | |
) | |
except Exception as e: | |
history[-1]["loading"] = False | |
history[-1]["meta"]["end"] = True | |
history[-1]["meta"]["error"] = True | |
history[-1]["content"] = "Failed to respond, please try again." | |
yield { | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value) | |
} | |
print('Error: ',e) | |
raise e | |
def submit(sender_value, state_value): | |
if not state_value["conversation_id"]: | |
random_id = str(uuid.uuid4()) | |
history = [] | |
state_value["conversation_id"] = random_id | |
state_value["conversations_history"][random_id] = history | |
# 使用文本内容作为对话标签 | |
label = sender_value if isinstance(sender_value, str) else "New Chat" | |
state_value["conversations"].append({ | |
"label": label, | |
"key": random_id | |
}) | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
# 处理多模态内容 | |
uploaded_images = state_value.get("uploaded_images", []) | |
if uploaded_images: | |
# 创建多模态内容 | |
multimodal_content = create_multimodal_content(sender_value, uploaded_images) | |
# 转换图片为base64用于持久化存储 | |
images_base64 = convert_images_to_base64_list(uploaded_images) | |
message_content = { | |
"text": sender_value, | |
"images_count": len(uploaded_images), # 保存图片数量 | |
"images_base64": images_base64, # 保存base64图片数据 | |
"multimodal": multimodal_content # 用于API调用的多模态内容 | |
} | |
logger.info(f"Saving message with {len(uploaded_images)} images to history") | |
# 清空已上传的图片 | |
state_value["uploaded_images"] = [] | |
state_value["image_file_paths"] = [] | |
else: | |
# 纯文本内容 | |
message_content = sender_value | |
history.append({ | |
"role": "user", | |
"meta": {}, | |
"key": str(uuid.uuid4()), | |
"content": message_content | |
}) | |
# preprocess submit | |
yield Gradio_Events.preprocess_submit()(state_value) | |
try: | |
for chunk in Gradio_Events._submit(state_value): | |
yield chunk | |
except Exception as e: | |
raise e | |
finally: | |
# postprocess submit - 包括清空图片上传组件 | |
yield Gradio_Events.postprocess_submit(state_value) | |
def regenerate_message(state_value, e: gr.EventData): | |
conversation_key = e._data["component"]["conversationKey"] | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
index = -1 | |
for i, conversation in enumerate(history): | |
if conversation["key"] == conversation_key: | |
index = i | |
break | |
if index == -1: | |
yield gr.skip() | |
history = history[:index] | |
state_value["conversations_history"][ | |
state_value["conversation_id"]] = history | |
yield { | |
chatbot:gr.update(items=history), | |
state: gr.update(value=state_value) | |
} | |
# preprocess submit | |
yield Gradio_Events.preprocess_submit(clear_input=False)(state_value) | |
try: | |
for chunk in Gradio_Events._submit(state_value): | |
yield chunk | |
except Exception as e: | |
raise e | |
finally: | |
# postprocess submit | |
yield Gradio_Events.postprocess_submit(state_value) | |
def preprocess_submit(clear_input=True): | |
def preprocess_submit_handler(state_value): | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
for conversation in history: | |
if "meta" in conversation: | |
conversation["meta"]["disabled"] = True | |
return { | |
sender: gr.update(value=None, loading=True) if clear_input else gr.update(loading=True), | |
conversations: | |
gr.update(active_key=state_value["conversation_id"], | |
items=list( | |
map( | |
lambda item: { | |
**item, | |
"disabled": | |
True if item["key"] != state_value[ | |
"conversation_id"] else False, | |
}, state_value["conversations"]))), | |
add_conversation_btn: | |
gr.update(disabled=True), | |
clear_btn: | |
gr.update(disabled=True), | |
conversation_delete_menu_item: | |
gr.update(disabled=True), | |
chatbot: | |
gr.update(items=history), | |
state: | |
gr.update(value=state_value), | |
image_upload: gr.update(value=None), # 发送消息时立即清空图片上传组件 | |
green_image_indicator: gr.update(count=0, elem_style=dict(display="block")), # 左侧绿色指示器显示0 | |
trash_button: gr.update(elem_style=dict(display="none")), # 隐藏垃圾桶按钮 | |
stop_btn: gr.update(visible=True) # 显示停止按钮 | |
} | |
return preprocess_submit_handler | |
def postprocess_submit(state_value): | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
for conversation in history: | |
if "meta" in conversation: | |
conversation["meta"]["disabled"] = False | |
return { | |
sender: gr.update(loading=False), | |
conversation_delete_menu_item: gr.update(disabled=False), | |
clear_btn: gr.update(disabled=False), | |
conversations: gr.update(items=state_value["conversations"]), | |
add_conversation_btn: gr.update(disabled=False), | |
chatbot: gr.update(items=history), | |
state: gr.update(value=state_value), | |
stop_btn: gr.update(visible=False) # 隐藏停止按钮 | |
} | |
def cancel(state_value): | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
history[-1]["loading"] = False | |
history[-1]["meta"]["end"] = True | |
history[-1]["meta"]["canceled"] = True | |
return Gradio_Events.postprocess_submit(state_value) | |
def delete_message(state_value, e: gr.EventData): | |
conversation_key = e._data["component"]["conversationKey"] | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
history = [item for item in history if item["key"] != conversation_key] | |
state_value["conversations_history"][ | |
state_value["conversation_id"]] = history | |
return gr.update(items=history if len(history) > | |
0 else DEFAULT_CONVERSATIONS_HISTORY), gr.update( | |
value=state_value) | |
def edit_message(state_value, e: gr.EventData): | |
conversation_key = e._data["component"]["conversationKey"] | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
index = -1 | |
for i, conversation in enumerate(history): | |
if conversation["key"] == conversation_key: | |
index = i | |
break | |
if index == -1: | |
return gr.skip() | |
state_value["editing_message_index"] = index | |
text = '' | |
if isinstance(history[index]["content"], str): | |
text = history[index]["content"] | |
else: | |
text = history[index]["content"]["text"] | |
return gr.update(value=text), gr.update(value=state_value) | |
def confirm_edit_message(edit_textarea_value, state_value): | |
history = state_value["conversations_history"][ | |
state_value["conversation_id"]] | |
message = history[state_value["editing_message_index"]] | |
if isinstance(message["content"], str): | |
message["content"] = edit_textarea_value | |
else: | |
message["content"]["text"] = edit_textarea_value | |
return gr.update(items=history), gr.update(value=state_value) | |
def select_suggestion(sender_value, e: gr.EventData): | |
return gr.update(value=sender_value[:-1] + e._data["payload"][0]) | |
def new_chat(state_value): | |
if not state_value["conversation_id"]: | |
return gr.skip() | |
state_value["conversation_id"] = "" | |
# 清空上传的图片(修复新对话图片泄露bug) | |
state_value["uploaded_images"] = [] | |
state_value["image_file_paths"] = [] | |
return gr.update(active_key=state_value["conversation_id"]), gr.update( | |
items=DEFAULT_CONVERSATIONS_HISTORY), gr.update(value=state_value) | |
def select_conversation(state_value, e: gr.EventData): | |
active_key = e._data["payload"][0] | |
if state_value["conversation_id"] == active_key or ( | |
active_key not in state_value["conversations_history"]): | |
return gr.skip() | |
state_value["conversation_id"] = active_key | |
# 切换对话时清空上传的图片(避免图片泄露到其他对话) | |
state_value["uploaded_images"] = [] | |
state_value["image_file_paths"] = [] | |
return gr.update(active_key=active_key), gr.update( | |
items=state_value["conversations_history"][active_key]), gr.update( | |
value=state_value) | |
def click_conversation_menu(state_value, e: gr.EventData): | |
conversation_id = e._data["payload"][0]["key"] | |
operation = e._data["payload"][1]["key"] | |
if operation == "delete": | |
del state_value["conversations_history"][conversation_id] | |
state_value["conversations"] = [ | |
item for item in state_value["conversations"] | |
if item["key"] != conversation_id | |
] | |
if state_value["conversation_id"] == conversation_id: | |
state_value["conversation_id"] = "" | |
# 删除当前对话时清空上传的图片 | |
state_value["uploaded_images"] = [] | |
state_value["image_file_paths"] = [] | |
return gr.update( | |
items=state_value["conversations"], | |
active_key=state_value["conversation_id"]), gr.update( | |
items=DEFAULT_CONVERSATIONS_HISTORY), gr.update( | |
value=state_value) | |
else: | |
return gr.update( | |
items=state_value["conversations"]), gr.skip(), gr.update( | |
value=state_value) | |
return gr.skip() | |
def clear_conversation_history(state_value): | |
if not state_value["conversation_id"]: | |
return gr.skip() | |
state_value["conversations_history"][ | |
state_value["conversation_id"]] = [] | |
# 清空对话历史时也清空上传的图片 | |
state_value["uploaded_images"] = [] | |
state_value["image_file_paths"] = [] | |
return gr.update(items=DEFAULT_CONVERSATIONS_HISTORY), gr.update( | |
value=state_value) | |
def close_modal(): | |
return gr.update(open=False) | |
def open_modal(): | |
return gr.update(open=True) | |
def update_browser_state(state_value): | |
return gr.update(value=dict( | |
conversations=state_value["conversations"], | |
conversations_history=state_value["conversations_history"])) | |
def apply_browser_state(browser_state_value, state_value): | |
state_value["conversations"] = browser_state_value["conversations"] | |
state_value["conversations_history"] = browser_state_value[ | |
"conversations_history"] | |
return gr.update( | |
items=browser_state_value["conversations"]), gr.update( | |
value=state_value) | |
def handle_image_upload(files, state_value): | |
"""处理图片上传 - 支持拖拽和粘贴功能""" | |
logger.info(f"handle_image_upload called with files: {files}, type: {type(files)}") | |
if not files: | |
# 没有文件时重置为默认状态 | |
logger.info("No files provided, resetting to default state") | |
return ( | |
gr.update(value=state_value), | |
gr.update(count=0, elem_style=dict(display="block")), # 左侧绿色指示器显示0 | |
gr.update(elem_style=dict(display="none")), # 隐藏垃圾桶按钮 | |
) | |
# 显示上传中状态 | |
logger.info("Upload in progress...") | |
try: | |
# 处理上传的文件 | |
uploaded_images = [] | |
image_file_paths = [] | |
# 确保files是列表格式 | |
if not isinstance(files, list): | |
files = [files] if files else [] | |
for i, file_info in enumerate(files): | |
logger.info(f"Processing file {i}: {file_info}, type: {type(file_info)}") | |
file_path = None | |
if isinstance(file_info, dict): | |
# 如果是文件信息字典(Gradio上传格式) | |
file_path = file_info.get('name') or file_info.get('path') | |
logger.info(f"Extracted path from dict: {file_path}") | |
elif isinstance(file_info, str): | |
# 如果直接是文件路径 | |
file_path = file_info | |
logger.info(f"Direct file path: {file_path}") | |
elif hasattr(file_info, 'name') and hasattr(file_info, 'read'): | |
# 如果是文件对象(拖拽/粘贴可能产生) | |
logger.info(f"File object detected: {file_info.name if hasattr(file_info, 'name') else 'unnamed'}") | |
# 对于文件对象,我们需要特殊处理 | |
try: | |
if hasattr(file_info, 'name'): | |
file_path = file_info.name | |
else: | |
# 创建临时文件名 | |
import tempfile | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: | |
if hasattr(file_info, 'read'): | |
tmp_file.write(file_info.read()) | |
file_path = tmp_file.name | |
logger.info(f"Created temporary file: {file_path}") | |
except Exception as file_error: | |
logger.error(f"Error processing file object: {str(file_error)}") | |
continue | |
else: | |
logger.warning(f"Unknown file format: {type(file_info)}") | |
continue | |
if file_path: | |
try: | |
# 保存文件路径 | |
image_file_paths.append(file_path) | |
logger.info(f"Added to image_file_paths: {file_path}") | |
# 使用PIL加载图片 | |
image = Image.open(file_path) | |
logger.info(f"Loaded image with size: {image.size} (原始尺寸,不进行缩放)") | |
# 保持原始图片尺寸,不进行任何缩放处理 | |
uploaded_images.append(image) | |
except Exception as img_error: | |
logger.error(f"Error processing image {file_path}: {str(img_error)}") | |
continue | |
# 替换而不是追加图片(修复累积bug) | |
state_value["uploaded_images"] = uploaded_images | |
state_value["image_file_paths"] = image_file_paths | |
logger.info(f"Successfully uploaded {len(uploaded_images)} images via drag/paste/upload") | |
# 显示状态指示器,显示图片数量 | |
return ( | |
gr.update(value=state_value), | |
gr.update(count=len(uploaded_images), elem_style=dict(display="block")), # 左侧绿色指示器 | |
gr.update(elem_style=dict(display="block")), # 显示垃圾桶按钮 | |
) | |
except Exception as e: | |
logger.error(f"Error handling image upload: {str(e)}") | |
import traceback | |
logger.error(f"Full traceback: {traceback.format_exc()}") | |
return ( | |
gr.update(value=state_value), | |
gr.update(count=0, elem_style=dict(display="block")), # 左侧绿色指示器显示0 | |
gr.update(elem_style=dict(display="none")), # 隐藏垃圾桶按钮 | |
) | |
def clear_images(state_value): | |
"""清空上传的图片""" | |
state_value["uploaded_images"] = [] | |
state_value["image_file_paths"] = [] | |
logger.info("Cleared all uploaded images") | |
return ( | |
gr.update(value=state_value), | |
gr.update(count=0, elem_style=dict(display="block")), # 左侧绿色指示器显示0 | |
gr.update(elem_style=dict(display="none")), # 隐藏垃圾桶按钮 | |
gr.update(value=None), # 清空图片上传组件 | |
) | |
css = """ | |
.gradio-container { | |
padding: 0 !important; | |
} | |
.gradio-container > main.fillable { | |
padding: 0 !important; | |
} | |
#chatbot { | |
height: calc(100vh - 21px - 16px); | |
} | |
#chatbot .chatbot-conversations { | |
height: 100%; | |
background-color: var(--ms-gr-ant-color-bg-layout); | |
} | |
#chatbot .chatbot-conversations .chatbot-conversations-list { | |
padding-left: 0; | |
padding-right: 0; | |
} | |
#chatbot .chatbot-chat { | |
padding: 32px; | |
height: 100%; | |
} | |
@media (max-width: 768px) { | |
#chatbot .chatbot-chat { | |
padding: 0; | |
} | |
} | |
#chatbot .chatbot-chat .chatbot-chat-messages { | |
flex: 1; | |
} | |
#chatbot .chatbot-chat .chatbot-chat-messages .chatbot-chat-message .chatbot-chat-message-footer { | |
visibility: hidden; | |
opacity: 0; | |
transition: opacity 0.2s; | |
} | |
#chatbot .chatbot-chat .chatbot-chat-message:last-child .chatbot-chat-message-footer { | |
visibility: visible; | |
opacity: 1; | |
} | |
#chatbot .chatbot-chat .chatbot-chat-message:hover .chatbot-chat-message-footer { | |
visibility: visible; | |
opacity: 1; | |
} | |
/* Thinking区域样式 */ | |
.thinking-content .ant-collapse { | |
background: linear-gradient(135deg, #f8f9fc 0%, #f2f5f8 100%); | |
border: 1px solid #e1e8ed; | |
border-radius: 8px; | |
margin-bottom: 12px; | |
} | |
.thinking-content .ant-collapse > .ant-collapse-item > .ant-collapse-header { | |
padding: 8px 12px; | |
font-size: 13px; | |
color: #5a6c7d; | |
font-weight: 500; | |
} | |
.thinking-content .ant-collapse-content > .ant-collapse-content-box { | |
padding: 12px; | |
background: #fafbfc; | |
border-radius: 0 0 6px 6px; | |
font-size: 13px; | |
color: #667788; | |
line-height: 1.5; | |
white-space: pre-wrap; | |
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace; | |
} | |
.thinking-content .ant-collapse-content-box .markdown-body { | |
font-size: 13px; | |
line-height: 1.5; | |
color: #667788; | |
} | |
.thinking-content .ant-collapse-content-box pre { | |
background: #f6f8fa; | |
padding: 8px; | |
border-radius: 4px; | |
overflow: auto; | |
} | |
.thinking-content .ant-collapse-content-box h1, | |
.thinking-content .ant-collapse-content-box h2, | |
.thinking-content .ant-collapse-content-box h3, | |
.thinking-content .ant-collapse-content-box h4, | |
.thinking-content .ant-collapse-content-box h5, | |
.thinking-content .ant-collapse-content-box h6 { | |
margin-top: 16px; | |
margin-bottom: 8px; | |
font-weight: 600; | |
} | |
.thinking-content .ant-collapse-content-box ul, | |
.thinking-content .ant-collapse-content-box ol { | |
margin: 8px 0; | |
padding-left: 20px; | |
} | |
.thinking-content .ant-collapse-content-box li { | |
margin: 4px 0; | |
} | |
.thinking-content .ant-collapse-content-box code { | |
background: #f1f3f4; | |
padding: 2px 4px; | |
border-radius: 3px; | |
font-size: 85%; | |
} | |
/* 图片预览和展示样式 */ | |
.image-preview-container { | |
background: #fafafa; | |
border: 1px solid #d9d9d9; | |
border-radius: 8px; | |
padding: 12px; | |
margin-bottom: 12px; | |
} | |
.image-gallery img { | |
transition: all 0.2s ease; | |
border-radius: 4px; | |
} | |
.image-gallery img:hover { | |
transform: scale(1.05); | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); | |
z-index: 10; | |
position: relative; | |
} | |
.image-thumbnail { | |
position: relative; | |
display: inline-block; | |
margin: 4px; | |
border-radius: 6px; | |
overflow: hidden; | |
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); | |
transition: all 0.2s ease; | |
} | |
.image-thumbnail:hover { | |
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.2); | |
transform: translateY(-2px); | |
} | |
.image-upload-preview { | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
border: 2px dashed #d9d9d9; | |
border-radius: 8px; | |
padding: 16px; | |
margin-bottom: 16px; | |
text-align: center; | |
transition: all 0.3s ease; | |
} | |
.image-upload-preview.has-images { | |
border-style: solid; | |
border-color: #6A57FF; | |
background: linear-gradient(135deg, #f6f9fc 0%, #f0f4f8 100%); | |
} | |
/* 拖拽区域样式 */ | |
.drop-zone { | |
position: relative; | |
transition: all 0.3s ease; | |
} | |
.drop-zone.drag-over { | |
background: linear-gradient(135deg, #e6f7ff 0%, #d6f7ff 100%); | |
border: 2px dashed #1890ff; | |
border-radius: 8px; | |
} | |
.drop-zone.drag-over::before { | |
content: "释放以上传图片"; | |
position: absolute; | |
top: 50%; | |
left: 50%; | |
transform: translate(-50%, -50%); | |
background: rgba(24, 144, 255, 0.9); | |
color: white; | |
padding: 12px 24px; | |
border-radius: 6px; | |
font-size: 16px; | |
font-weight: 500; | |
z-index: 1000; | |
pointer-events: none; | |
} | |
/* 响应式图片展示 */ | |
@media (max-width: 768px) { | |
.image-gallery img { | |
width: 80px !important; | |
height: 60px !important; | |
} | |
.image-thumbnail { | |
width: 80px; | |
height: 60px; | |
} | |
} | |
/* 图片加载动画 */ | |
@keyframes imageLoad { | |
from { opacity: 0; transform: scale(0.8); } | |
to { opacity: 1; transform: scale(1); } | |
} | |
.image-gallery img { | |
animation: imageLoad 0.3s ease; | |
} | |
/* 粘贴提示样式 */ | |
.paste-hint { | |
position: fixed; | |
top: 20px; | |
right: 20px; | |
background: rgba(24, 144, 255, 0.9); | |
color: white; | |
padding: 8px 16px; | |
border-radius: 6px; | |
font-size: 14px; | |
z-index: 1001; | |
opacity: 0; | |
transform: translateY(-10px); | |
transition: all 0.3s ease; | |
} | |
.paste-hint.show { | |
opacity: 1; | |
transform: translateY(0); | |
} | |
""" | |
# 添加JavaScript代码来处理拖拽和粘贴 | |
drag_and_paste_js = """ | |
<script> | |
(function() { | |
let isInitialized = false; | |
function initializeDragAndPaste() { | |
if (isInitialized) return; | |
isInitialized = true; | |
console.log('Initializing drag and paste functionality...'); | |
// 创建粘贴提示元素 | |
const pasteHint = document.createElement('div'); | |
pasteHint.className = 'paste-hint'; | |
pasteHint.textContent = '检测到剪贴板中的图片,按 Ctrl+V 粘贴'; | |
document.body.appendChild(pasteHint); | |
// 获取聊天容器作为拖拽区域 | |
const chatContainer = document.querySelector('#chatbot .chatbot-chat') || document.body; | |
// 防止默认的拖拽行为 | |
['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => { | |
chatContainer.addEventListener(eventName, preventDefaults, false); | |
document.body.addEventListener(eventName, preventDefaults, false); | |
}); | |
function preventDefaults(e) { | |
e.preventDefault(); | |
e.stopPropagation(); | |
} | |
// 拖拽进入 | |
['dragenter', 'dragover'].forEach(eventName => { | |
chatContainer.addEventListener(eventName, highlight, false); | |
}); | |
// 拖拽离开 | |
['dragleave', 'drop'].forEach(eventName => { | |
chatContainer.addEventListener(eventName, unhighlight, false); | |
}); | |
function highlight(e) { | |
if (e.dataTransfer.types.includes('Files')) { | |
chatContainer.classList.add('drop-zone', 'drag-over'); | |
} | |
} | |
function unhighlight(e) { | |
chatContainer.classList.remove('drop-zone', 'drag-over'); | |
} | |
// 处理文件放置 | |
chatContainer.addEventListener('drop', handleDrop, false); | |
function handleDrop(e) { | |
const dt = e.dataTransfer; | |
const files = dt.files; | |
if (files.length > 0) { | |
handleFileUpload(Array.from(files)); | |
} | |
} | |
// 处理粘贴事件 | |
document.addEventListener('paste', handlePaste, false); | |
function handlePaste(e) { | |
const items = e.clipboardData.items; | |
const imageFiles = []; | |
for (let i = 0; i < items.length; i++) { | |
if (items[i].type.indexOf('image') === 0) { | |
const file = items[i].getAsFile(); | |
if (file) { | |
imageFiles.push(file); | |
} | |
} | |
} | |
if (imageFiles.length > 0) { | |
e.preventDefault(); | |
handleFileUpload(imageFiles); | |
showPasteSuccess(); | |
} | |
} | |
// 显示粘贴成功提示 | |
function showPasteSuccess() { | |
pasteHint.textContent = '图片粘贴成功!'; | |
pasteHint.classList.add('show'); | |
setTimeout(() => { | |
pasteHint.classList.remove('show'); | |
}, 2000); | |
} | |
// 处理文件上传 | |
function handleFileUpload(files) { | |
console.log('Processing files:', files); | |
// 过滤只保留图片文件 | |
const imageFiles = files.filter(file => file.type.startsWith('image/')); | |
if (imageFiles.length === 0) { | |
console.log('No image files found'); | |
return; | |
} | |
// 查找上传组件 | |
const uploadInput = document.querySelector('input[type="file"][accept*="image"]'); | |
if (!uploadInput) { | |
console.error('Upload input not found'); | |
return; | |
} | |
try { | |
// 创建新的文件列表 | |
const dt = new DataTransfer(); | |
imageFiles.forEach(file => { | |
dt.items.add(file); | |
}); | |
// 设置文件到上传组件 | |
uploadInput.files = dt.files; | |
// 触发 change 事件 | |
const changeEvent = new Event('change', { bubbles: true }); | |
uploadInput.dispatchEvent(changeEvent); | |
console.log(`Successfully uploaded ${imageFiles.length} image(s)`); | |
// 显示成功提示 | |
showUploadSuccess(imageFiles.length); | |
} catch (error) { | |
console.error('Error uploading files:', error); | |
} | |
} | |
// 显示上传成功提示 | |
function showUploadSuccess(count) { | |
pasteHint.textContent = `成功上传 ${count} 张图片!`; | |
pasteHint.classList.add('show'); | |
setTimeout(() => { | |
pasteHint.classList.remove('show'); | |
}, 2000); | |
} | |
// 监听剪贴板变化(可选功能) | |
document.addEventListener('keydown', function(e) { | |
if (e.ctrlKey && e.key === 'v') { | |
// 检查是否聚焦在输入框上 | |
const activeElement = document.activeElement; | |
const isInInputArea = activeElement && ( | |
activeElement.tagName === 'TEXTAREA' || | |
activeElement.tagName === 'INPUT' || | |
activeElement.contentEditable === 'true' | |
); | |
if (isInInputArea) { | |
// 短暂显示提示 | |
setTimeout(() => { | |
if (navigator.clipboard && navigator.clipboard.read) { | |
navigator.clipboard.read().then(items => { | |
const hasImage = items.some(item => | |
item.types.some(type => type.startsWith('image/')) | |
); | |
if (hasImage) { | |
pasteHint.textContent = '检测到图片,正在处理...'; | |
pasteHint.classList.add('show'); | |
setTimeout(() => { | |
pasteHint.classList.remove('show'); | |
}, 1500); | |
} | |
}).catch(() => { | |
// 忽略权限错误 | |
}); | |
} | |
}, 100); | |
} | |
} | |
}); | |
console.log('Drag and paste functionality initialized successfully'); | |
} | |
// 初始化函数 | |
function init() { | |
if (document.readyState === 'loading') { | |
document.addEventListener('DOMContentLoaded', initializeDragAndPaste); | |
} else { | |
initializeDragAndPaste(); | |
} | |
} | |
// 如果Gradio还没有完全加载,等待一下 | |
if (window.gradio && window.gradio.mount) { | |
init(); | |
} else { | |
// 等待Gradio加载 | |
setTimeout(init, 1000); | |
} | |
// 也监听window load事件作为备选 | |
window.addEventListener('load', initializeDragAndPaste); | |
})(); | |
</script> | |
""" | |
def logo(): | |
with antd.Typography.Title(level=1, | |
elem_style=dict(fontSize=24, | |
padding=8, | |
margin=0)): | |
with antd.Flex(align="center", gap="small", justify="center"): | |
antd.Image(logo_img, | |
preview=False, | |
alt="logo", | |
width=24, | |
height=24) | |
ms.Span("dots.vlm1.inst") | |
with gr.Blocks(css=css, fill_width=True, head=drag_and_paste_js) as demo: | |
state = gr.State({ | |
"conversations_history": {}, | |
"conversations": [], | |
"conversation_id": "", | |
"editing_message_index": -1, | |
"uploaded_images": [], # 存储当前上传的图片 | |
"image_file_paths": [], # 存储图片文件路径用于预览 | |
}) | |
with ms.Application(), antdx.XProvider( | |
theme=DEFAULT_THEME, locale=DEFAULT_LOCALE), ms.AutoLoading(): | |
with antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"): | |
# Left Column | |
with antd.Col(md=dict(flex="0 0 260px", span=24, order=0), | |
span=0, | |
order=1, | |
elem_classes="chatbot-conversations", | |
elem_style=dict( | |
maxWidth="260px", | |
minWidth="260px", | |
overflow="hidden")): | |
with antd.Flex(vertical=True, | |
gap="small", | |
elem_style=dict(height="100%", width="100%", minWidth="0")): | |
# Logo | |
logo() | |
# New Conversation Button | |
with antd.Button(value=None, | |
color="primary", | |
variant="filled", | |
block=True, elem_style=dict(maxWidth="100%")) as add_conversation_btn: | |
ms.Text(get_text("New Conversation", "新建对话")) | |
with ms.Slot("icon"): | |
antd.Icon("PlusOutlined") | |
# Conversations List | |
with antdx.Conversations( | |
elem_classes="chatbot-conversations-list", | |
elem_style=dict( | |
width="100%", | |
minWidth="0", | |
overflow="hidden", | |
flex="1" | |
) | |
) as conversations: | |
with ms.Slot('menu.items'): | |
with antd.Menu.Item( | |
label="Delete", key="delete", danger=True | |
) as conversation_delete_menu_item: | |
with ms.Slot("icon"): | |
antd.Icon("DeleteOutlined") | |
# Right Column | |
with antd.Col(flex=1, elem_style=dict(height="100%")): | |
with antd.Flex(vertical=True, | |
gap="middle", | |
elem_classes="chatbot-chat"): | |
# Chatbot | |
with antdx.Bubble.List( | |
items=DEFAULT_CONVERSATIONS_HISTORY, | |
elem_classes="chatbot-chat-messages") as chatbot: | |
# Define Chatbot Roles | |
with ms.Slot("roles"): | |
# Placeholder Role | |
with antdx.Bubble.List.Role( | |
role="placeholder", | |
styles=dict(content=dict(width="100%")), | |
variant="borderless"): | |
with ms.Slot("messageRender"): | |
with antd.Space( | |
direction="vertical", | |
size=16, | |
elem_style=dict(width="100%")): | |
with antdx.Welcome( | |
styles=dict(icon=dict( | |
flexShrink=0)), | |
variant="borderless", | |
title=get_text( | |
"Hello, I'm dots.", | |
"你好,我是 dots."), | |
description=get_text( | |
"", | |
""), | |
): | |
with ms.Slot("icon"): | |
antd.Image(logo_img, | |
preview=False) | |
# User Role | |
with antdx.Bubble.List.Role( | |
role="user", | |
placement="end", | |
elem_classes="chatbot-chat-message", | |
class_names=dict( | |
footer="chatbot-chat-message-footer"), | |
styles=dict(content=dict( | |
maxWidth="100%", | |
overflow='auto', | |
))): | |
with ms.Slot( | |
"messageRender", | |
params_mapping="""(content) => { | |
// 检查多种图片存储格式 | |
let imageCount = 0; | |
let textContent = ''; | |
let imagesBase64 = []; | |
if (typeof content === 'object') { | |
// 新格式:检查 images_count | |
if (content.images_count && content.images_count > 0) { | |
imageCount = content.images_count; | |
textContent = content.text || ''; | |
imagesBase64 = content.images_base64 || []; | |
} | |
// 旧格式:检查 images 数组 | |
else if (content.images && content.images.length > 0) { | |
imageCount = content.images.length; | |
textContent = content.text || ''; | |
imagesBase64 = content.images || []; | |
} | |
// 纯文本格式 | |
else { | |
textContent = content.text || content; | |
} | |
} else { | |
// 字符串格式 | |
textContent = content; | |
} | |
if (imageCount > 0 && imagesBase64.length > 0) { | |
const imageHtml = imagesBase64.map((base64, index) => | |
`<img src="data:image/jpeg;base64,${base64}" | |
style="width: 120px; height: 90px; object-fit: cover; border-radius: 6px; margin: 4px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); cursor: pointer;" | |
alt="Image ${index + 1}" />` | |
).join(''); | |
return { | |
image_info: { | |
style: { marginBottom: '8px', fontSize: '13px', color: '#666' }, | |
value: `📷 包含 ${imageCount} 张图片` | |
}, | |
text_content: { | |
value: textContent | |
}, | |
image_gallery: { | |
value: `<div style="display: flex; flex-wrap: wrap; gap: 8px; margin-bottom: 12px;">${imageHtml}</div>` | |
} | |
}; | |
} | |
return { | |
text_content: { | |
value: textContent | |
}, | |
image_info: { style: { display: 'none' } }, | |
image_gallery: { value: '' } | |
}; | |
}"""): | |
# 图片信息提示 | |
antd.Typography.Text(as_item="image_info", type="secondary") | |
# 图片展示区域 - 使用Markdown组件显示HTML | |
ms.Markdown(as_item="image_gallery") | |
# 文本内容 | |
ms.Markdown(as_item="text_content") | |
with ms.Slot("footer", | |
params_mapping="""(bubble) => { | |
return { | |
copy_btn: { | |
copyable: { text: typeof bubble.content === 'string' ? bubble.content : bubble.content?.text, tooltips: false }, | |
}, | |
edit_btn: { conversationKey: bubble.key, disabled: bubble.meta.disabled }, | |
delete_btn: { conversationKey: bubble.key, disabled: bubble.meta.disabled }, | |
}; | |
}"""): | |
with antd.Typography.Text( | |
copyable=dict(tooltips=False), | |
as_item="copy_btn"): | |
with ms.Slot("copyable.icon"): | |
with antd.Button(value=None, | |
size="small", | |
color="default", | |
variant="text"): | |
with ms.Slot("icon"): | |
antd.Icon("CopyOutlined") | |
with antd.Button(value=None, | |
size="small", | |
color="default", | |
variant="text"): | |
with ms.Slot("icon"): | |
antd.Icon("CheckOutlined") | |
with antd.Button(value=None, | |
size="small", | |
color="default", | |
variant="text", | |
as_item="edit_btn" | |
) as user_edit_btn: | |
with ms.Slot("icon"): | |
antd.Icon("EditOutlined") | |
with antd.Popconfirm( | |
title="Delete the message", | |
description= | |
"Are you sure to delete this message?", | |
ok_button_props=dict(danger=True), | |
as_item="delete_btn" | |
) as user_delete_popconfirm: | |
with antd.Button(value=None, | |
size="small", | |
color="default", | |
variant="text", | |
as_item="delete_btn"): | |
with ms.Slot("icon"): | |
antd.Icon("DeleteOutlined") | |
# Chatbot Role | |
with antdx.Bubble.List.Role( | |
role="assistant", | |
placement="start", | |
elem_classes="chatbot-chat-message", | |
class_names=dict( | |
footer="chatbot-chat-message-footer"), | |
styles=dict(content=dict( | |
maxWidth="100%", overflow='auto'))): | |
with ms.Slot("avatar"): | |
antd.Avatar( | |
os.path.join(os.path.dirname(__file__), | |
"rednote_hilab.png")) | |
with ms.Slot( | |
"messageRender", | |
params_mapping="""(content, bubble) => { | |
const has_error = bubble?.meta?.error | |
const thinking_content = bubble?.meta?.thinking_content || "" | |
const is_thinking = bubble?.meta?.is_thinking || false | |
const thinking_done = bubble?.meta?.thinking_done || false | |
const just_finished_thinking = bubble?.meta?.just_finished_thinking || false | |
// 改进的自动折叠逻辑: | |
// 1. 刚完成thinking且有回答内容时自动折叠 | |
// 2. 考虑用户交互状态,避免频繁重置 | |
const shouldAutoCollapse = just_finished_thinking && content && content.trim().length > 0 | |
// 动态生成唯一key以触发组件重新渲染,但保持用户控制能力 | |
let collapseKey = 'thinking' | |
let collapseProps = {} | |
if (shouldAutoCollapse) { | |
// 刚完成thinking且有内容时,设置为折叠状态 | |
// 使用时间戳确保key的唯一性,触发折叠 | |
collapseKey = 'thinking-auto-collapsed-' + Date.now() | |
collapseProps.active_key = [] // 强制折叠 | |
} else if (thinking_done) { | |
// thinking完成但用户可能已经手动展开,使用稳定key | |
collapseKey = 'thinking-user-controlled' | |
// 不设置active_key,让用户控制 | |
} else { | |
// thinking进行中,默认展开 | |
collapseKey = 'thinking-active' | |
collapseProps.default_active_key = ['1'] | |
} | |
return { | |
thinking_collapse_props: Object.assign({ | |
key: collapseKey, | |
style: { | |
display: (thinking_content || is_thinking) ? 'block' : 'none', | |
marginBottom: thinking_content || is_thinking ? '12px' : '0' | |
} | |
}, collapseProps), | |
thinking_label: is_thinking ? '🤔 正在思考...' : '🤔 思考过程', | |
thinking_markdown: { | |
value: thinking_content || '思考中...' | |
}, | |
answer: { | |
value: content | |
}, | |
canceled: bubble.meta?.canceled ? undefined : { style: { display: 'none' } } | |
} | |
}"""): | |
# Thinking区域 - 可折叠 + Markdown 渲染 | |
with antd.Collapse( | |
size='small', | |
ghost=True, | |
elem_classes="thinking-content", | |
as_item="thinking_collapse_props" # 动态控制所有属性 | |
): | |
with antd.Collapse.Item( | |
as_item="thinking_label", # 动态 label 作为 header | |
key='1', | |
force_render=True # 确保即使折叠也预渲染内容(可选,提高性能) | |
): | |
ms.Markdown(as_item="thinking_markdown") # 动态 value,支持 Markdown 渲染 | |
# 回答内容 | |
ms.Markdown( | |
as_item="answer", | |
elem_classes="answer-content") | |
antd.Divider(as_item="canceled") | |
antd.Typography.Text(get_text( | |
"Chat completion paused.", "聊天已暂停。"), | |
as_item="canceled", | |
type="warning") | |
with ms.Slot("footer", | |
params_mapping="""(bubble) => { | |
if (bubble?.meta?.end) { | |
return { | |
copy_btn: { | |
copyable: { text: bubble.content, tooltips: false }, | |
}, | |
regenerate_btn: { conversationKey: bubble.key, disabled: bubble.meta.disabled }, | |
delete_btn: { conversationKey: bubble.key, disabled: bubble.meta.disabled }, | |
edit_btn: { conversationKey: bubble.key, disabled: bubble.meta.disabled }, | |
}; | |
} | |
return { actions_container: { style: { display: 'none' } } }; | |
}"""): | |
with ms.Div(as_item="actions_container"): | |
with antd.Typography.Text( | |
copyable=dict(tooltips=False), | |
as_item="copy_btn"): | |
with ms.Slot("copyable.icon"): | |
with antd.Button( | |
value=None, | |
size="small", | |
color="default", | |
variant="text"): | |
with ms.Slot("icon"): | |
antd.Icon( | |
"CopyOutlined") | |
with antd.Button( | |
value=None, | |
size="small", | |
color="default", | |
variant="text"): | |
with ms.Slot("icon"): | |
antd.Icon( | |
"CheckOutlined") | |
with antd.Popconfirm( | |
title=get_text( | |
"Regenerate the message", | |
"重新生成消息"), | |
description=get_text( | |
"Regenerate the message will also delete all subsequent messages.", | |
"重新生成消息将会删除所有的后续消息。"), | |
ok_button_props=dict( | |
danger=True), | |
as_item="regenerate_btn" | |
) as chatbot_regenerate_popconfirm: | |
with antd.Button( | |
value=None, | |
size="small", | |
color="default", | |
variant="text", | |
as_item="regenerate_btn", | |
): | |
with ms.Slot("icon"): | |
antd.Icon("SyncOutlined") | |
with antd.Button(value=None, | |
size="small", | |
color="default", | |
variant="text", | |
as_item="edit_btn" | |
) as chatbot_edit_btn: | |
with ms.Slot("icon"): | |
antd.Icon("EditOutlined") | |
with antd.Popconfirm( | |
title=get_text("Delete the message", "删除消息"), | |
description=get_text( | |
"Are you sure to delete this message?", | |
"确定要删除这条消息吗?"), | |
ok_button_props=dict( | |
danger=True), | |
as_item="delete_btn" | |
) as chatbot_delete_popconfirm: | |
with antd.Button( | |
value=None, | |
size="small", | |
color="default", | |
variant="text", | |
as_item="delete_btn"): | |
with ms.Slot("icon"): | |
antd.Icon("DeleteOutlined") | |
# Sender | |
with antdx.Suggestion( | |
# onKeyDown Handler in Javascript | |
should_trigger="""(e, { onTrigger, onKeyDown }) => { | |
switch(e.key) { | |
case '/': | |
onTrigger() | |
break | |
case 'ArrowRight': | |
case 'ArrowLeft': | |
case 'ArrowUp': | |
case 'ArrowDown': | |
break; | |
default: | |
onTrigger(false) | |
} | |
onKeyDown(e) | |
}""") as suggestion: | |
with ms.Slot("children"): | |
with antdx.Sender(placeholder=get_text( | |
"Enter Prompt (Drag & Drop or Ctrl+V to paste images)", | |
"输入内容(可拖拽图片或 Ctrl+V 粘贴图片)"), ) as sender: | |
with ms.Slot("actions"): | |
# 停止生成按钮 | |
with antd.Button( | |
type="text", | |
size="large", | |
visible=False, # 初始隐藏 | |
elem_style=dict( | |
color="#ff4d4f", # 红色 | |
border="none", | |
background="transparent" | |
) | |
) as stop_btn: | |
with ms.Slot("icon"): | |
antd.Icon("PauseCircleOutlined") | |
with ms.Slot("prefix"): | |
# Image Upload Button with Counter - 图片上传按钮 | |
with antd.Space(size="small"): | |
with antd.Tooltip(title="点击上传图片", color="green"): | |
with antd.Upload( | |
accept="image/*", | |
multiple=True, | |
show_upload_list=False, | |
elem_style=dict(display="inline-block") | |
) as image_upload: | |
with antd.Badge( | |
count=0, # 默认显示0 | |
size="small", | |
color="#52c41a", # 绿色 | |
elem_style=dict(display="block") # 默认显示 | |
) as green_image_indicator: | |
with antd.Button( | |
type="text", | |
size="large", | |
elem_style=dict( | |
color="#52c41a", # 绿色图标 | |
border="none", | |
background="transparent" | |
) | |
): | |
with ms.Slot("icon"): | |
antd.Icon("PictureOutlined") | |
# Trash Button - 垃圾桶清理按钮 | |
with antd.Tooltip(title="清除已上传的图片", color="red"): | |
with antd.Button( | |
type="text", | |
size="large", | |
elem_style=dict( | |
color="#ff4d4f", # 红色图标 | |
border="none", | |
background="transparent", | |
display="none" # 默认隐藏,有图片时显示 | |
) | |
) as trash_button: | |
with ms.Slot("icon"): | |
antd.Icon("DeleteOutlined") | |
# Clear Button - 清空对话历史按钮 | |
with antd.Tooltip(title=get_text( | |
"Clear Conversation History", | |
"清空对话历史"), ): | |
with antd.Button( | |
value=None, | |
type="text") as clear_btn: | |
with ms.Slot("icon"): | |
antd.Icon("ClearOutlined") | |
# Modals | |
with antd.Modal(title=get_text("Edit Message", "编辑消息"), | |
open=False, | |
centered=True, | |
width="60%") as edit_modal: | |
edit_textarea = antd.Input.Textarea(auto_size=dict(minRows=2, | |
maxRows=6), | |
elem_style=dict(width="100%")) | |
# Events Handler | |
if save_history: | |
browser_state = gr.BrowserState( | |
{ | |
"conversations_history": {}, | |
"conversations": [], | |
}, | |
storage_key="dots_chatbot_storage") | |
state.change(fn=Gradio_Events.update_browser_state, | |
inputs=[state], | |
outputs=[browser_state]) | |
demo.load(fn=Gradio_Events.apply_browser_state, | |
inputs=[browser_state, state], | |
outputs=[conversations, state]) | |
add_conversation_btn.click(fn=Gradio_Events.new_chat, | |
inputs=[state], | |
outputs=[conversations, chatbot, state]) | |
conversations.active_change(fn=Gradio_Events.select_conversation, | |
inputs=[state], | |
outputs=[conversations, chatbot, state]) | |
conversations.menu_click(fn=Gradio_Events.click_conversation_menu, | |
inputs=[state], | |
outputs=[conversations, chatbot, state]) | |
clear_btn.click(fn=Gradio_Events.clear_conversation_history, | |
inputs=[state], | |
outputs=[chatbot, state]) | |
suggestion.select(fn=Gradio_Events.select_suggestion, | |
inputs=[sender], | |
outputs=[sender]) | |
gr.on(triggers=[user_edit_btn.click, chatbot_edit_btn.click], | |
fn=Gradio_Events.edit_message, | |
inputs=[state], | |
outputs=[edit_textarea, state]).then(fn=Gradio_Events.open_modal, | |
outputs=[edit_modal]) | |
edit_modal.ok(fn=Gradio_Events.confirm_edit_message, | |
inputs=[edit_textarea, state], | |
outputs=[chatbot, state]).then(fn=Gradio_Events.close_modal, | |
outputs=[edit_modal]) | |
edit_modal.cancel(fn=Gradio_Events.close_modal, outputs=[edit_modal]) | |
gr.on(triggers=[ | |
chatbot_delete_popconfirm.confirm, user_delete_popconfirm.confirm | |
], | |
fn=Gradio_Events.delete_message, | |
inputs=[state], | |
outputs=[chatbot, state]) | |
regenerating_event = chatbot_regenerate_popconfirm.confirm( | |
fn=Gradio_Events.regenerate_message, | |
inputs=[state], | |
outputs=[sender, clear_btn, conversation_delete_menu_item, add_conversation_btn, conversations, chatbot, state, | |
image_upload, green_image_indicator, trash_button, stop_btn]) | |
# 图片上传事件 | |
image_upload.change(fn=Gradio_Events.handle_image_upload, | |
inputs=[image_upload, state], | |
outputs=[state, green_image_indicator, trash_button]) | |
# 清空图片事件 - 垃圾桶按钮 | |
trash_button.click(fn=Gradio_Events.clear_images, | |
inputs=[state], | |
outputs=[state, green_image_indicator, trash_button, image_upload]) | |
submit_event = sender.submit(fn=Gradio_Events.submit, | |
inputs=[sender, state], | |
outputs=[sender, clear_btn, conversation_delete_menu_item, | |
add_conversation_btn, conversations, chatbot, state, | |
image_upload, green_image_indicator, trash_button, stop_btn]) | |
# 停止按钮点击事件 | |
stop_btn.click(fn=None, cancels=[submit_event, regenerating_event]) | |
stop_btn.click(fn=Gradio_Events.cancel, | |
inputs=[state], | |
outputs=[ | |
sender, conversation_delete_menu_item, clear_btn, | |
conversations, add_conversation_btn, chatbot, state, stop_btn | |
]) | |
sender.cancel(fn=None, cancels=[submit_event, regenerating_event]) | |
sender.cancel(fn=Gradio_Events.cancel, | |
inputs=[state], | |
outputs=[ | |
sender, conversation_delete_menu_item, clear_btn, | |
conversations, add_conversation_btn, chatbot, state, stop_btn | |
]) | |
if __name__ == "__main__": | |
import sys | |
import argparse | |
parser = argparse.ArgumentParser(description="启动 Gradio Demo") | |
parser.add_argument("--port", type=int, default=7860, help="指定服务端口,默认为7960") | |
args = parser.parse_args() | |
demo.queue(default_concurrency_limit=200).launch( | |
ssr_mode=False, | |
max_threads=200, | |
server_port=args.port, | |
server_name="0.0.0.0" | |
) |