File size: 5,406 Bytes
64631f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import os

def load_model(model_id, model_type="base"):
    try:
        if model_type == "base":
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            return tokenizer, model
        else:  # finetuned model with PEFT
            # Load the base model first
            base_model_id = "satyanayak/gemma-3-base"
            tokenizer = AutoTokenizer.from_pretrained(base_model_id)
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_id,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            
            # Load and merge the PEFT adapters
            model = PeftModel.from_pretrained(
                base_model,
                model_id,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            return tokenizer, model
    except Exception as e:
        print(f"Error loading {model_type} model: {str(e)}")
        return None, None

# Load base model and tokenizer
base_model_id = "satyanayak/gemma-3-base"
base_tokenizer, base_model = load_model(base_model_id, "base")

# Load finetuned model and tokenizer
finetuned_model_id = "satyanayak/gemma-3-GRPO"
finetuned_tokenizer, finetuned_model = load_model(finetuned_model_id, "finetuned")

def generate_base_response(prompt, max_length=512):
    if base_model is None or base_tokenizer is None:
        return "Error: Base model failed to load. Please check if the model files are properly uploaded to Hugging Face."
    
    try:
        inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device)
        outputs = base_model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            pad_token_id=base_tokenizer.eos_token_id
        )
        response = base_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response
    except Exception as e:
        return f"Error generating response with base model: {str(e)}"

def generate_finetuned_response(prompt, max_length=512):
    if finetuned_model is None or finetuned_tokenizer is None:
        return "Error: Finetuned model failed to load. Please check if the model files are properly uploaded to Hugging Face."
    
    try:
        inputs = finetuned_tokenizer(prompt, return_tensors="pt").to(finetuned_model.device)
        outputs = finetuned_model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            pad_token_id=finetuned_tokenizer.eos_token_id
        )
        response = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response
    except Exception as e:
        return f"Error generating response with finetuned model: {str(e)}"

# Example prompts
examples = [
    ["What is the sqrt of 101"],
    ["How many r's are in strawberry?"],
    ["If Tom has 3 more apples than Jerry and Jerry has 5 apples, how many apples does Tom have?"]
]

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Gemma-3 Model Comparison Demo")
    gr.Markdown("Compare responses between the base model and the GRPO-finetuned model.")
    
    with gr.Row():
        # Base Model Column
        with gr.Column(scale=1):
            gr.Markdown("## Base Model (Gemma-3)")
            base_input = gr.Textbox(
                label="Enter your prompt",
                placeholder="Type your prompt here...",
                lines=5
            )
            base_generate_btn = gr.Button("Generate with Base Model")
            base_output = gr.Textbox(label="Base Model Output", lines=10)
            
            gr.Examples(
                examples=examples,
                inputs=base_input,
                outputs=base_output,
                fn=generate_base_response,
                cache_examples=True
            )
        
        # Finetuned Model Column
        with gr.Column(scale=1):
            gr.Markdown("## GRPO-Finetuned Model")
            finetuned_input = gr.Textbox(
                label="Enter your prompt",
                placeholder="Type your prompt here...",
                lines=5
            )
            finetuned_generate_btn = gr.Button("Generate with Finetuned Model")
            finetuned_output = gr.Textbox(label="Finetuned Model Output", lines=10)
            
            gr.Examples(
                examples=examples,
                inputs=finetuned_input,
                outputs=finetuned_output,
                fn=generate_finetuned_response,
                cache_examples=True
            )
    
    # Connect buttons to their respective functions
    base_generate_btn.click(
        fn=generate_base_response,
        inputs=base_input,
        outputs=base_output
    )
    
    finetuned_generate_btn.click(
        fn=generate_finetuned_response,
        inputs=finetuned_input,
        outputs=finetuned_output
    )

if __name__ == "__main__":
    demo.launch()