Shriti09 commited on
Commit
c91a27e
·
verified ·
1 Parent(s): 3fe707b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -25
app.py CHANGED
@@ -3,47 +3,81 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  import gradio as gr
5
 
6
- # Base model and adapter repo
7
  BASE_MODEL_NAME = "microsoft/phi-2"
8
- ADAPTER_REPO = "Shriti09/Microsoft-Phi-QLora"
9
 
10
- # Load the tokenizer
11
  print("Loading tokenizer...")
12
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
13
  tokenizer.pad_token = tokenizer.eos_token
14
 
15
- # Load the base model
16
  print("Loading base model...")
17
- base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, device_map="auto")
18
 
19
- # Load adapter weights
20
  print("Loading LoRA adapter...")
21
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
22
 
23
- # Merge adapter into base model (optional, makes inference simpler)
24
  model = model.merge_and_unload()
25
-
26
- # Put model in eval mode
27
  model.eval()
28
 
29
- # Function to generate response from prompt
30
- def generate_response(prompt):
31
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
32
  outputs = model.generate(
33
  **inputs,
34
- max_length=256,
35
  do_sample=True,
36
- top_p=0.95,
37
- temperature=0.7
 
38
  )
 
 
39
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
- return response
41
-
42
- # Gradio UI
43
- gr.Interface(
44
- fn=generate_response,
45
- inputs=gr.Textbox(lines=2, placeholder="Ask me something..."),
46
- outputs="text",
47
- title="Phi-2 QLoRA Chatbot",
48
- description="Chat with Phi-2 fine-tuned with QLoRA adapters!"
49
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from peft import PeftModel
4
  import gradio as gr
5
 
6
+ # Model Names
7
  BASE_MODEL_NAME = "microsoft/phi-2"
8
+ ADAPTER_REPO = "Shriti09/Microsoft-Phi-QLora"
9
 
10
+ # Load tokenizer and model
11
  print("Loading tokenizer...")
12
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
13
  tokenizer.pad_token = tokenizer.eos_token
14
 
 
15
  print("Loading base model...")
16
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME, device_map="auto", torch_dtype=torch.float16)
17
 
 
18
  print("Loading LoRA adapter...")
19
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
20
 
21
+ # Merge adapter into the base model
22
  model = model.merge_and_unload()
 
 
23
  model.eval()
24
 
25
+ # Function to generate responses
26
+ def generate_response(message, chat_history, temperature, top_p, max_tokens):
27
+ # Combine history with the new message
28
+ full_prompt = ""
29
+ for user_msg, bot_msg in chat_history:
30
+ full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n"
31
+ full_prompt += f"User: {message}\nAI:"
32
+
33
+ # Tokenize and generate
34
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
35
  outputs = model.generate(
36
  **inputs,
37
+ max_length=len(inputs["input_ids"][0]) + max_tokens,
38
  do_sample=True,
39
+ temperature=temperature,
40
+ top_p=top_p,
41
+ pad_token_id=tokenizer.eos_token_id
42
  )
43
+
44
+ # Decode and extract the AI response
45
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ # Only return the new part of the response
47
+ response = response.split("AI:")[-1].strip()
48
+
49
+ # Update history
50
+ chat_history.append((message, response))
51
+ return chat_history, chat_history
52
+
53
+ # Gradio UI with Blocks
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("<h1><center>🤖 Phi-2 QLoRA Chatbot</center></h1>")
56
+ gr.Markdown("Chat with Microsoft Phi-2 fine-tuned using QLoRA adapters!")
57
+
58
+ chatbot = gr.Chatbot()
59
+ msg = gr.Textbox(placeholder="Ask me something...", label="Your Message")
60
+ clear = gr.Button("🗑️ Clear Chat")
61
+
62
+ # Add sliders for controlling generation behavior
63
+ with gr.Row():
64
+ temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
65
+ top_p_slider = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
66
+ max_tokens_slider = gr.Slider(64, 1024, value=256, step=64, label="Max Tokens")
67
+
68
+ # State to hold chat history
69
+ state = gr.State([])
70
+
71
+ # On send message
72
+ def on_message(message, history, temperature, top_p, max_tokens):
73
+ return generate_response(message, history, temperature, top_p, max_tokens)
74
+
75
+ # Button actions
76
+ msg.submit(on_message,
77
+ [msg, state, temp_slider, top_p_slider, max_tokens_slider],
78
+ [chatbot, state])
79
+
80
+ clear.click(lambda: ([], []), None, [chatbot, state])
81
+
82
+ # Launch the Gradio app
83
+ demo.launch()