Step3 / app.py
Zenithwang's picture
Update app.py
e1d60fa verified
raw
history blame
4.34 kB
import os
import io
import base64
from typing import Optional
import httpx
import gradio as gr
from PIL import Image
from fastapi import FastAPI
# --------- 基本配置 ---------
STEPFUN_ENDPOINT = os.getenv("STEPFUN_ENDPOINT", "https://api.stepfun.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "step-3")
TITLE = "StepFun · step-3 图片问答 Demo"
DESC = "上传一张图片,问一个问题;后台通过 StepFun OpenAI 兼容接口完成图文对话。"
FOOTER = "提示:在 HF Spaces 的 Settings -> Variables 里设置 OPENAI_API_KEY 或 STEPFUN_KEY"
# ---------------------------
def _get_api_key() -> Optional[str]:
"""
从环境变量里取 API Key。
优先 OPENAI_API_KEY(OpenAI 兼容接口的常用名),否则退回 STEPFUN_KEY。
"""
return os.getenv("OPENAI_API_KEY") or os.getenv("STEPFUN_KEY")
def _pil_to_data_url(img: Image.Image, fmt: str = "PNG") -> str:
"""
PIL.Image -> data:image/...;base64,... (OpenAI兼容的 image_url 需要)
"""
buf = io.BytesIO()
img.save(buf, format=fmt)
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
mime = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
return f"data:{mime};base64,{b64}"
def _post_chat(messages: list, temperature: float = 0.7) -> str:
"""
直接用 httpx 调用 StepFun 的 /chat/completions 接口,返回文本。
"""
key = _get_api_key()
if not key:
raise RuntimeError(
"API Key 未设置。\n"
"请在本地环境变量或 HF Spaces -> Settings -> Variables 里设置:\n"
"OPENAI_API_KEY 或 STEPFUN_KEY(值为 StepFun API Key)。"
)
url = f"{STEPFUN_ENDPOINT.rstrip('/')}/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
payload = {
"model": MODEL_NAME,
"messages": messages,
"temperature": temperature,
}
# 简单超时与错误抛出
resp = httpx.post(url, headers=headers, json=payload, timeout=60)
resp.raise_for_status()
data = resp.json()
return data["choices"][0]["message"]["content"]
def chat_with_step3(image: Optional[Image.Image], question: Optional[str]) -> str:
"""
Gradio 回调函数:输入图片和问题,返回模型回答。
"""
if image is None:
return "请先上传图片。"
q = (question or "").strip()
if not q:
q = "请描述这张图片的内容,并指出可能的菜品名称与做法要点。"
data_url = _pil_to_data_url(image, fmt="PNG")
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": q},
],
}
]
try:
return _post_chat(messages)
except httpx.HTTPStatusError as e:
# 返回后端具体错误信息,便于排障
try:
detail = e.response.json()
except Exception:
detail = e.response.text
return f"后端返回错误:HTTP {e.response.status_code}\n{detail}"
except Exception as e:
return f"调用失败:{repr(e)}"
# --------- 构建 Gradio 界面 ---------
with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
gr.Markdown(f"## {TITLE}")
gr.Markdown(DESC)
with gr.Row():
with gr.Column():
img_in = gr.Image(type="pil", label="上传图片")
txt_in = gr.Textbox(
label="问题(可留空)",
placeholder="例如:这是什么菜?做法是怎样的?",
)
btn = gr.Button("提交")
with gr.Column():
out = gr.Textbox(label="回答", lines=12)
btn.click(fn=chat_with_step3, inputs=[img_in, txt_in], outputs=out)
gr.Markdown(f"<small>{FOOTER}</small>")
# 让 HF Spaces 识别到 FastAPI/ASGI 应用
app = FastAPI()
# 不使用自定义路径参数,直接挂载到根路径
app = gr.mount_gradio_app(app, demo, path="/")
# --------- 本地调试专用(Spaces 环境不会执行)---------
if __name__ == "__main__" and os.environ.get("SPACE_BUILD") is None:
import uvicorn
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)