edge_llm_chat / chat_demo.py
lemonteaa's picture
Update chat_demo.py
b4acdb9 verified
import gradio as gr
from openai import OpenAI
import uuid
import json
import os
import tempfile
import subprocess
import threading
MAIN_PORT = 5100
BASE_URL = f"http://localhost:{MAIN_PORT}/v1"
#MODEL_NAME = "placeholder-model-id"
MODEL_LIST = [
("Ernie-4.5-0.3B - Good generalist and small", "Ernie-4.5-0.3B"),
("LFM2-VL-450M - Stronger RLHF? Weaker in STEM", "LFM2-VL-450M"),
("gemma-3-270m-it - Deliberately Raw, need strong system prompt and steering if want assistant behavior", "gemma-3-270m-it"),
("Qwen3-0.6B - hybrid thinking /no_think, can do very limited STEM?", "Qwen3-0.6B")
]
example_conv = [
[
"Compare and analyze the pros and cons of traditional vs flat organization in business administration. Feel free to use any style and formatting you want in your response.",
"LFM2-VL-450M",
"",
2048
],
[
"Write a recipe for pancake",
"gemma-3-270m-it",
"You are a friendly and cheerful AI assistant.",
2048
],
[
"Help me plan a quick weekend getaway trip to Tokyo?",
"Ernie-4.5-0.3B",
"You are a helpful AI assistant.",
2048
],
[
"Write an essay on the role of information technology in international supply chain.",
"gemma-3-270m-it",
"You are a helpful AI assistant.",
2048
]
]
def read_output(process):
"""Reads the output from the subprocess and prints it to the console."""
for line in iter(process.stdout.readline, ""):
print(line.rstrip())
process.stdout.close()
def start_server(command):
"""Starts the server as a subprocess and captures its stdout."""
# Start the server process
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout
text=True # Automatically decode the output to text
)
# Start a thread to read the output
output_thread = threading.Thread(target=read_output, args=(process,))
output_thread.daemon = True # Daemonize the thread so it exits when the main program does
output_thread.start()
return process
#server_process = start_server(["./llama.cpp/build/bin/llama-server", "-m" ,"./llama.cpp/build/ERNIE-4.5-0.3B-PT-UD-Q8_K_XL.gguf", "-c", "32000", "--jinja", "--no-mmap", "--port", "5100", "--threads", "2"])
server_process = start_server(["./llamaswap/llama-swap", "--listen", f"localhost:{MAIN_PORT}", "--config", "./config.yaml"])
cli = OpenAI(api_key="sk-nokey", base_url=BASE_URL)
def openai_call(message, history, model_chosen, system_prompt, max_new_tokens):
#print(history) # DEBUG
history.insert(0, {
"role": "system",
"content": system_prompt
})
history.append({
"role": "user",
"content": message
})
response = cli.chat.completions.create(
model=model_chosen,
messages=history,
max_tokens=max_new_tokens,
#stop=["<|im_end|>", "</s>"],
stream=True
)
reply = ""
for chunk in response:
if len(chunk.choices) > 0:
delta = chunk.choices[0].delta.content
if delta is not None:
reply = reply + delta
yield reply, None
history.append({ "role": "assistant", "content": reply })
yield reply, gr.State(history)
def gen_file(conv_state):
#print(conv_state) # DEBUG
fname = f"{str(uuid.uuid4())}.json"
#with tempfile.NamedTemporaryFile(prefix=str(uuid.uuid4()), suffix=".json", mode="w", encoding="utf-8", delete_on_close=False) as f:
with open(fname, mode="w", encoding="utf-8") as f:
json.dump(conv_state.value, f, indent=4, ensure_ascii=False)
return gr.File(fname), gr.State(fname)
def rm_file_wrap(path : str):
# Try to delete the file.
try:
os.remove(path)
except OSError as e:
# If it fails, inform the user.
print("Error: %s - %s." % (e.filename, e.strerror))
def on_download(download_data: gr.DownloadData):
print(f"deleting {download_data.file.path}")
rm_file_wrap(download_data.file.path)
def clean_file(orig_path):
print(f"Deleting {orig_path.value}")
rm_file_wrap(orig_path.value)
with gr.Blocks() as demo:
#download=gr.DownloadButton(label="Download Conversation", value=None)
conv_state = gr.State()
orig_path = gr.State()
chatbot = gr.Chatbot(placeholder="Have fun with the AI!", editable='all', show_copy_button=True, type="messages")
additional_inputs=[
gr.Dropdown(choices=MODEL_LIST, label="LLM Model"),
gr.Textbox("You are a helpful AI assistant.", label="System Prompt"),
gr.Slider(30, 8192, value=2048, label="Max new tokens"),
]
chat = gr.ChatInterface(
openai_call,
type="messages",
chatbot=chatbot,
additional_inputs=additional_inputs,
additional_outputs=[conv_state],
examples=example_conv,
title="Edge level LLM Chat demo",
description="In this demo, you can chat with sub-1B param range LLM - they are small enough to run with reasonable speed on most end user device. **Warning:** Do not input sensitive info - assume everything is public!"
)
with gr.Accordion("Export Conversations"):
download_file = gr.File()
download_btn = gr.Button("Export Conversation for Download") \
.click(fn=gen_file, inputs=[conv_state], outputs=[download_file, orig_path]) \
.success(fn=clean_file, inputs=[orig_path])
download_file.download(on_download, None, None)
try:
demo.queue(max_size=10, api_open=True).launch(server_name='0.0.0.0')
finally:
# Stop the server
server_process.terminate()
server_process.wait()
print("Server stopped.")