Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import gradio as gr | |
| #from transformers import pipeline | |
| import torch | |
| from utils import * | |
| from presets import * | |
| #antwort="" | |
| # Create a chatbot connection | |
| #chatbot = hugchat.ChatBot(cookie_path="cookies.json") | |
| #Alternativ mit beliebigen Modellen: | |
| #base_model = "project-baize/baize-v2-7b" | |
| base_model = "microsoft/DialoGPT-medium" | |
| tokenizer,model,device = load_tokenizer_and_model(base_model) | |
| def predict(text, | |
| chatbotGr, | |
| history, | |
| top_p, | |
| temperature, | |
| max_length_tokens, | |
| max_context_length_tokens,): | |
| if text=="": | |
| yield chatbotGr,history,"Empty context." | |
| return | |
| try: | |
| model | |
| except: | |
| yield [[text,"No Model Found"]],[],"No Model Found" | |
| return | |
| inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens) | |
| if inputs is None: | |
| yield chatbotGr,history,"Input too long." | |
| return | |
| else: | |
| prompt,inputs=inputs | |
| begin_length = len(prompt) | |
| input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device) | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): | |
| if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: | |
| if "[|Human|]" in x: | |
| x = x[:x.index("[|Human|]")].strip() | |
| if "[|AI|]" in x: | |
| x = x[:x.index("[|AI|]")].strip() | |
| x = x.strip() | |
| a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]] | |
| yield a, b, "Generating..." | |
| if shared_state.interrupted: | |
| shared_state.recover() | |
| try: | |
| yield a, b, "Stop: Success" | |
| return | |
| except: | |
| pass | |
| del input_ids | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| #print(text) | |
| #print(x) | |
| #print("="*80) | |
| try: | |
| yield a,b,"Generate: Success" | |
| except: | |
| pass | |
| def reset_chat(): | |
| id_new = chatbot.new_conversation() | |
| chatbot.change_conversation(id_new) | |
| reset_textbox() | |
| with gr.Blocks(theme=small_and_beautiful_theme) as demo: | |
| history = gr.State([]) | |
| user_question = gr.State("") | |
| with gr.Row(): | |
| gr.HTML(title) | |
| status_display = gr.Markdown("Erfolg", elem_id="status_display") | |
| gr.Markdown(description_top) | |
| with gr.Row(scale=1).style(equal_height=True): | |
| with gr.Column(scale=5): | |
| with gr.Row(scale=1): | |
| chatbotGr = gr.Chatbot(elem_id="LI_chatbot").style(height="100%") | |
| with gr.Row(scale=1): | |
| with gr.Column(scale=12): | |
| user_input = gr.Textbox( | |
| show_label=False, placeholder="Gib deinen Text / Frage ein." | |
| ).style(container=False) | |
| with gr.Column(min_width=90, scale=1): | |
| submitBtn = gr.Button("Absenden") | |
| with gr.Column(min_width=90, scale=1): | |
| cancelBtn = gr.Button("Stoppen") | |
| with gr.Row(scale=1): | |
| emptyBtn = gr.Button( | |
| "🧹 Neuer Chat", | |
| ) | |
| with gr.Column(): | |
| with gr.Column(min_width=50, scale=1): | |
| with gr.Tab(label="Parameter zum Model"): | |
| gr.Markdown("# Parameters") | |
| top_p = gr.Slider( | |
| minimum=-0, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| interactive=True, | |
| label="Top-p", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1, | |
| step=0.1, | |
| interactive=True, | |
| label="Temperature", | |
| ) | |
| max_length_tokens = gr.Slider( | |
| minimum=0, | |
| maximum=512, | |
| value=512, | |
| step=8, | |
| interactive=True, | |
| label="Max Generation Tokens", | |
| ) | |
| max_context_length_tokens = gr.Slider( | |
| minimum=0, | |
| maximum=4096, | |
| value=2048, | |
| step=128, | |
| interactive=True, | |
| label="Max History Tokens", | |
| ) | |
| gr.Markdown(description) | |
| predict_args = dict( | |
| fn=predict, | |
| inputs=[ | |
| user_question, | |
| chatbotGr, | |
| history, | |
| top_p, | |
| temperature, | |
| max_length_tokens, | |
| max_context_length_tokens, | |
| ], | |
| outputs=[chatbotGr, history, status_display], | |
| show_progress=True, | |
| ) | |
| #neuer Chat | |
| reset_args = dict( | |
| fn=reset_chat, inputs=[], outputs=[user_input, status_display] | |
| ) | |
| # Chatbot | |
| transfer_input_args = dict( | |
| fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True | |
| ) | |
| #Listener auf Start-Click auf Button oder Return | |
| predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args) | |
| predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args) | |
| #Listener, Wenn reset... | |
| emptyBtn.click( | |
| reset_state, | |
| outputs=[chatbotGr, history, status_display], | |
| show_progress=True, | |
| ) | |
| emptyBtn.click(**reset_args) | |
| demo.title = "LI Chat" | |
| #demo.queue(concurrency_count=1).launch(share=True) | |
| demo.queue(concurrency_count=1).launch() | |