File size: 3,556 Bytes
6a6db8a
e1d60fa
65038dc
fe0c240
8987e34
6a6db8a
e1d60fa
fe0c240
6a6db8a
61119fe
fe0c240
8987e34
e1d60fa
fe0c240
e1d60fa
8987e34
fe0c240
e1d60fa
fe0c240
e1d60fa
79cceb8
 
e1d60fa
79cceb8
e1d60fa
fe0c240
e1d60fa
65038dc
 
 
 
 
 
e1d60fa
fe0c240
e1d60fa
fe0c240
e1d60fa
6a6db8a
 
fe0c240
e1d60fa
fe0c240
 
6a6db8a
8987e34
fe0c240
d3eab6a
6a6db8a
d3eab6a
 
6a6db8a
d3eab6a
 
 
 
e1d60fa
fe0c240
 
 
 
 
 
 
 
 
6a6db8a
e1d60fa
fe0c240
e1d60fa
fe0c240
e1d60fa
6a6db8a
e1d60fa
 
fe0c240
6a6db8a
fe0c240
6a6db8a
 
 
 
 
 
 
 
 
fe0c240
6a6db8a
8987e34
fe0c240
 
e1d60fa
fe0c240
e1d60fa
fe0c240
e1d60fa
fe0c240
e1d60fa
fe0c240
 
 
e1d60fa
fe0c240
 
 
e1d60fa
fe0c240
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import io
import base64
from typing import List, Dict, Any

import gradio as gr
from PIL import Image
import httpx
from fastapi import FastAPI

# ================== 配置 ==================
STEPFUN_ENDPOINT = os.getenv("STEPFUN_ENDPOINT", "https://api.stepfun.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "step-3")
# =========================================


def _get_api_key() -> str | None:
    """
    优先读 OPENAI_API_KEY(兼容 OpenAI SDK 约定),其次读 STEPFUN_KEY
    """
    return os.getenv("OPENAI_API_KEY") or os.getenv("STEPFUN_KEY")


def _pil_to_data_url(img: Image.Image, fmt: str = "PNG") -> str:
    """
    PIL.Image -> data:image/...;base64,xxxxx
    """
    buf = io.BytesIO()
    img.save(buf, format=fmt)
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    mime = "image/png" if fmt.upper() == "PNG" else "image/jpeg"
    return f"data:{mime};base64,{b64}"


def _post_chat(messages: List[Dict[str, Any]], temperature: float = 0.7) -> str:
    """
    直连 StepFun 的 OpenAI 兼容接口:/chat/completions
    """
    key = _get_api_key()
    if not key:
        return (
            "API Key 未设置。\n"
            "请在 Space 的 Settings → Variables and secrets 中添加:\n"
            "OPENAI_API_KEY=<你的 StepFun API Key>(或设置 STEPFUN_KEY 也可)"
        )

    url = f"{STEPFUN_ENDPOINT}/chat/completions"
    headers = {
        "Authorization": f"Bearer {key}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": MODEL_NAME,
        "messages": messages,
        "temperature": temperature,
    }

    try:
        resp = httpx.post(url, headers=headers, json=payload, timeout=60)
        resp.raise_for_status()
        data = resp.json()
        return data["choices"][0]["message"]["content"]
    except httpx.HTTPError as e:
        return f"调用失败(HTTP):{e}"
    except Exception as e:
        return f"调用失败:{e}"


def chat_fn(image: Image.Image, question: str) -> str:
    """
    Gradio 回调:上传图片 + 文本问题 → 模型答案
    """
    if image is None:
        return "请先上传图片。"

    q = question.strip() or "请描述这张图片。"
    data_url = _pil_to_data_url(image, fmt="PNG")

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": data_url}},
                {"type": "text", "text": q},
            ],
        }
    ]
    return _post_chat(messages)


with gr.Blocks(title="Step-3 · 图片问答") as demo:
    gr.Markdown("## Step-3 图片问答\n上传一张图,随便问。")
    with gr.Row():
        img = gr.Image(type="pil", label="上传图片")
        with gr.Column():
            q = gr.Textbox(label="问题", placeholder="比如:这是什么菜?怎么做?")
            btn = gr.Button("提交")
    out = gr.Textbox(label="答案", lines=8)

    btn.click(fn=chat_fn, inputs=[img, q], outputs=[out])
    # 也支持回车触发
    q.submit(fn=chat_fn, inputs=[img, q], outputs=[out])

# ---- FastAPI + 挂载 Gradio(给 HF Spaces 用)----
_fastapi = FastAPI()
app = gr.mount_gradio_app(_fastapi, demo, path="/")

# ---- 本地调试:只有在非 HF Spaces 才会启动 ----
if __name__ == "__main__" and os.environ.get("SYSTEM") != "spaces":
    # 本地跑:python app.py
    # 不要在 HF Spaces 起 uvicorn,否则会端口冲突
    demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))