edge_llm_chat / chat_demo.py
lemonteaa's picture
fix streaming None object (openai sdk)
5856dbe verified
raw
history blame
4.13 kB
import gradio as gr
from openai import OpenAI
import uuid
import json
import os
import tempfile
import subprocess
import threading
BASE_URL = "http://localhost:5100/v1"
MODEL_NAME = "placeholder-model-id"
def read_output(process):
"""Reads the output from the subprocess and prints it to the console."""
for line in iter(process.stdout.readline, ""):
print(line.rstrip())
process.stdout.close()
def start_server(command):
"""Starts the server as a subprocess and captures its stdout."""
# Start the server process
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Redirect stderr to stdout
text=True # Automatically decode the output to text
)
# Start a thread to read the output
output_thread = threading.Thread(target=read_output, args=(process,))
output_thread.daemon = True # Daemonize the thread so it exits when the main program does
output_thread.start()
return process
server_process = start_server(["./llama.cpp/build/bin/llama-server", "-m" ,"./llama.cpp/build/ERNIE-4.5-0.3B-PT-UD-Q8_K_XL.gguf", "-c", "32000", "--jinja", "--no-mmap", "--port", "5100"])
cli = OpenAI(api_key="sk-nokey", base_url=BASE_URL)
def openai_call(message, history, system_prompt, max_new_tokens):
#print(history) # DEBUG
history.insert(0, {
"role": "system",
"content": system_prompt
})
history.append({
"role": "user",
"content": message
})
response = cli.chat.completions.create(
model=MODEL_NAME,
messages=history,
max_tokens=max_new_tokens,
#stop=["<|im_end|>", "</s>"],
stream=True
)
reply = ""
for chunk in response:
if len(chunk.choices) > 0:
delta = chunk.choices[0].delta.content
if delta is not None:
reply = reply + delta
yield reply, None
history.append({ "role": "assistant", "content": reply })
yield reply, gr.State(history)
def gen_file(conv_state):
#print(conv_state) # DEBUG
fname = f"{str(uuid.uuid4())}.json"
#with tempfile.NamedTemporaryFile(prefix=str(uuid.uuid4()), suffix=".json", mode="w", encoding="utf-8", delete_on_close=False) as f:
with open(fname, mode="w", encoding="utf-8") as f:
json.dump(conv_state.value, f, indent=4, ensure_ascii=False)
return gr.File(fname), gr.State(fname)
def rm_file_wrap(path : str):
# Try to delete the file.
try:
os.remove(path)
except OSError as e:
# If it fails, inform the user.
print("Error: %s - %s." % (e.filename, e.strerror))
def on_download(download_data: gr.DownloadData):
print(f"deleting {download_data.file.path}")
rm_file_wrap(download_data.file.path)
def clean_file(orig_path):
print(f"Deleting {orig_path.value}")
rm_file_wrap(orig_path.value)
with gr.Blocks() as demo:
#download=gr.DownloadButton(label="Download Conversation", value=None)
conv_state = gr.State()
orig_path = gr.State()
chat = gr.ChatInterface(
openai_call,
type="messages",
additional_inputs=[
gr.Textbox("You are a helpful AI assistant.", label="System Prompt"),
gr.Slider(30, 8192, value=2048, label="Max new tokens"),
],
additional_outputs=[conv_state],
title="Edge level LLM Chat demo",
description="In this demo, you can chat with sub-1B param range LLM - they are small enough to run with reasonable speed on most end user device. **Warning:** Do not input sensitive info - assume everything is public!"
)
download_file = gr.File()
download_btn = gr.Button("Export Conversation for Download") \
.click(fn=gen_file, inputs=[conv_state], outputs=[download_file, orig_path]) \
.success(fn=clean_file, inputs=[orig_path])
download_file.download(on_download, None, None)
try:
demo.queue(max_size=10, api_open=True).launch(server_name='0.0.0.0')
finally:
# Stop the server
server_process.terminate()
server_process.wait()
print("Server stopped.")