import os import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct" model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct" tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct) model_2_7B_instruct = AutoModelForCausalLM.from_pretrained( model_name_2_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16 ) tokenizer_7B_instruct = AutoTokenizer.from_pretrained(model_name_7B_instruct) model_7B_instruct = AutoModelForCausalLM.from_pretrained( model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16 ) def extract_assistant_response(generated_text): assistant_token = '<|im_start|> assistant' end_token = '<|im_end|>' start_idx = generated_text.rfind(assistant_token) if start_idx == -1: # Assistant token not found return generated_text.strip() start_idx += len(assistant_token) end_idx = generated_text.find(end_token, start_idx) if end_idx == -1: # End token not found, return from start_idx to end return generated_text[start_idx:].strip() else: return generated_text[start_idx:end_idx].strip() def generate_response_2_7B_instruct(chat_history, max_new_tokens): sample = [] for turn in chat_history: if turn[0]: sample.append({'role': 'user', 'content': turn[0]}) if turn[1]: sample.append({'role': 'assistant', 'content': turn[1]}) chat_sample = tokenizer_2_7B_instruct.apply_chat_template(sample, tokenize=False) input_ids = tokenizer_2_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).to(model_2_7B_instruct.device) outputs = model_2_7B_instruct.generate(**input_ids, max_new_tokens=int(max_new_tokens), return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False) """ outputs = model_2_7B_instruct.generate( input_ids=input_ids, max_new_tokens=int(max_new_tokens), do_sample=True, use_cache=True, temperature=temperature, top_k=int(top_k), top_p=top_p, repetition_penalty=repetition_penalty, num_beams=int(num_beams), length_penalty=length_penalty, num_return_sequences=1 ) """ generated_text = tokenizer_2_7B_instruct.decode(outputs[0]) assistant_response = extract_assistant_response(generated_text) return assistant_response def generate_response_7B_instruct(chat_history, max_new_tokens): sample = [] for turn in chat_history: if turn[0]: sample.append({'role': 'user', 'content': turn[0]}) if turn[1]: sample.append({'role': 'assistant', 'content': turn[1]}) chat_sample = tokenizer_7B_instruct.apply_chat_template(sample, tokenize=False) input_ids = tokenizer_7B_instruct(chat_sample, return_tensors='pt', add_special_tokens=False).to(model_7B_instruct.device) outputs = model_7B_instruct.generate(**input_ids, max_new_tokens=int(max_new_tokens), return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False) """ outputs = model_7B_instruct.generate( input_ids=input_ids, max_new_tokens=int(max_new_tokens), do_sample=True, use_cache=True, temperature=temperature, top_k=int(top_k), top_p=top_p, repetition_penalty=repetition_penalty, num_beams=int(num_beams), length_penalty=length_penalty, num_return_sequences=1 ) """ generated_text = tokenizer_7B_instruct.decode(outputs[0]) assistant_response = extract_assistant_response(generated_text) return assistant_response with gr.Blocks() as demo: gr.Markdown("# Zamba2 Model Selector") with gr.Tabs(): with gr.TabItem("2.7B Instruct Model"): gr.Markdown("### Zamba2-2.7B Instruct Model") with gr.Column(): chat_history_2_7B_instruct = gr.State([]) chatbot_2_7B_instruct = gr.Chatbot() message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message") with gr.Accordion("Generation Parameters", open=False): max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") # temperature_2_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature") # top_k_2_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K") # top_p_2_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P") # repetition_penalty_2_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") # num_beams_2_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams") # length_penalty_2_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty") def user_message_2_7B_instruct(message, chat_history): chat_history = chat_history + [[message, None]] return gr.update(value=""), chat_history, chat_history def bot_response_2_7B_instruct(chat_history, max_new_tokens): response = generate_response_2_7B_instruct(chat_history, max_new_tokens) chat_history[-1][1] = response return chat_history, chat_history send_button_2_7B_instruct = gr.Button("Send") send_button_2_7B_instruct.click( fn=user_message_2_7B_instruct, inputs=[message_2_7B_instruct, chat_history_2_7B_instruct], outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct] ).then( fn=bot_response_2_7B_instruct, inputs=[ chat_history_2_7B_instruct, max_new_tokens_2_7B_instruct ], outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct] ) with gr.TabItem("7B Instruct Model"): gr.Markdown("### Zamba2-7B Instruct Model") with gr.Column(): chat_history_7B_instruct = gr.State([]) chatbot_7B_instruct = gr.Chatbot() message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message") with gr.Accordion("Generation Parameters", open=False): max_new_tokens_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") # temperature_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature") # top_k_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K") # top_p_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P") # repetition_penalty_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") # num_beams_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams") # length_penalty_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty") def user_message_7B_instruct(message, chat_history): chat_history = chat_history + [[message, None]] return gr.update(value=""), chat_history, chat_history def bot_response_7B_instruct(chat_history, max_new_tokens): response = generate_response_7B_instruct(chat_history, max_new_tokens) chat_history[-1][1] = response return chat_history, chat_history send_button_7B_instruct = gr.Button("Send") send_button_7B_instruct.click( fn=user_message_7B_instruct, inputs=[message_7B_instruct, chat_history_7B_instruct], outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct] ).then( fn=bot_response_7B_instruct, inputs=[ chat_history_7B_instruct, max_new_tokens_7B_instruct ], outputs=[chat_history_7B_instruct, chatbot_7B_instruct] ) if __name__ == "__main__": demo.queue().launch()