MixtureOfInputs / app.py
yzhuang's picture
Update app.py
dbe5296 verified
raw
history blame
3.16 kB
# app.py
import json, requests, gradio as gr
API_URL = "http://0.0.0.0:8000/v1/chat/completions"
def stream_completion(message, history, max_tokens, temperature, top_p, beta):
"""Gradio callback: stream the assistant’s reply token-by-token."""
# -------- build OpenAI-style message list (no system prompt) -------------
messages = []
for user_msg, assistant_msg in history:
if user_msg: # past user turn
messages.append({"role": "user", "content": user_msg})
if assistant_msg: # past assistant turn
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
payload = {
"model": "Qwen/Qwen3-4B",
"messages": messages,
"temperature": temperature,
"top_p": top_p,
"max_tokens": int(max_tokens),
"stream": True,
}
headers = {
"Content-Type": "application/json",
"X-MIXINPUTS-BETA": str(beta),
}
try:
with requests.post(API_URL,
json=payload,
stream=True,
headers=headers,
timeout=(10, None)) as resp:
resp.raise_for_status()
assistant = ""
# iterate over the HTTP chunks
for raw in resp.iter_lines(decode_unicode=True, delimiter=b"\n"):
if not raw:
continue
if raw.startswith("data: "):
data = raw[6:] # strip the 'data: ' prefix
else:
data = raw
if data.strip() == "[DONE]":
break
delta = json.loads(data)["choices"][0]["delta"].get("content", "")
assistant += delta
yield history + [(message, assistant)] # live update in Gradio
except Exception as err:
yield history + [(message, f"[ERROR] {err}")]
# ---------------------------- UI --------------------------------------------
with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo") as demo:
gr.Markdown(
"## 🎨 Mixture of Inputs (MoI) Demo \n"
"Streaming vLLM demo with dynamic **beta** adjustment in MoI "
"(higher beta → less blending)."
)
with gr.Row(): # sliders first
beta = gr.Slider(0.0, 10.0, value=1.0, step=0.1, label="MoI β")
temperature = gr.Slider(0.1, 1.0, value=0.6, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.80, step=0.05, label="Top-p")
max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens")
chatbot = gr.Chatbot(height=450)
user_box = gr.Textbox(placeholder="Type a message and press Enter…", show_label=False)
clear_btn = gr.Button("Clear chat")
user_box.submit(
fn=stream_completion,
inputs=[user_box, chatbot, max_tokens, temperature, top_p, beta],
outputs=chatbot,
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()