sudhakar272 commited on
Commit
64631f6
·
verified ·
1 Parent(s): 2a45f14

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +150 -0
  2. requirements .txt +5 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel, PeftConfig
4
+ import torch
5
+ import os
6
+
7
+ def load_model(model_id, model_type="base"):
8
+ try:
9
+ if model_type == "base":
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_id,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto"
15
+ )
16
+ return tokenizer, model
17
+ else: # finetuned model with PEFT
18
+ # Load the base model first
19
+ base_model_id = "satyanayak/gemma-3-base"
20
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
21
+ base_model = AutoModelForCausalLM.from_pretrained(
22
+ base_model_id,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto"
25
+ )
26
+
27
+ # Load and merge the PEFT adapters
28
+ model = PeftModel.from_pretrained(
29
+ base_model,
30
+ model_id,
31
+ torch_dtype=torch.float16,
32
+ device_map="auto"
33
+ )
34
+ return tokenizer, model
35
+ except Exception as e:
36
+ print(f"Error loading {model_type} model: {str(e)}")
37
+ return None, None
38
+
39
+ # Load base model and tokenizer
40
+ base_model_id = "satyanayak/gemma-3-base"
41
+ base_tokenizer, base_model = load_model(base_model_id, "base")
42
+
43
+ # Load finetuned model and tokenizer
44
+ finetuned_model_id = "satyanayak/gemma-3-GRPO"
45
+ finetuned_tokenizer, finetuned_model = load_model(finetuned_model_id, "finetuned")
46
+
47
+ def generate_base_response(prompt, max_length=512):
48
+ if base_model is None or base_tokenizer is None:
49
+ return "Error: Base model failed to load. Please check if the model files are properly uploaded to Hugging Face."
50
+
51
+ try:
52
+ inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device)
53
+ outputs = base_model.generate(
54
+ **inputs,
55
+ max_length=max_length,
56
+ num_return_sequences=1,
57
+ temperature=0.7,
58
+ do_sample=True,
59
+ pad_token_id=base_tokenizer.eos_token_id
60
+ )
61
+ response = base_tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+ return response
63
+ except Exception as e:
64
+ return f"Error generating response with base model: {str(e)}"
65
+
66
+ def generate_finetuned_response(prompt, max_length=512):
67
+ if finetuned_model is None or finetuned_tokenizer is None:
68
+ return "Error: Finetuned model failed to load. Please check if the model files are properly uploaded to Hugging Face."
69
+
70
+ try:
71
+ inputs = finetuned_tokenizer(prompt, return_tensors="pt").to(finetuned_model.device)
72
+ outputs = finetuned_model.generate(
73
+ **inputs,
74
+ max_length=max_length,
75
+ num_return_sequences=1,
76
+ temperature=0.7,
77
+ do_sample=True,
78
+ pad_token_id=finetuned_tokenizer.eos_token_id
79
+ )
80
+ response = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+ return response
82
+ except Exception as e:
83
+ return f"Error generating response with finetuned model: {str(e)}"
84
+
85
+ # Example prompts
86
+ examples = [
87
+ ["What is the sqrt of 101"],
88
+ ["How many r's are in strawberry?"],
89
+ ["If Tom has 3 more apples than Jerry and Jerry has 5 apples, how many apples does Tom have?"]
90
+ ]
91
+
92
+ # Create the Gradio interface
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown("# Gemma-3 Model Comparison Demo")
95
+ gr.Markdown("Compare responses between the base model and the GRPO-finetuned model.")
96
+
97
+ with gr.Row():
98
+ # Base Model Column
99
+ with gr.Column(scale=1):
100
+ gr.Markdown("## Base Model (Gemma-3)")
101
+ base_input = gr.Textbox(
102
+ label="Enter your prompt",
103
+ placeholder="Type your prompt here...",
104
+ lines=5
105
+ )
106
+ base_generate_btn = gr.Button("Generate with Base Model")
107
+ base_output = gr.Textbox(label="Base Model Output", lines=10)
108
+
109
+ gr.Examples(
110
+ examples=examples,
111
+ inputs=base_input,
112
+ outputs=base_output,
113
+ fn=generate_base_response,
114
+ cache_examples=True
115
+ )
116
+
117
+ # Finetuned Model Column
118
+ with gr.Column(scale=1):
119
+ gr.Markdown("## GRPO-Finetuned Model")
120
+ finetuned_input = gr.Textbox(
121
+ label="Enter your prompt",
122
+ placeholder="Type your prompt here...",
123
+ lines=5
124
+ )
125
+ finetuned_generate_btn = gr.Button("Generate with Finetuned Model")
126
+ finetuned_output = gr.Textbox(label="Finetuned Model Output", lines=10)
127
+
128
+ gr.Examples(
129
+ examples=examples,
130
+ inputs=finetuned_input,
131
+ outputs=finetuned_output,
132
+ fn=generate_finetuned_response,
133
+ cache_examples=True
134
+ )
135
+
136
+ # Connect buttons to their respective functions
137
+ base_generate_btn.click(
138
+ fn=generate_base_response,
139
+ inputs=base_input,
140
+ outputs=base_output
141
+ )
142
+
143
+ finetuned_generate_btn.click(
144
+ fn=generate_finetuned_response,
145
+ inputs=finetuned_input,
146
+ outputs=finetuned_output
147
+ )
148
+
149
+ if __name__ == "__main__":
150
+ demo.launch()
requirements .txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.19.2
2
+ transformers>=4.38.0
3
+ torch>=2.2.0
4
+ accelerate>=0.27.0
5
+ peft>=0.9.0