File size: 2,012 Bytes
511336b
27bbe11
511336b
27bbe11
511336b
 
3149dc0
477e431
61fc039
 
 
 
af954c4
511336b
 
61fc039
511336b
 
61fc039
 
 
 
511336b
 
ff382d3
511336b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af954c4
987ec44
511336b
 
 
 
 
61fc039
 
 
 
af954c4
511336b
 
 
61fc039
 
 
511336b
27bbe11
 
 
51737ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import gradio as gr
from openai import OpenAI

BASE = os.getenv("HF_ENDPOINT_URL", "").rstrip("/")
API_KEY = os.getenv("HF_TOKEN")
MODEL_ID = os.getenv("MODEL_ID", "RedMod/mangrove_30b_a3b_sft_step_6000")

if not BASE or not API_KEY:
    raise RuntimeError("Set HF_ENDPOINT_URL and HF_TOKEN in Settings → Repository secrets.")

client = OpenAI(base_url=f"{BASE}/v1", api_key=API_KEY)

def build_messages(history, user_msg, system_msg):
    msgs = []
    if system_msg and system_msg.strip():
        msgs.append({"role": "system", "content": system_msg.strip()})
    for u, a in history:
        if u:
            msgs.append({"role": "user", "content": u})
        if a:
            msgs.append({"role": "assistant", "content": a})
    msgs.append({"role": "user", "content": user_msg})
    return msgs

def chat_fn(message, history, system_message, temperature, top_p, max_tokens):
    msgs = build_messages(history, message, system_message)
    stream = client.chat.completions.create(
        model=MODEL_ID,
        messages=msgs,
        temperature=float(temperature),
        top_p=float(top_p),
        max_tokens=int(max_tokens),
        stream=True,
    )
    partial = ""
    for chunk in stream:
        delta = chunk.choices[0].delta
        if delta and delta.content:
            partial += delta.content
            yield partial

with gr.Blocks(title="Mangrove Demo") as demo:
    system_box = gr.Textbox(
        label="System prompt",
        value="You are a helpful assistant.",
        lines=2,
    )
    with gr.Row():
        temp = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature")
        topp = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p")
        maxt = gr.Slider(16, 4096, value=512, step=16, label="Max tokens")

    gr.ChatInterface(
        fn=chat_fn,
        additional_inputs=[system_box, temp, topp, maxt],
        submit_btn="Send",
        stop_btn="Stop",
        multimodal=False,
    )

if __name__ == "__main__":
    demo.launch()