Shriti09's picture
Update app.py
c91a27e verified
raw
history blame
2.87 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
# Model Names
BASE_MODEL_NAME = "microsoft/phi-2"
ADAPTER_REPO = "Shriti09/Microsoft-Phi-QLora"
# Load tokenizer and model
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, device_map="auto", torch_dtype=torch.float16)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
# Merge adapter into the base model
model = model.merge_and_unload()
model.eval()
# Function to generate responses
def generate_response(message, chat_history, temperature, top_p, max_tokens):
# Combine history with the new message
full_prompt = ""
for user_msg, bot_msg in chat_history:
full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n"
full_prompt += f"User: {message}\nAI:"
# Tokenize and generate
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=len(inputs["input_ids"][0]) + max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id
)
# Decode and extract the AI response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Only return the new part of the response
response = response.split("AI:")[-1].strip()
# Update history
chat_history.append((message, response))
return chat_history, chat_history
# Gradio UI with Blocks
with gr.Blocks() as demo:
gr.Markdown("<h1><center>πŸ€– Phi-2 QLoRA Chatbot</center></h1>")
gr.Markdown("Chat with Microsoft Phi-2 fine-tuned using QLoRA adapters!")
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder="Ask me something...", label="Your Message")
clear = gr.Button("πŸ—‘οΈ Clear Chat")
# Add sliders for controlling generation behavior
with gr.Row():
temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
max_tokens_slider = gr.Slider(64, 1024, value=256, step=64, label="Max Tokens")
# State to hold chat history
state = gr.State([])
# On send message
def on_message(message, history, temperature, top_p, max_tokens):
return generate_response(message, history, temperature, top_p, max_tokens)
# Button actions
msg.submit(on_message,
[msg, state, temp_slider, top_p_slider, max_tokens_slider],
[chatbot, state])
clear.click(lambda: ([], []), None, [chatbot, state])
# Launch the Gradio app
demo.launch()