File size: 2,770 Bytes
c1965a3
 
a846510
f873ce7
c1965a3
314161f
4c07c1e
314161f
4c07c1e
c1965a3
 
 
 
 
 
 
 
 
 
 
 
 
f873ce7
4c07c1e
a846510
c1965a3
f873ce7
a846510
 
c1965a3
f873ce7
a846510
f873ce7
 
c1965a3
f873ce7
 
a846510
c1965a3
 
 
4c07c1e
c1965a3
f873ce7
c1965a3
f873ce7
c1965a3
 
 
 
 
 
 
 
 
f31f69d
c1965a3
f31f69d
 
c1965a3
 
 
 
f31f69d
c1965a3
 
f31f69d
 
4c07c1e
c1965a3
 
 
4c07c1e
c1965a3
 
 
 
 
 
 
a846510
f873ce7
c1965a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# app.py
import json
import requests
import sseclient
import gradio as gr
import server

API_URL = "http://localhost:8000/v1/chat/completions"


def stream_completion(message, history, max_tokens, temperature, top_p, beta):
    """
    Gradio callback: takes the newest user message + full chat history,
    returns an updated history while streaming assistant tokens.
    """
    # ------- build OpenAI-style message list (no system prompt) -------------
    messages = []
    for usr, bot in history:
        if usr:
            messages.append({"role": "user", "content": usr})
        if bot:
            messages.append({"role": "assistant", "content": bot})
    messages.append({"role": "user", "content": message})

    payload = {
        "model": "Qwen/Qwen3-4B",
        "messages": messages,
        "temperature": temperature,
        "top_p": top_p,
        "max_tokens": int(max_tokens),
        "stream": True,
    }
    headers = {
        "Content-Type": "application/json",
        "X-MIXINPUTS-BETA": str(beta),
    }

    try:
        resp = requests.post(API_URL, json=payload, stream=True, headers=headers, timeout=60)
        resp.raise_for_status()
        client = sseclient.SSEClient(resp)

        assistant = ""
        for event in client.events():
            if event.data.strip() == "[DONE]":
                break
            delta = json.loads(event.data)["choices"][0]["delta"].get("content", "")
            assistant += delta
            yield history + [(message, assistant)]  # update the chat box live

    except Exception as err:
        yield history + [(message, f"[ERROR] {err}")]


# ----------------------- UI ---------------------------------------------
with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo") as demo:
    gr.Markdown(
        "## 🎨 Mixture of Inputs (MoI) Demo  \n"
        "Streaming vLLM demo with dynamic **beta** adjustment in MoI, higher beta means less blending."
    )

    # sliders first – all on one row
    with gr.Row():
        beta = gr.Slider(0.0, 10.0, value=1.0, step=0.1, label="MoI Beta")
        temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
        max_tokens = gr.Slider(1, 2048, value=512, step=1, label="Max new tokens")


    chatbot = gr.Chatbot(height=450)
    user_box = gr.Textbox(placeholder="Type a message and press Enter…", show_label=False)
    clear_btn = gr.Button("Clear chat")

    # wiring
    user_box.submit(
        stream_completion,
        inputs=[user_box, chatbot, max_tokens, temperature, top_p, beta],
        outputs=chatbot,
    )
    clear_btn.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch()