Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from collections import defaultdict | |
| import gradio as gr | |
| from optimum.onnxruntime import ORTModelForCausalLM | |
| import itertools | |
| import re | |
| user_token = "<User>" | |
| eos_token = "<EOS>" | |
| bos_token = "<BOS>" | |
| bot_token = "<Assistant>" | |
| max_context_length = 750 | |
| def is_english_word(tested_string): | |
| pattern = re.compile(r"^[a-zA-Z]+$") | |
| return pattern.match(tested_string) is not None | |
| def format(history): | |
| prompt = bos_token | |
| for idx, txt in enumerate(history): | |
| if idx % 2 == 0: | |
| prompt += f"{user_token}{txt}{eos_token}" | |
| else: | |
| prompt += f"{bot_token}{txt}" | |
| prompt += bot_token | |
| print(prompt) | |
| return prompt | |
| def gradio(model, tokenizer): | |
| def response( | |
| user_input, | |
| chat_history, | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| ): | |
| history = list(itertools.chain(*chat_history)) | |
| history.append(user_input) | |
| prompt = format(history) | |
| input_ids = tokenizer.encode( | |
| prompt, | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| )[:, -max_context_length:] | |
| prompt_length = input_ids.shape[1] | |
| beam_output = model.generate( | |
| input_ids, | |
| pad_token_id=tokenizer.pad_token_id, | |
| max_new_tokens=250, | |
| num_beams=1, # with cpu | |
| top_k=top_k, | |
| top_p=top_p, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| early_stopping=True, | |
| do_sample=True | |
| ) | |
| output = beam_output[0][prompt_length:] | |
| tokens = tokenizer.convert_ids_to_tokens(output) | |
| for i, token in enumerate(tokens[:-1]): | |
| if is_english_word(token) and is_english_word(tokens[i + 1]): | |
| tokens[i] = token + " " | |
| text = "".join(tokens).replace("##", "").replace("[UNK]", "").strip() | |
| return text | |
| bot = gr.Chatbot(show_copy_button=True, show_share_button=True) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("GPT2 chatbot | Powered by nlp-greyfoss") | |
| with gr.Accordion("Parameters in generation", open=False): | |
| with gr.Row(): | |
| top_k = gr.Slider( | |
| 2.0, | |
| 100.0, | |
| label="top_k", | |
| step=1, | |
| value=50, | |
| info="Limit the number of candidate tokens considered during decoding.", | |
| ) | |
| top_p = gr.Slider( | |
| 0.1, | |
| 1.0, | |
| label="top_p", | |
| value=0.9, | |
| info="Control the diversity of the output by selecting tokens with cumulative probabilities up to the Top-P threshold.", | |
| ) | |
| temperature = gr.Slider( | |
| 0.1, | |
| 2.0, | |
| label="temperature", | |
| value=0.9, | |
| info="Control the randomness of the generated text. A higher temperature results in more diverse and unpredictable outputs, while a lower temperature produces more conservative and coherent text.", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| 0.1, | |
| 2.0, | |
| label="repetition_penalty", | |
| value=1.2, | |
| info="Discourage the model from generating repetitive tokens in a sequence.", | |
| ) | |
| no_repeat_ngram_size = gr.Slider( | |
| 0, | |
| 100, | |
| label="no_repeat_ngram_size", | |
| step=1, | |
| value=5, | |
| info="Prevent the model from generating sequences of n consecutive tokens that have already been generated in the context. ", | |
| ) | |
| gr.ChatInterface( | |
| response, | |
| chatbot=bot, | |
| additional_inputs=[ | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| ], | |
| ) | |
| demo.queue().launch() | |
| tokenizer = AutoTokenizer.from_pretrained("greyfoss/gpt2-chatbot-chinese") | |
| model = ORTModelForCausalLM.from_pretrained("greyfoss/gpt2-chatbot-chinese", export=True) | |
| gradio(model, tokenizer) | |