import gradio as gr
import base64
from openai import OpenAI
from PIL import Image
import io
import os
import time
import traceback
# API配置
BASE_URL = "https://api.stepfun.com/v1"
STEP_API_KEY = os.environ.get("STEP_API_KEY", "5GTbxYn2RDN9qmm3Y2T2yhuzlJNrNj65y0W9dVVNrOUmD7eLB3aJ2NDXGyyl2Yccq")
print(f"[DEBUG] Starting app with API key: {'Set' if STEP_API_KEY else 'Not set'}")
print(f"[DEBUG] Base URL: {BASE_URL}")
def image_to_base64(image_path):
"""将图片文件转换为base64字符串"""
try:
with Image.open(image_path) as img:
# 如果是RGBA,转换为RGB
if img.mode == 'RGBA':
rgb_img = Image.new('RGB', img.size, (255, 255, 255))
rgb_img.paste(img, mask=img.split()[3])
img = rgb_img
# 转换为字节流
buffered = io.BytesIO()
img.save(buffered, format="JPEG", quality=95)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
except Exception as e:
print(f"[ERROR] Failed to convert image: {e}")
return None
def user_submit(message, history, images):
"""处理用户提交"""
if not message and not images:
return message, history, images, "", None
# 创建用户消息显示
display_message = message if message else ""
if images:
# 显示上传的图片数量
if isinstance(images, list):
num_images = len(images)
image_text = f"[{num_images} Image{'s' if num_images > 1 else ''}]"
else:
image_text = "[1 Image]"
display_message = f"{image_text} {display_message}" if display_message else image_text
history = history + [[display_message, None]]
# 返回清空的输入框、更新的历史、清空的图片,以及保存的消息和图片
return "", history, None, message, images
def bot_response(history, saved_message, saved_images, system_prompt, temperature, max_tokens, top_p):
"""生成机器人回复"""
if saved_message or saved_images:
# 调用process_message并流式返回结果
for updated_history in process_message(
saved_message,
history,
saved_images,
system_prompt,
temperature,
max_tokens,
top_p
):
yield updated_history
else:
yield history
def process_message(message, history, images, system_prompt, temperature, max_tokens, top_p):
"""处理消息并调用Step-3 API"""
print(f"[DEBUG] Processing message: {message[:100] if message else 'None'}")
print(f"[DEBUG] Has images: {images is not None}")
print(f"[DEBUG] Images type: {type(images)}")
if images:
print(f"[DEBUG] Images content: {images}")
if not message and not images:
history[-1][1] = "Please provide a message or image."
yield history
return
# 确保历史记录中有用户消息
if not history or history[-1][1] is not None:
display_message = message if message else ""
if images:
if isinstance(images, list):
num_images = len(images)
image_text = f"[{num_images} Image{'s' if num_images > 1 else ''}]"
else:
image_text = "[1 Image]"
display_message = f"{image_text} {display_message}" if display_message else image_text
history.append([display_message, None])
# 开始生成回复
history[-1][1] = "🤔 Thinking..."
yield history
try:
# 构建消息内容
content = []
# 处理图片(支持多图)
if images:
# 确保images是列表
image_list = images if isinstance(images, list) else [images]
for image_path in image_list:
if image_path:
print(f"[DEBUG] Processing image: {image_path}")
base64_image = image_to_base64(image_path)
if base64_image:
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
})
print(f"[DEBUG] Successfully added image to content")
else:
print(f"[ERROR] Failed to convert image: {image_path}")
# 添加文本消息
if message:
content.append({
"type": "text",
"text": message
})
print(f"[DEBUG] Added text to content: {message[:100]}")
if not content:
history[-1][1] = "❌ No valid input provided."
yield history
return
# 构造API消息
messages = []
# 添加系统提示(如果有)
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# 只使用用户消息内容,不包括之前的历史
messages.append({
"role": "user",
"content": content
})
print(f"[DEBUG] Prepared {len(messages)} messages for API")
print(f"[DEBUG] Message structure: {[{'role': m['role'], 'content_types': [c.get('type', 'text') for c in m['content']] if isinstance(m['content'], list) else 'text'} for m in messages]}")
# 处理代理问题 - 确保删除所有代理相关的环境变量
import os
import httpx
# 删除所有可能的代理环境变量
proxy_vars = ['HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy',
'ALL_PROXY', 'all_proxy', 'NO_PROXY', 'no_proxy']
for var in proxy_vars:
if var in os.environ:
del os.environ[var]
print(f"[DEBUG] Removed {var} from environment")
# 尝试创建客户端
try:
# 方法1:直接创建
client = OpenAI(
api_key=STEP_API_KEY,
base_url=BASE_URL
)
print("[DEBUG] Client created successfully (method 1)")
except TypeError as e:
if 'proxies' in str(e):
print(f"[DEBUG] Method 1 failed with proxy error, trying method 2")
# 方法2:使用自定义HTTP客户端
http_client = httpx.Client(trust_env=False)
client = OpenAI(
api_key=STEP_API_KEY,
base_url=BASE_URL,
http_client=http_client
)
print("[DEBUG] Client created successfully (method 2)")
else:
raise e
print(f"[DEBUG] Making API call to {BASE_URL}")
# 调用API
response = client.chat.completions.create(
model="step-3",
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
print("[DEBUG] API call successful, starting streaming")
# 流式输出
full_response = ""
in_reasoning = False
reasoning_content = ""
final_content = ""
for chunk in response:
if chunk.choices and chunk.choices[0].delta:
delta_content = chunk.choices[0].delta.content
if delta_content:
full_response += delta_content
# 检测reasoning标签
if '' in full_response and not in_reasoning:
in_reasoning = True
parts = full_response.split('')
if len(parts) > 1:
reasoning_content = parts[1]
if in_reasoning and '' in full_response:
in_reasoning = False
parts = full_response.split('')
if len(parts) > 1:
reasoning_content = parts[0].split('')[-1]
final_content = parts[1]
elif in_reasoning:
reasoning_content = full_response.split('')[-1]
elif '' in full_response:
parts = full_response.split('')
if len(parts) > 1:
final_content = parts[1]
else:
# 没有reasoning标签的情况
if '' not in full_response:
final_content = full_response
# 格式化显示
if reasoning_content and final_content:
display_text = f"💭 **Chain of Thought:**\n\n{reasoning_content.strip()}\n\n---\n\n📝 **Answer:**\n\n{final_content.strip()}"
elif reasoning_content:
display_text = f"💭 **Chain of Thought:**\n\n{reasoning_content.strip()}\n\n---\n\n📝 **Answer:**\n\n*Generating...*"
else:
display_text = full_response
history[-1][1] = display_text
yield history
# 最终格式化
if reasoning_content or final_content:
final_display = f"💭 **Chain of Thought:**\n\n{reasoning_content.strip()}\n\n---\n\n📝 **Answer:**\n\n{final_content.strip()}"
history[-1][1] = final_display
else:
history[-1][1] = full_response
print(f"[DEBUG] Streaming completed. Response length: {len(full_response)}")
yield history
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(f"[ERROR] {error_msg}")
traceback.print_exc()
history[-1][1] = f"❌ Error: {str(e)}"
yield history
# 创建Gradio界面
css = """
/* 强制设置File组件容器高度 */
.compact-file, .compact-file > * {
height: 52px !important;
max-height: 52px !important;
min-height: 52px !important;
}
/* 使用ID选择器确保优先级 */
#image-upload {
height: 52px !important;
max-height: 52px !important;
min-height: 52px !important;
}
#image-upload > div,
#image-upload .wrap,
#image-upload .block,
#image-upload .container {
height: 52px !important;
max-height: 52px !important;
min-height: 52px !important;
padding: 0 !important;
margin: 0 !important;
}
/* 文件上传按钮样式 */
#image-upload button,
.compact-file button {
height: 50px !important;
max-height: 50px !important;
min-height: 50px !important;
font-size: 13px !important;
padding: 0 12px !important;
margin: 1px !important;
}
/* 文件预览区域 */
#image-upload .file-preview,
.compact-file .file-preview {
height: 50px !important;
max-height: 50px !important;
overflow-y: auto !important;
font-size: 12px !important;
padding: 4px !important;
}
/* 隐藏标签 */
#image-upload label,
.compact-file label {
display: none !important;
}
/* 确保input元素也是正确高度 */
#image-upload input[type="file"],
.compact-file input[type="file"] {
height: 50px !important;
max-height: 50px !important;
}
/* 文本框参考高度 */
#message-textbox textarea {
min-height: 52px !important;
max-height: 52px !important;
}
/* 使用通配符确保所有子元素 */
#image-upload * {
max-height: 52px !important;
}
"""
with gr.Blocks(title="Step-3", theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("""
#
Step-3
Welcome to Step-3, an advanced multimodal AI assistant by StepFun.
""")
# 创建状态变量来保存消息和图片
saved_msg = gr.State("")
saved_imgs = gr.State([])
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
height=600,
show_label=False,
elem_id="chatbot",
bubble_full_width=False,
avatar_images=None,
render_markdown=True
)
# 输入区域
with gr.Row():
with gr.Column(scale=8):
msg = gr.Textbox(
label="Message",
placeholder="Type your message here...",
lines=2,
max_lines=10,
show_label=False,
elem_id="message-textbox"
)
with gr.Column(scale=2):
image_input = gr.File(
label="Upload Images",
file_count="multiple",
file_types=[".png", ".jpg", ".jpeg", ".gif", ".webp"],
interactive=True,
show_label=False,
elem_classes="compact-file",
elem_id="image-upload"
)
with gr.Column(scale=1, min_width=100):
submit_btn = gr.Button("Send", variant="primary")
# 底部按钮
with gr.Row():
clear_btn = gr.Button("🗑️ Clear", scale=1)
undo_btn = gr.Button("↩️ Undo", scale=1)
retry_btn = gr.Button("🔄 Retry", scale=1)
with gr.Column(scale=1):
# 设置面板
with gr.Accordion("⚙️ Settings", open=False):
system_prompt = gr.Textbox(
label="System Prompt",
placeholder="Set a system prompt (optional)",
lines=3
)
temperature_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
step=0.1,
label="Temperature"
)
max_tokens_slider = gr.Slider(
minimum=100,
maximum=8000,
value=2000,
step=100,
label="Max Tokens"
)
top_p_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.95,
step=0.05,
label="Top P"
)
# 事件处理
submit_event = msg.submit(
user_submit,
[msg, chatbot, image_input],
[msg, chatbot, image_input, saved_msg, saved_imgs],
queue=False
).then(
bot_response,
[chatbot, saved_msg, saved_imgs, system_prompt, temperature_slider, max_tokens_slider, top_p_slider],
chatbot
)
submit_btn.click(
user_submit,
[msg, chatbot, image_input],
[msg, chatbot, image_input, saved_msg, saved_imgs],
queue=False
).then(
bot_response,
[chatbot, saved_msg, saved_imgs, system_prompt, temperature_slider, max_tokens_slider, top_p_slider],
chatbot
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
undo_btn.click(
lambda h: h[:-1] if h else h,
chatbot,
chatbot,
queue=False
)
retry_btn.click(
lambda h: h[:-1] if h and h[-1][1] is not None else h,
chatbot,
chatbot,
queue=False
).then(
bot_response,
[chatbot, saved_msg, saved_imgs, system_prompt, temperature_slider, max_tokens_slider, top_p_slider],
chatbot
)
# 启动应用
if __name__ == "__main__":
print(f"[DEBUG] Starting app with API key: {'Set' if STEP_API_KEY else 'Not set'}")
print(f"[DEBUG] Base URL: {BASE_URL}")
demo.queue(max_size=10)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False
)