Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import torch | |
| from threading import Thread | |
| phi4_model_path = "microsoft/phi-4" | |
| phi4_mini_model_path = "microsoft/Phi-4-mini-instruct" | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, torch_dtype="auto").to(device) | |
| phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
| phi4_mini_model = AutoModelForCausalLM.from_pretrained(phi4_mini_model_path, torch_dtype="auto").to(device) | |
| phi4_mini_tokenizer = AutoTokenizer.from_pretrained(phi4_mini_model_path) | |
| def generate_response(user_message, model_name, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): | |
| if not user_message.strip(): | |
| return history_state, history_state | |
| # Select models | |
| if model_name == "Phi-4": | |
| model = phi4_model | |
| tokenizer = phi4_tokenizer | |
| start_tag = "<|im_start|>" | |
| sep_tag = "<|im_sep|>" | |
| end_tag = "<|im_end|>" | |
| elif model_name == "Phi-4-mini-instruct": | |
| model = phi4_mini_model | |
| tokenizer = phi4_mini_tokenizer | |
| start_tag = "" | |
| sep_tag = "" | |
| end_tag = "<|end|>" | |
| else: | |
| raise ValueError("Invalid model selected") | |
| # Recommended prompt settings by Microsoft | |
| system_message = "You are a friendly and knowledgeable assistant, here to help with any questions or tasks." | |
| if model_name == "Phi-4": | |
| prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
| for message in history_state: | |
| if message["role"] == "user": | |
| prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}" | |
| elif message["role"] == "assistant" and message["content"]: | |
| prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}" | |
| prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
| else: | |
| prompt = f"<|system|>{system_message}{end_tag}" | |
| for message in history_state: | |
| if message["role"] == "user": | |
| prompt += f"<|user|>{message['content']}{end_tag}" | |
| elif message["role"] == "assistant" and message["content"]: | |
| prompt += f"<|assistant|>{message['content']}{end_tag}" | |
| prompt += f"<|user|>{user_message}{end_tag}<|assistant|>" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| do_sample = not (temperature == 1.0 and top_k >= 100 and top_p == 1.0) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| # sampling techniques | |
| generation_kwargs = { | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| "max_new_tokens": int(max_tokens), | |
| "do_sample": do_sample, | |
| "temperature": temperature, | |
| "top_k": int(top_k), | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "streamer": streamer, | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream the response | |
| assistant_response = "" | |
| new_history = history_state + [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": ""} | |
| ] | |
| for new_token in streamer: | |
| cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "").replace("<|end|>", "").replace("<|system|>", "").replace("<|user|>", "").replace("<|assistant|>", "") | |
| assistant_response += cleaned_token | |
| new_history[-1]["content"] = assistant_response.strip() | |
| yield new_history, new_history | |
| yield new_history, new_history | |
| example_messages = { | |
| "Learn about physics": "Explain Newtonβs laws of motion.", | |
| "Discover space facts": "What are some interesting facts about black holes?", | |
| "Write a factorial function": "Write a Python function to calculate the factorial of a number." | |
| } | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Phi-4 Chatbot Demo | |
| Welcome to the Phi-4 Chatbot Demo! You can chat with Microsoft's Phi-4 or Phi-4-mini-instruct models. Adjust the settings on the left to customize the model's responses. | |
| """ | |
| ) | |
| history_state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| model_dropdown = gr.Dropdown( | |
| choices=["Phi-4", "Phi-4-mini-instruct"], | |
| label="Select Model", | |
| value="Phi-4" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=64, | |
| maximum=4096, | |
| step=50, | |
| value=512, | |
| label="Max Tokens" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=1.0, | |
| label="Temperature" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="Top-k" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| label="Top-p" | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.0, | |
| label="Repetition Penalty" | |
| ) | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(label="Chat", type="messages") | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| label="Your message", | |
| placeholder="Type your message here...", | |
| scale=3 | |
| ) | |
| submit_button = gr.Button("Send", variant="primary", scale=1) | |
| clear_button = gr.Button("Clear", scale=1) | |
| gr.Markdown("**Try these examples:**") | |
| with gr.Row(): | |
| example1_button = gr.Button("Learn about physics") | |
| example2_button = gr.Button("Discover space facts") | |
| example3_button = gr.Button("Write a factorial function") | |
| submit_button.click( | |
| fn=generate_response, | |
| inputs=[user_input, model_dropdown, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], | |
| outputs=[chatbot, history_state] | |
| ).then( | |
| fn=lambda: gr.update(value=""), | |
| inputs=None, | |
| outputs=user_input | |
| ) | |
| clear_button.click( | |
| fn=lambda: ([], []), | |
| inputs=None, | |
| outputs=[chatbot, history_state] | |
| ) | |
| example1_button.click( | |
| fn=lambda: gr.update(value=example_messages["Learn about physics"]), | |
| inputs=None, | |
| outputs=user_input | |
| ) | |
| example2_button.click( | |
| fn=lambda: gr.update(value=example_messages["Discover space facts"]), | |
| inputs=None, | |
| outputs=user_input | |
| ) | |
| example3_button.click( | |
| fn=lambda: gr.update(value=example_messages["Write a factorial function"]), | |
| inputs=None, | |
| outputs=user_input | |
| ) | |
| demo.launch(ssr_mode=False) |