File size: 5,404 Bytes
7977967
 
 
 
 
 
 
 
 
19d6e32
 
 
 
 
 
 
 
 
7977967
f557784
 
 
 
 
 
 
7977967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19d6e32
 
7977967
 
 
 
19d6e32
7977967
 
 
 
 
 
 
 
 
 
19d6e32
7977967
 
438d4d3
7977967
 
 
 
5856dbe
 
 
 
 
7977967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dfbf5d
2f871c8
19d6e32
2f871c8
 
 
7977967
 
 
044e027
2f871c8
7977967
f557784
438d4d3
 
7977967
044e027
 
 
 
 
 
7977967
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from openai import OpenAI
import uuid
import json
import os
import tempfile
import subprocess
import threading

MAIN_PORT = 5100
BASE_URL = f"http://localhost:{MAIN_PORT}/v1"
#MODEL_NAME = "placeholder-model-id"
MODEL_LIST = [
    ("Ernie-4.5-0.3B - Good generalist and small", "Ernie-4.5-0.3B"),
    ("LFM2-VL-450M - Stronger RLHF? Weaker in STEM", "LFM2-VL-450M"),
    ("gemma-3-270m-it - Deliberately Raw, need strong system prompt and steering if want assistant behavior", "gemma-3-270m-it"),
    ("Qwen3-0.6B - hybrid thinking /no_think, can do very limited STEM?", "Qwen3-0.6B")
]

example_conv = [
    "Compare and analyze the pros and cons of traditional vs flat organization in business administration. Feel free to use any style and formatting you want in your response.",
    "Write a recipe for pancake",
    "Help me plan a quick weekend getaway trip to Tokyo?",
    "Write an essay on the role of information technology in international supply chain."
]

def read_output(process):
    """Reads the output from the subprocess and prints it to the console."""
    for line in iter(process.stdout.readline, ""):
        print(line.rstrip())
    process.stdout.close()

def start_server(command):
    """Starts the server as a subprocess and captures its stdout."""
    # Start the server process
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,  # Redirect stderr to stdout
        text=True  # Automatically decode the output to text
    )

    # Start a thread to read the output
    output_thread = threading.Thread(target=read_output, args=(process,))
    output_thread.daemon = True  # Daemonize the thread so it exits when the main program does
    output_thread.start()

    return process

#server_process = start_server(["./llama.cpp/build/bin/llama-server", "-m" ,"./llama.cpp/build/ERNIE-4.5-0.3B-PT-UD-Q8_K_XL.gguf", "-c", "32000", "--jinja", "--no-mmap", "--port", "5100", "--threads", "2"])
server_process = start_server(["./llamaswap/llama-swap", "--listen", f"localhost:{MAIN_PORT}", "--config", "./config.yaml"])


cli = OpenAI(api_key="sk-nokey", base_url=BASE_URL)

def openai_call(message, history, model_chosen, system_prompt, max_new_tokens):
    #print(history) # DEBUG
    history.insert(0, {
        "role": "system",
        "content": system_prompt
    })
    history.append({
        "role": "user",
        "content": message
    })
    response = cli.chat.completions.create(
        model=model_chosen,
        messages=history,
        max_tokens=max_new_tokens,
        #stop=["<|im_end|>", "</s>"],
        stream=True
    )
    reply = ""
    for chunk in response:
        if len(chunk.choices) > 0:
            delta = chunk.choices[0].delta.content
            if delta is not None:
                reply = reply + delta
                yield reply, None
    history.append({ "role": "assistant", "content": reply })
    yield reply, gr.State(history)

def gen_file(conv_state):
    #print(conv_state) # DEBUG
    fname = f"{str(uuid.uuid4())}.json"
    #with tempfile.NamedTemporaryFile(prefix=str(uuid.uuid4()), suffix=".json", mode="w", encoding="utf-8", delete_on_close=False) as f:
    with open(fname, mode="w", encoding="utf-8") as f:
        json.dump(conv_state.value, f, indent=4, ensure_ascii=False)
    return gr.File(fname), gr.State(fname)

def rm_file_wrap(path : str):
    # Try to delete the file.
    try:
        os.remove(path)
    except OSError as e:
        # If it fails, inform the user.
        print("Error: %s - %s." % (e.filename, e.strerror))

def on_download(download_data: gr.DownloadData):
    print(f"deleting {download_data.file.path}")
    rm_file_wrap(download_data.file.path)

def clean_file(orig_path):
    print(f"Deleting {orig_path.value}")
    rm_file_wrap(orig_path.value)

with gr.Blocks() as demo:
    #download=gr.DownloadButton(label="Download Conversation", value=None)
    conv_state = gr.State()
    orig_path = gr.State()
    chatbot = gr.Chatbot(placeholder="Have fun with the AI!", editable='all', show_copy_button=True, type="messages")
    additional_inputs=[
        gr.Dropdown(choices=MODEL_LIST, label="LLM Model"),
        gr.Textbox("You are a helpful AI assistant.", label="System Prompt"),
        gr.Slider(30, 8192, value=2048, label="Max new tokens"),
    ]
    chat = gr.ChatInterface(
        openai_call,
        type="messages",
        chatbot=chatbot,
        additional_inputs=additional_inputs,
        additional_outputs=[conv_state],
        examples=example_conv,
        title="Edge level LLM Chat demo",
        description="In this demo, you can chat with sub-1B param range LLM - they are small enough to run with reasonable speed on most end user device. **Warning:** Do not input sensitive info - assume everything is public!"
    )
    with gr.Accordion("Export Conversations"):
        download_file = gr.File()
        download_btn = gr.Button("Export Conversation for Download") \
            .click(fn=gen_file, inputs=[conv_state], outputs=[download_file, orig_path]) \
            .success(fn=clean_file, inputs=[orig_path])
        download_file.download(on_download, None, None)

try:
    demo.queue(max_size=10, api_open=True).launch(server_name='0.0.0.0')
finally:
    # Stop the server
    server_process.terminate()
    server_process.wait()
    print("Server stopped.")