""" The gradio demo server for chatting with a single model. """ import datetime import json import os import time import uuid import logging import gradio as gr import requests from conversation import get_conv_template from gradio_patch import Chatbot as grChatbot from gradio_css import code_highlight_css from utils import ( WORKER_API_TIMEOUT, ErrorCode, server_error_msg, get_window_url_params_js, ) logging.basicConfig( format='%(asctime)s %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) headers = {"User-Agent": "fastchat Client"} no_change_btn = gr.Button.update() enable_btn = gr.Button.update(interactive=True) disable_btn = gr.Button.update(interactive=False) controller_url = os.environ['controller_url'] concurrency_count = int(os.environ['concurrency_count']) learn_more_md = (""" ### Notice - All the models in this demo run on 4th Generation Intel® Xeon® (Sapphire Rapids) utilizing AMX operations and mixed precision inference - This demo is based on the FastChat demo server. [[GitHub]](https://github.com/lm-sys/FastChat) ### Terms of use By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It can produce factually incorrect output, and should not be relied on to produce factually accurate information. The service only provides limited safety measures and may generate lewd, biased or otherwise offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. ### License The service is a research preview intended for non-commercial use only, subject to the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. """) def get_model_list(controller_url): ret = requests.post(controller_url + "/refresh_all_workers") assert ret.status_code == 200 ret = requests.post(controller_url + "/list_models") models = ret.json()["models"] models.sort() logger.info(f"Models: {models}") return models def load_demo_refresh_model_list(url_params): models = get_model_list(controller_url) selected_model = models[0] if len(models) > 0 else "" if "model" in url_params: model = url_params["model"] if model in models: selected_model = model dropdown_update = gr.Dropdown.update( choices=models, value=selected_model, visible=True ) state = None return ( state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True), ) def load_demo_reload_model(url_params, request: gr.Request): logger.info( f"load_demo_reload_model. ip: {request.client.host}. params: {url_params}" ) return load_demo_refresh_model_list(url_params) def load_demo_single(models, url_params): dropdown_update = gr.Dropdown.update(visible=True) if "model" in url_params: model = url_params["model"] if model in models: dropdown_update = gr.Dropdown.update(value=model, visible=True) state = None return ( state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True), ) def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") return load_demo_single(models, url_params) def regenerate(state, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.messages[-1][-1] = None state.skip_next = False return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = None return (state, [], "") + (disable_btn,) * 5 def add_text(state, text, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if state is None: state = get_conv_template("vicuna_v1.1") if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 text = text[:1536] # Hard cut-off state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def post_process_code(code): sep = "\n```" if sep in code: blocks = code.split(sep) if len(blocks) % 2 == 1: for i in range(1, len(blocks), 2): blocks[i] = blocks[i].replace("\\_", "_") code = sep.join(blocks) return code def model_worker_stream_iter( conv, model_name, worker_addr, prompt, temperature, top_p, max_new_tokens ): # Make requests gen_params = { "model": model_name, "prompt": prompt, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids, "echo": False, } logger.info(f"==== request ====\n{gen_params}") # Stream output response = requests.post( worker_addr + "/worker_generate_stream", headers=headers, json=gen_params, stream=True, timeout=WORKER_API_TIMEOUT, ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) yield data def http_bot( state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request ): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() model_name = model_selector temperature = float(temperature) top_p = float(top_p) max_new_tokens = int(max_new_tokens) if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: # First round of conversation new_state = get_conv_template(model_name.lower()) new_state.conv_id = uuid.uuid4().hex new_state.model_name = state.model_name or model_selector new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state # Construct prompt conv = state if "chatglm" in model_name: prompt = list(list(x) for x in conv.messages[conv.offset :]) else: prompt = conv.get_prompt() stream_iter = model_worker_stream_iter( conv, model_name, controller_url, prompt, temperature, top_p, max_new_tokens ) state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: for data in stream_iter: if data["error_code"] == 0: output = data["text"].strip() if "vicuna" in model_name: output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] + f"\n\n(error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return time.sleep(0.02) except requests.exceptions.RequestException as e: state.messages[-1][-1] = ( f"{server_error_msg}\n\n" f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" ) yield (state, state.to_gradio_chatbot()) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return except Exception as e: state.messages[-1][-1] = ( f"{server_error_msg}\n\n" f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" ) yield (state, state.to_gradio_chatbot()) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 finish_tstamp = time.time() logger.info(f"{output}") # TODO # with open(get_conv_log_filename(), "a") as fout: # data = { # "tstamp": round(finish_tstamp, 4), # "type": "chat", # "model": model_name, # "gen_params": { # "temperature": temperature, # "top_p": top_p, # "max_new_tokens": max_new_tokens, # }, # "start": round(start_tstamp, 4), # "finish": round(start_tstamp, 4), # "state": state.dict(), # "ip": request.client.host, # } # fout.write(json.dumps(data) + "\n") block_css = ( code_highlight_css + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -pre-wrap; /* Opera 4-6 */ white-space: -o-pre-wrap; /* Opera 7 */ word-wrap: break-word; /* Internet Explorer 5.5+ */ } #notice_markdown th { display: none; } """ ) def build_single_model_ui(models): notice_markdown = (""" #
Chat with Intel Labs optimized Large Language Models
### Choose a model to chat with """) state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, ).style(container=False) chatbot = grChatbot( elem_id="chatbot", label="Scroll down and start chatting", visible=False, ).style(height=550) with gr.Row(): with gr.Column(scale=20): textbox = gr.Textbox( show_label=False, placeholder="Type your message...", visible=False, ).style(container=False) with gr.Column(scale=1, min_width=50): send_btn = gr.Button(value="Send", visible=False) with gr.Row(visible=False) as button_row: regenerate_btn = gr.Button(value="Regenerate", interactive=False) clear_btn = gr.Button(value="Clear history", interactive=False) with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.1, interactive=True, label="Temperature", ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Top P", ) max_output_tokens = gr.Slider( minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens", ) gr.Markdown(learn_more_md) btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) textbox.submit( add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) send_btn.click( add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row def build_demo(models): with gr.Blocks( title="Chat with Open Large Language Models", theme=gr.themes.Soft(), css=block_css, ) as demo: url_params = gr.JSON(visible=False) with gr.Row(): gr.Column(scale=1, min_width=0) with gr.Column(scale=9): ( state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row, ) = build_single_model_ui(models) gr.Column(scale=1, min_width=0) demo.load( load_demo_reload_model, [url_params], [ state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row, ], _js=get_window_url_params_js, ) return demo if __name__ == "__main__": models = get_model_list(controller_url) demo = build_demo(models) demo.queue( concurrency_count=concurrency_count, status_update_rate=10, api_open=False ).launch()