Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import gradio as gr | |
| from llama2 import GradioLLaMA2ChatPPManager | |
| from llama2 import gen_text | |
| from styles import MODEL_SELECTION_CSS | |
| from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS | |
| from templates import templates | |
| from pingpong import PingPong | |
| from pingpong.context import CtxLastWindowStrategy | |
| TOKEN = os.getenv('HF_TOKEN') | |
| MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' | |
| def build_prompts(ppmanager, global_context, win_size=3): | |
| dummy_ppm = copy.deepcopy(ppmanager) | |
| dummy_ppm.ctx = global_context | |
| lws = CtxLastWindowStrategy(win_size) | |
| return lws(dummy_ppm) | |
| ex_file = open("examples.txt", "r") | |
| examples = ex_file.read().split("\n") | |
| ex_btns = [] | |
| chl_file = open("channels.txt", "r") | |
| channels = chl_file.read().split("\n") | |
| channel_btns = [] | |
| def fill_up_placeholders(txt): | |
| placeholders = get_placeholders(txt) | |
| highlighted_txt = txt | |
| return ( | |
| gr.update( | |
| visible=True, | |
| value=highlighted_txt | |
| ), | |
| gr.update( | |
| visible=True if len(placeholders) >= 1 else False, | |
| placeholder=placeholders[0] if len(placeholders) >= 1 else "" | |
| ), | |
| gr.update( | |
| visible=True if len(placeholders) >= 2 else False, | |
| placeholder=placeholders[1] if len(placeholders) >= 2 else "" | |
| ), | |
| gr.update( | |
| visible=True if len(placeholders) >= 3 else False, | |
| placeholder=placeholders[2] if len(placeholders) >= 3 else "" | |
| ), | |
| "" if len(placeholders) >= 1 else txt | |
| ) | |
| async def chat_stream(idx, local_data, instruction_txtbox, chat_state): | |
| res = [ | |
| chat_state["ppmanager_type"].from_json(json.dumps(ppm)) | |
| for ppm in local_data | |
| ] | |
| ppm = res[idx] | |
| ppm.add_pingpong( | |
| PingPong(instruction_txtbox, "") | |
| ) | |
| prompt = build_prompts(ppm, "global context", 3) | |
| for result in await gen_text(prompt, hf_model=MODEL_ID, hf_token=TOKEN): | |
| ppm.append_pong(result) | |
| yield ppm.build_uis(), str(res) | |
| def channel_num(btn_title): | |
| choice = 0 | |
| for idx, channel in enumerate(channels): | |
| if channel == btn_title: | |
| choice = idx | |
| return choice | |
| def set_chatbot(btn, ld, state): | |
| choice = channel_num(btn) | |
| res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
| empty = len(res[choice].pingpongs) == 0 | |
| return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty)) | |
| def set_example(btn): | |
| return btn, gr.update(visible=False) | |
| with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo: | |
| with gr.Column() as chat_view: | |
| idx = gr.State(0) | |
| chat_state = gr.State({ | |
| "ppmanager_type": GradioLLaMA2ChatPPManager | |
| }) | |
| local_data = gr.JSON({}, visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=180): | |
| gr.Markdown("GradioChat", elem_id="left-top") | |
| with gr.Column(elem_id="left-pane"): | |
| chat_back_btn = gr.Button("Back", elem_id="chat-back-btn") | |
| with gr.Accordion("Histories", elem_id="chat-history-accordion", open=True): | |
| channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"])) | |
| for channel in channels[1:]: | |
| channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"])) | |
| with gr.Column(scale=8, elem_id="right-pane"): | |
| with gr.Column( | |
| elem_id="initial-popup", visible=False | |
| ) as example_block: | |
| with gr.Row(scale=1): | |
| with gr.Column(elem_id="initial-popup-left-pane"): | |
| gr.Markdown("GradioChat", elem_id="initial-popup-title") | |
| gr.Markdown("Making the community's best AI chat models available to everyone.") | |
| with gr.Column(elem_id="initial-popup-right-pane"): | |
| gr.Markdown("Chat UI is now open sourced on Hugging Face Hub") | |
| gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)") | |
| with gr.Column(scale=1): | |
| gr.Markdown("Examples") | |
| with gr.Row(): | |
| for example in examples: | |
| ex_btns.append(gr.Button(example, elem_classes=["example-btn"])) | |
| with gr.Column(elem_id="aux-btns-popup", visible=True): | |
| with gr.Row(): | |
| stop = gr.Button("Stop", elem_classes=["aux-btn"]) | |
| regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"]) | |
| clean = gr.Button("Clean", elem_classes=["aux-btn"]) | |
| with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False): | |
| context_inspector = gr.Textbox( | |
| "", | |
| elem_id="aux-viewer-inspector", | |
| label="", | |
| lines=30, | |
| max_lines=50, | |
| ) | |
| chatbot = gr.Chatbot(elem_id='chatbot') | |
| instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt") | |
| with gr.Accordion("Example Templates", open=False): | |
| template_txt = gr.Textbox(visible=False) | |
| template_md = gr.Markdown(label="Chosen Template", visible=False, elem_classes="template-txt") | |
| with gr.Row(): | |
| placeholder_txt1 = gr.Textbox(label="placeholder #1", visible=False, interactive=True) | |
| placeholder_txt2 = gr.Textbox(label="placeholder #2", visible=False, interactive=True) | |
| placeholder_txt3 = gr.Textbox(label="placeholder #3", visible=False, interactive=True) | |
| for template in templates: | |
| with gr.Tab(template['title']): | |
| gr.Examples( | |
| template['template'], | |
| inputs=[template_txt], | |
| outputs=[template_md, placeholder_txt1, placeholder_txt2, placeholder_txt3, instruction_txtbox], | |
| run_on_click=True, | |
| fn=fill_up_placeholders, | |
| ) | |
| with gr.Accordion("Control Panel", open=False) as control_panel: | |
| with gr.Column(): | |
| with gr.Column(): | |
| gr.Markdown("#### Global context") | |
| with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=False): | |
| global_context = gr.Textbox( | |
| "global context", | |
| lines=5, | |
| max_lines=10, | |
| interactive=True, | |
| elem_id="global-context" | |
| ) | |
| # gr.Markdown("#### Internet search") | |
| # with gr.Row(): | |
| # internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode") | |
| # serper_api_key = gr.Textbox( | |
| # value= "" if args.serper_api_key is None else args.serper_api_key, | |
| # placeholder="Get one by visiting serper.dev", | |
| # label="Serper api key" | |
| # ) | |
| gr.Markdown("#### GenConfig for **response** text generation") | |
| with gr.Row(): | |
| res_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True) | |
| res_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True) | |
| res_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True) | |
| res_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True) | |
| res_mnts = gr.Slider(64, 8192, 0, step=1, label="new_tokens", interactive=True) | |
| res_beams = gr.Slider(1, 4, 0, step=1, label="beams") | |
| res_cache = gr.Radio([True, False], value=0, label="cache", interactive=True) | |
| res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True) | |
| res_eosid = gr.Number(value=0, visible=False, precision=0) | |
| res_padid = gr.Number(value=0, visible=False, precision=0) | |
| with gr.Column(visible=False): | |
| gr.Markdown("#### GenConfig for **summary** text generation") | |
| with gr.Row(): | |
| sum_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True) | |
| sum_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True) | |
| sum_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True) | |
| sum_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True) | |
| sum_mnts = gr.Slider(64, 8192, 0, step=1, label="new_tokens", interactive=True) | |
| sum_beams = gr.Slider(1, 8, 0, step=1, label="beams", interactive=True) | |
| sum_cache = gr.Radio([True, False], value=0, label="cache", interactive=True) | |
| sum_sample = gr.Radio([True, False], value=0, label="sample", interactive=True) | |
| sum_eosid = gr.Number(value=0, visible=False, precision=0) | |
| sum_padid = gr.Number(value=0, visible=False, precision=0) | |
| with gr.Column(): | |
| gr.Markdown("#### Context managements") | |
| with gr.Row(): | |
| ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True) | |
| ctx_sum_prompt = gr.Textbox( | |
| "summarize our conversations. what have we discussed about so far?", | |
| label="design a prompt to summarize the conversations", | |
| visible=False | |
| ) | |
| instruction_txtbox.submit( | |
| chat_stream, | |
| [idx, local_data, instruction_txtbox, chat_state], | |
| [chatbot, local_data] | |
| ) | |
| for btn in channel_btns: | |
| btn.click( | |
| set_chatbot, | |
| [btn, local_data, chat_state], | |
| [chatbot, idx, example_block, regenerate] | |
| ).then( | |
| None, btn, None, | |
| _js=UPDATE_LEFT_BTNS_STATE | |
| ) | |
| for btn in ex_btns: | |
| btn.click( | |
| set_example, | |
| [btn], | |
| [instruction_txtbox, example_block] | |
| ) | |
| placeholder_txt1.change( | |
| inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| outputs=[template_md], | |
| show_progress=False, | |
| _js=UPDATE_PLACEHOLDERS, | |
| fn=None | |
| ) | |
| placeholder_txt2.change( | |
| inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| outputs=[template_md], | |
| show_progress=False, | |
| _js=UPDATE_PLACEHOLDERS, | |
| fn=None | |
| ) | |
| placeholder_txt3.change( | |
| inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| outputs=[template_md], | |
| show_progress=False, | |
| _js=UPDATE_PLACEHOLDERS, | |
| fn=None | |
| ) | |
| placeholder_txt1.submit( | |
| inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| fn=get_final_template | |
| ) | |
| placeholder_txt2.submit( | |
| inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| fn=get_final_template | |
| ) | |
| placeholder_txt3.submit( | |
| inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], | |
| fn=get_final_template | |
| ) | |
| demo.load( | |
| None, | |
| inputs=None, | |
| outputs=[chatbot, local_data], | |
| _js=GET_LOCAL_STORAGE, | |
| ) | |
| demo.launch() |