|
import os |
|
import gradio as gr |
|
from openai import OpenAI |
|
|
|
BASE = os.getenv("HF_ENDPOINT_URL", "").rstrip("/") |
|
API_KEY = os.getenv("HF_TOKEN") |
|
MODEL_ID = os.getenv("MODEL_ID", "RedMod/mangrove_30b_a3b_sft_step_6000") |
|
|
|
if not BASE or not API_KEY: |
|
raise RuntimeError("Set HF_ENDPOINT_URL and HF_TOKEN in Settings → Repository secrets.") |
|
|
|
client = OpenAI(base_url=f"{BASE}/v1", api_key=API_KEY) |
|
|
|
def build_messages(history, user_msg, system_msg): |
|
msgs = [] |
|
if system_msg and system_msg.strip(): |
|
msgs.append({"role": "system", "content": system_msg.strip()}) |
|
for u, a in history: |
|
if u: |
|
msgs.append({"role": "user", "content": u}) |
|
if a: |
|
msgs.append({"role": "assistant", "content": a}) |
|
msgs.append({"role": "user", "content": user_msg}) |
|
return msgs |
|
|
|
def chat_fn(message, history, system_message, temperature, top_p, max_tokens): |
|
msgs = build_messages(history, message, system_message) |
|
stream = client.chat.completions.create( |
|
model=MODEL_ID, |
|
messages=msgs, |
|
temperature=float(temperature), |
|
top_p=float(top_p), |
|
max_tokens=int(max_tokens), |
|
stream=True, |
|
) |
|
partial = "" |
|
for chunk in stream: |
|
delta = chunk.choices[0].delta |
|
if delta and delta.content: |
|
partial += delta.content |
|
yield partial |
|
|
|
with gr.Blocks(title="Mangrove Demo") as demo: |
|
system_box = gr.Textbox( |
|
label="System prompt", |
|
value="You are a helpful assistant.", |
|
lines=2, |
|
) |
|
with gr.Row(): |
|
temp = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="Temperature") |
|
topp = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-p") |
|
maxt = gr.Slider(16, 4096, value=512, step=16, label="Max tokens") |
|
|
|
gr.ChatInterface( |
|
fn=chat_fn, |
|
additional_inputs=[system_box, temp, topp, maxt], |
|
submit_btn="Send", |
|
stop_btn="Stop", |
|
multimodal=False, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|