File size: 4,447 Bytes
27bbe11
 
477e431
27bbe11
477e431
6220bad
d081499
 
 
 
 
 
 
477e431
 
d081499
477e431
 
 
 
 
 
 
 
 
 
af954c4
 
477e431
ff382d3
477e431
d081499
 
 
 
 
 
477e431
ff382d3
af954c4
ff382d3
 
 
 
 
d081499
ff382d3
477e431
d081499
477e431
 
af954c4
 
477e431
 
af954c4
477e431
d081499
af954c4
d081499
 
 
 
af954c4
 
477e431
d081499
 
af954c4
 
d081499
 
 
 
 
 
 
 
 
 
 
af954c4
 
477e431
 
 
 
 
 
 
 
af954c4
 
 
27bbe11
af954c4
27bbe11
 
af954c4
 
 
477e431
 
27bbe11
 
 
 
d081499
477e431
 
d081499
477e431
 
 
d081499
af954c4
 
27bbe11
 
 
51737ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from huggingface_hub import InferenceClient
from typing import List, Dict, Optional

# Your endpoint root (no trailing /v1 here; the client adds it for chat)
ENDPOINT_URL = "https://x6leavj4hgm2fdyx.us-east-2.aws.endpoints.huggingface.cloud/v1/"
def respond(
    user_msg: str,
    history: List[Dict[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    hf_token: Optional[gr.OAuthToken],   # from LoginButton (kept)
    pat_override: str,                   # NEW: user-pasted PAT (password field)
):
    """
    Use PAT override if provided; otherwise fall back to LoginButton token.
    NOTE: OAuth token from LoginButton usually lacks `inference.endpoints.infer.write`,
    so for Inference Endpoints you almost always need to paste a PAT here.
    """

    # Choose a token: prefer user-supplied PAT with endpoints write scope
    token = pat_override.strip() or (getattr(hf_token, "token", None) if hf_token else None)
    if not token:
        yield "🔒 Please click **Login** OR paste a **Hugging Face PAT** with `inference.endpoints.infer.write`."
        return

    client = InferenceClient(base_url=ENDPOINT_URL, token=token)

    # Build messages (OpenAI-style)
    messages = []
    if system_message:
        messages.append({"role": "system", "content": system_message})
    messages.extend(history or [])
    messages.append({"role": "user", "content": user_msg})

    # Try OpenAI-compatible chat route first: /v1/chat/completions
    try:
        out = ""
        for chunk in client.chat_completion(
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True,
        ):
            tok = ""
            if getattr(chunk, "choices", None) and getattr(chunk.choices[0], "delta", None):
                tok = chunk.choices[0].delta.content or ""
            out += tok
            yield out
        return
    except Exception as e_chat:
        chat_err = str(e_chat)

    # Fallback to plain generation (for non-OpenAI runtimes)
    try:
        def to_prompt(msgs: List[Dict[str, str]]) -> str:
            lines = []
            for m in msgs:
                role = m.get("role", "user")
                content = m.get("content", "")
                tag = {"system": "SYSTEM", "user": "USER"}.get(role, "ASSISTANT")
                lines.append(f"[{tag}] {content}")
            lines.append("[ASSISTANT]")
            return "\n".join(lines)

        prompt = to_prompt(messages)
        out = ""
        for tok in client.text_generation(
            prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stream=True,
            return_full_text=False,
        ):
            piece = getattr(tok, "token", tok)
            if isinstance(piece, dict) and "text" in piece:
                piece = piece["text"]
            out += str(piece)
            yield out
    except Exception as e_gen:
        yield (
            "❗ Endpoint call failed.\n\n"
            f"• Chat API error: {chat_err}\n"
            f"• Text-generation fallback error: {e_gen}\n\n"
            "Most likely cause: the token used does NOT have `inference.endpoints.infer.write`.\n"
            "Paste a PAT with that scope in the sidebar."
        )

# --- UI ---
chat = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(1, 4096, value=512, step=1, label="Max new tokens"),
        gr.Slider(0.0, 4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.0, 1.0, value=0.95, step=0.05, label="Top-p"),
        # NEW: secure PAT override
        gr.Textbox(value="", label="HF PAT (with `inference.endpoints.infer.write`)", type="password"),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        gr.Markdown("### Hugging Face Login (optional)")
        gr.LoginButton()
        gr.Markdown(
            "**Important:** Inference Endpoints require a PAT with\n"
            "`inference.endpoints.infer.write`. The Login token usually does **not** have this.\n"
            "Paste a PAT in the password field if you see 403 errors."
        )
        gr.Markdown(f"**Endpoint**: `{ENDPOINT_URL}`")
    chat.render()

if __name__ == "__main__":
    demo.launch()