|
""" |
|
The gradio demo server for chatting with a single model. |
|
""" |
|
|
|
import datetime |
|
import json |
|
import os |
|
import time |
|
import uuid |
|
import logging |
|
|
|
import gradio as gr |
|
import requests |
|
|
|
from conversation import get_conv_template |
|
from gradio_patch import Chatbot as grChatbot |
|
from gradio_css import code_highlight_css |
|
from utils import ( |
|
WORKER_API_TIMEOUT, |
|
ErrorCode, |
|
server_error_msg, |
|
get_window_url_params_js, |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format='%(asctime)s %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
headers = {"User-Agent": "fastchat Client"} |
|
|
|
no_change_btn = gr.Button.update() |
|
enable_btn = gr.Button.update(interactive=True) |
|
disable_btn = gr.Button.update(interactive=False) |
|
|
|
controller_url = os.environ['controller_url'] |
|
concurrency_count = int(os.environ['concurrency_count']) |
|
|
|
learn_more_md = (""" |
|
### Notice |
|
- All the models in this demo run on 4th Generation Intel® Xeon® (Sapphire Rapids) utilizing AMX operations and mixed precision inference |
|
- This demo is based on the FastChat demo server. [[GitHub]](https://github.com/lm-sys/FastChat) |
|
|
|
### Terms of use |
|
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It can produce factually incorrect output, and should not be relied on to produce factually accurate information. The service only provides limited safety measures and may generate lewd, biased or otherwise offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. |
|
|
|
### License |
|
The service is a research preview intended for non-commercial use only, subject to the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. |
|
""") |
|
|
|
|
|
def get_model_list(controller_url): |
|
ret = requests.post(controller_url + "/refresh_all_workers") |
|
assert ret.status_code == 200 |
|
ret = requests.post(controller_url + "/list_models") |
|
models = ret.json()["models"] |
|
models.sort() |
|
logger.info(f"Models: {models}") |
|
return models |
|
|
|
|
|
def load_demo_refresh_model_list(url_params): |
|
models = get_model_list(controller_url) |
|
selected_model = models[0] if len(models) > 0 else "" |
|
if "model" in url_params: |
|
model = url_params["model"] |
|
if model in models: |
|
selected_model = model |
|
|
|
dropdown_update = gr.Dropdown.update( |
|
choices=models, value=selected_model, visible=True |
|
) |
|
|
|
state = None |
|
return ( |
|
state, |
|
dropdown_update, |
|
gr.Chatbot.update(visible=True), |
|
gr.Textbox.update(visible=True), |
|
gr.Button.update(visible=True), |
|
gr.Row.update(visible=True), |
|
gr.Accordion.update(visible=True), |
|
) |
|
|
|
|
|
def load_demo_reload_model(url_params, request: gr.Request): |
|
logger.info( |
|
f"load_demo_reload_model. ip: {request.client.host}. params: {url_params}" |
|
) |
|
return load_demo_refresh_model_list(url_params) |
|
|
|
|
|
def load_demo_single(models, url_params): |
|
dropdown_update = gr.Dropdown.update(visible=True) |
|
if "model" in url_params: |
|
model = url_params["model"] |
|
if model in models: |
|
dropdown_update = gr.Dropdown.update(value=model, visible=True) |
|
|
|
state = None |
|
return ( |
|
state, |
|
dropdown_update, |
|
gr.Chatbot.update(visible=True), |
|
gr.Textbox.update(visible=True), |
|
gr.Button.update(visible=True), |
|
gr.Row.update(visible=True), |
|
gr.Accordion.update(visible=True), |
|
) |
|
|
|
|
|
def load_demo(url_params, request: gr.Request): |
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") |
|
return load_demo_single(models, url_params) |
|
|
|
|
|
def regenerate(state, request: gr.Request): |
|
logger.info(f"regenerate. ip: {request.client.host}") |
|
state.messages[-1][-1] = None |
|
state.skip_next = False |
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 |
|
|
|
|
|
def clear_history(request: gr.Request): |
|
logger.info(f"clear_history. ip: {request.client.host}") |
|
state = None |
|
return (state, [], "") + (disable_btn,) * 5 |
|
|
|
|
|
def add_text(state, text, request: gr.Request): |
|
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") |
|
|
|
if state is None: |
|
state = get_conv_template("vicuna_v1.1") |
|
|
|
if len(text) <= 0: |
|
state.skip_next = True |
|
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 |
|
|
|
text = text[:1536] |
|
state.append_message(state.roles[0], text) |
|
state.append_message(state.roles[1], None) |
|
state.skip_next = False |
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 |
|
|
|
|
|
def post_process_code(code): |
|
sep = "\n```" |
|
if sep in code: |
|
blocks = code.split(sep) |
|
if len(blocks) % 2 == 1: |
|
for i in range(1, len(blocks), 2): |
|
blocks[i] = blocks[i].replace("\\_", "_") |
|
code = sep.join(blocks) |
|
return code |
|
|
|
|
|
def model_worker_stream_iter( |
|
conv, model_name, worker_addr, prompt, temperature, top_p, max_new_tokens |
|
): |
|
|
|
gen_params = { |
|
"model": model_name, |
|
"prompt": prompt, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
"stop": conv.stop_str, |
|
"stop_token_ids": conv.stop_token_ids, |
|
"echo": False, |
|
} |
|
logger.info(f"==== request ====\n{gen_params}") |
|
|
|
|
|
response = requests.post( |
|
worker_addr + "/worker_generate_stream", |
|
headers=headers, |
|
json=gen_params, |
|
stream=True, |
|
timeout=WORKER_API_TIMEOUT, |
|
) |
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
|
if chunk: |
|
data = json.loads(chunk.decode()) |
|
yield data |
|
|
|
|
|
def http_bot( |
|
state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request |
|
): |
|
logger.info(f"http_bot. ip: {request.client.host}") |
|
start_tstamp = time.time() |
|
model_name = model_selector |
|
temperature = float(temperature) |
|
top_p = float(top_p) |
|
max_new_tokens = int(max_new_tokens) |
|
|
|
if state.skip_next: |
|
|
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 |
|
return |
|
|
|
if len(state.messages) == state.offset + 2: |
|
|
|
new_state = get_conv_template(model_name.lower()) |
|
new_state.conv_id = uuid.uuid4().hex |
|
new_state.model_name = state.model_name or model_selector |
|
new_state.append_message(new_state.roles[0], state.messages[-2][1]) |
|
new_state.append_message(new_state.roles[1], None) |
|
state = new_state |
|
|
|
|
|
conv = state |
|
if "chatglm" in model_name: |
|
prompt = list(list(x) for x in conv.messages[conv.offset :]) |
|
else: |
|
prompt = conv.get_prompt() |
|
stream_iter = model_worker_stream_iter( |
|
conv, model_name, controller_url, prompt, temperature, top_p, max_new_tokens |
|
) |
|
|
|
state.messages[-1][-1] = "▌" |
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
|
|
|
try: |
|
for data in stream_iter: |
|
if data["error_code"] == 0: |
|
output = data["text"].strip() |
|
if "vicuna" in model_name: |
|
output = post_process_code(output) |
|
state.messages[-1][-1] = output + "▌" |
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
|
else: |
|
output = data["text"] + f"\n\n(error_code: {data['error_code']})" |
|
state.messages[-1][-1] = output |
|
yield (state, state.to_gradio_chatbot()) + ( |
|
disable_btn, |
|
disable_btn, |
|
disable_btn, |
|
enable_btn, |
|
enable_btn, |
|
) |
|
return |
|
time.sleep(0.02) |
|
except requests.exceptions.RequestException as e: |
|
state.messages[-1][-1] = ( |
|
f"{server_error_msg}\n\n" |
|
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" |
|
) |
|
yield (state, state.to_gradio_chatbot()) + ( |
|
disable_btn, |
|
disable_btn, |
|
disable_btn, |
|
enable_btn, |
|
enable_btn, |
|
) |
|
return |
|
except Exception as e: |
|
state.messages[-1][-1] = ( |
|
f"{server_error_msg}\n\n" |
|
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" |
|
) |
|
yield (state, state.to_gradio_chatbot()) + ( |
|
disable_btn, |
|
disable_btn, |
|
disable_btn, |
|
enable_btn, |
|
enable_btn, |
|
) |
|
return |
|
|
|
state.messages[-1][-1] = state.messages[-1][-1][:-1] |
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 |
|
|
|
finish_tstamp = time.time() |
|
logger.info(f"{output}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_css = ( |
|
code_highlight_css |
|
+ """ |
|
pre { |
|
white-space: pre-wrap; /* Since CSS 2.1 */ |
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ |
|
white-space: -pre-wrap; /* Opera 4-6 */ |
|
white-space: -o-pre-wrap; /* Opera 7 */ |
|
word-wrap: break-word; /* Internet Explorer 5.5+ */ |
|
} |
|
#notice_markdown th { |
|
display: none; |
|
} |
|
""" |
|
) |
|
|
|
|
|
def build_single_model_ui(models): |
|
notice_markdown = (""" |
|
# <p style="text-align: center;">Chat with Intel Labs optimized Large Language Models</p> |
|
|
|
### Choose a model to chat with |
|
""") |
|
|
|
state = gr.State() |
|
gr.Markdown(notice_markdown, elem_id="notice_markdown") |
|
|
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=models, |
|
value=models[0] if len(models) > 0 else "", |
|
interactive=True, |
|
show_label=False, |
|
).style(container=False) |
|
|
|
chatbot = grChatbot( |
|
elem_id="chatbot", label="Scroll down and start chatting", visible=False, |
|
).style(height=550) |
|
with gr.Row(): |
|
with gr.Column(scale=20): |
|
textbox = gr.Textbox( |
|
show_label=False, |
|
placeholder="Type your message...", |
|
visible=False, |
|
).style(container=False) |
|
with gr.Column(scale=1, min_width=50): |
|
send_btn = gr.Button(value="Send", visible=False) |
|
|
|
with gr.Row(visible=False) as button_row: |
|
regenerate_btn = gr.Button(value="Regenerate", interactive=False) |
|
clear_btn = gr.Button(value="Clear history", interactive=False) |
|
|
|
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.1, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=1.0, |
|
step=0.1, |
|
interactive=True, |
|
label="Top P", |
|
) |
|
max_output_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=1024, |
|
value=512, |
|
step=64, |
|
interactive=True, |
|
label="Max output tokens", |
|
) |
|
|
|
gr.Markdown(learn_more_md) |
|
|
|
btn_list = [regenerate_btn, clear_btn] |
|
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( |
|
http_bot, |
|
[state, model_selector, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
) |
|
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) |
|
|
|
model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) |
|
|
|
textbox.submit( |
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list |
|
).then( |
|
http_bot, |
|
[state, model_selector, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
) |
|
send_btn.click( |
|
add_text, [state, textbox], [state, chatbot, textbox] + btn_list |
|
).then( |
|
http_bot, |
|
[state, model_selector, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
) |
|
|
|
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row |
|
|
|
|
|
def build_demo(models): |
|
with gr.Blocks( |
|
title="Chat with Open Large Language Models", |
|
theme=gr.themes.Soft(), |
|
css=block_css, |
|
) as demo: |
|
url_params = gr.JSON(visible=False) |
|
|
|
with gr.Row(): |
|
gr.Column(scale=1, min_width=0) |
|
with gr.Column(scale=9): |
|
( |
|
state, |
|
model_selector, |
|
chatbot, |
|
textbox, |
|
send_btn, |
|
button_row, |
|
parameter_row, |
|
) = build_single_model_ui(models) |
|
gr.Column(scale=1, min_width=0) |
|
|
|
demo.load( |
|
load_demo_reload_model, |
|
[url_params], |
|
[ |
|
state, |
|
model_selector, |
|
chatbot, |
|
textbox, |
|
send_btn, |
|
button_row, |
|
parameter_row, |
|
], |
|
_js=get_window_url_params_js, |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
models = get_model_list(controller_url) |
|
|
|
demo = build_demo(models) |
|
demo.queue( |
|
concurrency_count=concurrency_count, status_update_rate=10, api_open=False |
|
).launch() |
|
|