File size: 3,535 Bytes
8987e34
6a6db8a
65038dc
6a6db8a
8987e34
 
65038dc
6a6db8a
 
61119fe
6a6db8a
8987e34
6a6db8a
 
 
 
8987e34
 
6a6db8a
79cceb8
 
 
65038dc
 
 
 
 
 
6a6db8a
 
 
 
 
 
 
8987e34
 
d3eab6a
6a6db8a
d3eab6a
 
6a6db8a
d3eab6a
 
 
 
6a6db8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8987e34
6a6db8a
aeb0af0
6a6db8a
aeb0af0
6a6db8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8987e34
6a6db8a
 
61119fe
6a6db8a
61119fe
6a6db8a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import io
import os
import base64
from typing import Optional

import httpx
from PIL import Image
import gradio as gr
from fastapi import FastAPI

# ========== 配置 ==========
STEPFUN_ENDPOINT = os.getenv("STEPFUN_ENDPOINT", "https://api.stepfun.com/v1")
MODEL_NAME = os.getenv("STEPFUN_MODEL", "step-3")
TITLE = "StepFun · 图像问答(step-3)"
DESC = "上传图片 + 输入问题,走 StepFun OpenAI 兼容接口 /chat/completions"
# =========================

def _get_api_key() -> Optional[str]:
    # 兼容两种环境变量名
    return os.getenv("OPENAI_API_KEY") or os.getenv("STEPFUN_KEY")

def _pil_to_data_url(img: Image.Image, fmt: str = "PNG") -> str:
    buf = io.BytesIO()
    img.save(buf, format=fmt)
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    mime = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
    return f"data:{mime};base64,{b64}"

def _post_chat(messages: list, temperature: float = 0.7, timeout: float = 60.0) -> str:
    key = _get_api_key()
    if not key:
        raise RuntimeError(
            "API Key 未设置。请在 Space 的 Settings → Variables and secrets 添加:\n"
            "OPENAI_API_KEY 或 STEPFUN_KEY(值为 StepFun API Key)。"
        )

    url = f"{STEPFUN_ENDPOINT.rstrip('/')}/chat/completions"
    headers = {
        "Authorization": f"Bearer {key}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": MODEL_NAME,
        "messages": messages,
        "temperature": temperature,
    }
    r = httpx.post(url, headers=headers, json=payload, timeout=timeout)
    r.raise_for_status()
    data = r.json()
    # 兼容常见返回结构
    return data["choices"][0]["message"]["content"]

def infer(image: Optional[Image.Image], question: Optional[str]) -> str:
    if image is None:
        return "请先上传图片再提问。"
    q = (question or "").strip() or "请描述这张图片。"
    data_url = _pil_to_data_url(image, fmt="PNG")
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": data_url}},
                {"type": "text", "text": q},
            ],
        }
    ]
    try:
        return _post_chat(messages)
    except httpx.HTTPStatusError as e:
        return f"调用失败(HTTP {e.response.status_code}):{e.response.text}"
    except Exception as e:
        return f"调用失败:{repr(e)}"

# ------- Gradio 界面 -------
demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Image(type="pil", label="上传图片"),
        gr.Textbox(label="问题", placeholder="例如:这是什么菜?怎么做?"),
    ],
    outputs=gr.Textbox(label="回答"),
    title=TITLE,
    description=DESC,
)

# ------- FastAPI 宿主应用(覆盖 /info,避免 gradio_client 的 schema 解析) -------
fastapi_app = FastAPI()

@fastapi_app.get("/health")
def health():
    return {"status": "ok"}

@fastapi_app.get("/info")
def info_stub():
    # 返回一个最小可用的对象,绕过 gradio 的 api_info 复杂逻辑
    return {
        "api": False,
        "message": "API docs 已禁用(此路由由外部 FastAPI 覆盖以规避依赖冲突)。"
    }

# 挂载 Gradio 到根路径
app = gr.mount_gradio_app(fastapi_app, demo, path="/")

# 本地调试:python app.py
if __name__ == "__main__":
    import uvicorn
    port = int(os.getenv("PORT", "7860"))
    # 注意:本地调试可以启;Spaces 不会走这里
    uvicorn.run(app, host="0.0.0.0", port=port)