Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
import gradio as gr | |
# Model Names | |
BASE_MODEL_NAME = "microsoft/phi-2" | |
ADAPTER_REPO = "Shriti09/Microsoft-Phi-QLora" | |
# Load tokenizer and model | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) | |
tokenizer.pad_token = tokenizer.eos_token | |
print("Loading base model...") | |
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, device_map="auto", torch_dtype=torch.float16) | |
print("Loading LoRA adapter...") | |
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) | |
# Merge adapter into the base model | |
model = model.merge_and_unload() | |
model.eval() | |
# Function to generate responses | |
def generate_response(message, chat_history, temperature, top_p, max_tokens): | |
# Combine history with the new message | |
full_prompt = "" | |
for user_msg, bot_msg in chat_history: | |
full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n" | |
full_prompt += f"User: {message}\nAI:" | |
# Tokenize and generate | |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_length=len(inputs["input_ids"][0]) + max_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and extract the AI response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Only return the new part of the response | |
response = response.split("AI:")[-1].strip() | |
# Update history | |
chat_history.append((message, response)) | |
return chat_history, chat_history | |
# Gradio UI with Blocks | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>π€ Phi-2 QLoRA Chatbot</center></h1>") | |
gr.Markdown("Chat with Microsoft Phi-2 fine-tuned using QLoRA adapters!") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder="Ask me something...", label="Your Message") | |
clear = gr.Button("ποΈ Clear Chat") | |
# Add sliders for controlling generation behavior | |
with gr.Row(): | |
temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature") | |
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)") | |
max_tokens_slider = gr.Slider(64, 1024, value=256, step=64, label="Max Tokens") | |
# State to hold chat history | |
state = gr.State([]) | |
# On send message | |
def on_message(message, history, temperature, top_p, max_tokens): | |
return generate_response(message, history, temperature, top_p, max_tokens) | |
# Button actions | |
msg.submit(on_message, | |
[msg, state, temp_slider, top_p_slider, max_tokens_slider], | |
[chatbot, state]) | |
clear.click(lambda: ([], []), None, [chatbot, state]) | |
# Launch the Gradio app | |
demo.launch() | |