Q8-Chat / app.py
ofirzaf's picture
Update app.py
07de8f2
"""
The gradio demo server for chatting with a single model.
"""
import datetime
import json
import os
import time
import uuid
import logging
import gradio as gr
import requests
from conversation import get_conv_template
from gradio_patch import Chatbot as grChatbot
from gradio_css import code_highlight_css
from utils import (
WORKER_API_TIMEOUT,
ErrorCode,
server_error_msg,
get_window_url_params_js,
)
logging.basicConfig(
format='%(asctime)s %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
headers = {"User-Agent": "fastchat Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
controller_url = os.environ['controller_url']
concurrency_count = int(os.environ['concurrency_count'])
learn_more_md = ("""
### Notice
- All the models in this demo run on 4th Generation Intel® Xeon® (Sapphire Rapids) utilizing AMX operations and mixed precision inference
- This demo is based on the FastChat demo server. [[GitHub]](https://github.com/lm-sys/FastChat)
### Terms of use
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It can produce factually incorrect output, and should not be relied on to produce factually accurate information. The service only provides limited safety measures and may generate lewd, biased or otherwise offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
### License
The service is a research preview intended for non-commercial use only, subject to the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
def get_model_list(controller_url):
ret = requests.post(controller_url + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(controller_url + "/list_models")
models = ret.json()["models"]
models.sort()
logger.info(f"Models: {models}")
return models
def load_demo_refresh_model_list(url_params):
models = get_model_list(controller_url)
selected_model = models[0] if len(models) > 0 else ""
if "model" in url_params:
model = url_params["model"]
if model in models:
selected_model = model
dropdown_update = gr.Dropdown.update(
choices=models, value=selected_model, visible=True
)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def load_demo_reload_model(url_params, request: gr.Request):
logger.info(
f"load_demo_reload_model. ip: {request.client.host}. params: {url_params}"
)
return load_demo_refresh_model_list(url_params)
def load_demo_single(models, url_params):
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
return load_demo_single(models, url_params)
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if state is None:
state = get_conv_template("vicuna_v1.1")
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
text = text[:1536] # Hard cut-off
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def model_worker_stream_iter(
conv, model_name, worker_addr, prompt, temperature, top_p, max_new_tokens
):
# Make requests
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"stop": conv.stop_str,
"stop_token_ids": conv.stop_token_ids,
"echo": False,
}
logger.info(f"==== request ====\n{gen_params}")
# Stream output
response = requests.post(
worker_addr + "/worker_generate_stream",
headers=headers,
json=gen_params,
stream=True,
timeout=WORKER_API_TIMEOUT,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
yield data
def http_bot(
state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request
):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
temperature = float(temperature)
top_p = float(top_p)
max_new_tokens = int(max_new_tokens)
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
new_state = get_conv_template(model_name.lower())
new_state.conv_id = uuid.uuid4().hex
new_state.model_name = state.model_name or model_selector
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# Construct prompt
conv = state
if "chatglm" in model_name:
prompt = list(list(x) for x in conv.messages[conv.offset :])
else:
prompt = conv.get_prompt()
stream_iter = model_worker_stream_iter(
conv, model_name, controller_url, prompt, temperature, top_p, max_new_tokens
)
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
for data in stream_iter:
if data["error_code"] == 0:
output = data["text"].strip()
if "vicuna" in model_name:
output = post_process_code(output)
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
time.sleep(0.02)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = (
f"{server_error_msg}\n\n"
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except Exception as e:
state.messages[-1][-1] = (
f"{server_error_msg}\n\n"
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
# TODO
# with open(get_conv_log_filename(), "a") as fout:
# data = {
# "tstamp": round(finish_tstamp, 4),
# "type": "chat",
# "model": model_name,
# "gen_params": {
# "temperature": temperature,
# "top_p": top_p,
# "max_new_tokens": max_new_tokens,
# },
# "start": round(start_tstamp, 4),
# "finish": round(start_tstamp, 4),
# "state": state.dict(),
# "ip": request.client.host,
# }
# fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
"""
)
def build_single_model_ui(models):
notice_markdown = ("""
# <p style="text-align: center;">Chat with Intel Labs optimized Large Language Models</p>
### Choose a model to chat with
""")
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
).style(container=False)
chatbot = grChatbot(
elem_id="chatbot", label="Scroll down and start chatting", visible=False,
).style(height=550)
with gr.Row():
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Type your message...",
visible=False,
).style(container=False)
with gr.Column(scale=1, min_width=50):
send_btn = gr.Button(value="Send", visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="Regenerate", interactive=False)
clear_btn = gr.Button(value="Clear history", interactive=False)
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)
gr.Markdown(learn_more_md)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
def build_demo(models):
with gr.Blocks(
title="Chat with Open Large Language Models",
theme=gr.themes.Soft(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
with gr.Row():
gr.Column(scale=1, min_width=0)
with gr.Column(scale=9):
(
state,
model_selector,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
) = build_single_model_ui(models)
gr.Column(scale=1, min_width=0)
demo.load(
load_demo_reload_model,
[url_params],
[
state,
model_selector,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
],
_js=get_window_url_params_js,
)
return demo
if __name__ == "__main__":
models = get_model_list(controller_url)
demo = build_demo(models)
demo.queue(
concurrency_count=concurrency_count, status_update_rate=10, api_open=False
).launch()