server / app.py
3v324v23's picture
app.py
d081499
raw
history blame
5.22 kB
import gradio as gr
from huggingface_hub import InferenceClient
from typing import List, Dict
ENDPOINT_URL = "https://x6leavj4hgm2fdyx.us-east-2.aws.endpoints.huggingface.cloud"
def respond(
user_msg: str,
history: List[Dict[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
hf_token: gr.OAuthToken,
):
"""
Streams chat responses from a Hugging Face Inference Endpoint.
Notes:
- Requires your endpoint to allow inference with your token (permission:
`inference.endpoints.infer.write`).
- If the endpoint doesn't support OpenAI-style /v1/chat (e.g., plain TGI),
we fallback to a single-prompt `.text_generation()` call using a simple
prompt format built from the chat history.
"""
# 1) Client that talks directly to your endpoint
client = InferenceClient(
base_url=ENDPOINT_URL,
token=hf_token.token, # uses the OAuth token from the LoginButton
)
# 2) Build OpenAI-style messages for chat backends
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
# Gradio gives `history` as a list of {"role": "...", "content": "..."} when type="messages"
# Append previous turns, then the new user message
messages.extend(history or [])
messages.append({"role": "user", "content": user_msg})
# 3) Try OpenAI-style chat first (works if your endpoint exposes /v1/chat/completions)
try:
response_text = ""
for chunk in client.chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
):
# chunk.choices[0].delta.content is the streamed token (if present)
token = ""
if getattr(chunk, "choices", None) and getattr(chunk.choices[0], "delta", None):
token = chunk.choices[0].delta.content or ""
response_text += token
yield response_text
return # success via chat api
except Exception as e:
# If chat endpoint isn't available, fall back to text_generation
# (common when the endpoint is plain TGI without OpenAI route enabled)
fallback_reason = str(e)
# 4) Fallback: Plain text generation with a simple chat-to-prompt adapter
try:
def to_plain_prompt(msgs: List[Dict[str, str]]) -> str:
lines = []
for m in msgs:
role = m.get("role", "user")
content = m.get("content", "")
if role == "system":
lines.append(f"[SYSTEM] {content}")
elif role == "user":
lines.append(f"[USER] {content}")
else:
lines.append(f"[ASSISTANT] {content}")
lines.append("[ASSISTANT]") # cue the model to speak
return "\n".join(lines)
prompt = to_plain_prompt(messages)
response_text = ""
# stream text_generation tokens if the backend supports it
for tok in client.text_generation(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
# Many TGI backends respect these kwargs; safe to include
return_full_text=False,
):
# `tok` can be a string or an object depending on server; normalize to str
piece = getattr(tok, "token", tok)
if isinstance(piece, dict) and "text" in piece:
piece = piece["text"]
piece = str(piece)
response_text += piece
yield response_text
except Exception as e2:
# Surface a readable error in the chat window
err = (
"Failed to query the endpoint.\n\n"
f"- Chat attempt error: {fallback_reason}\n"
f"- Text-generation fallback error: {e2}\n\n"
"Check that your endpoint is running, your token has "
"`inference.endpoints.infer.write`, and the runtime supports either "
"OpenAI chat (/v1/chat/completions) or TGI text-generation."
)
yield err
# --- Gradio UI ---
chatbot = gr.ChatInterface(
respond,
type="messages", # history comes as [{"role": "...", "content": "..."}]
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
],
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.Markdown("### Hugging Face Login")
# This provides `hf_token: gr.OAuthToken` to `respond`
gr.LoginButton()
gr.Markdown(
"Make sure your token has **`inference.endpoints.infer.write`** permission."
)
gr.Markdown(
f"**Endpoint**:\n\n`{ENDPOINT_URL}`"
)
chatbot.render()
if __name__ == "__main__":
demo.launch()