FineTuning-GPRO / app.py
sudhakar272's picture
Upload 2 files
64631f6 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import os
def load_model(model_id, model_type="base"):
try:
if model_type == "base":
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
return tokenizer, model
else: # finetuned model with PEFT
# Load the base model first
base_model_id = "satyanayak/gemma-3-base"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto"
)
# Load and merge the PEFT adapters
model = PeftModel.from_pretrained(
base_model,
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
return tokenizer, model
except Exception as e:
print(f"Error loading {model_type} model: {str(e)}")
return None, None
# Load base model and tokenizer
base_model_id = "satyanayak/gemma-3-base"
base_tokenizer, base_model = load_model(base_model_id, "base")
# Load finetuned model and tokenizer
finetuned_model_id = "satyanayak/gemma-3-GRPO"
finetuned_tokenizer, finetuned_model = load_model(finetuned_model_id, "finetuned")
def generate_base_response(prompt, max_length=512):
if base_model is None or base_tokenizer is None:
return "Error: Base model failed to load. Please check if the model files are properly uploaded to Hugging Face."
try:
inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device)
outputs = base_model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=base_tokenizer.eos_token_id
)
response = base_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Error generating response with base model: {str(e)}"
def generate_finetuned_response(prompt, max_length=512):
if finetuned_model is None or finetuned_tokenizer is None:
return "Error: Finetuned model failed to load. Please check if the model files are properly uploaded to Hugging Face."
try:
inputs = finetuned_tokenizer(prompt, return_tensors="pt").to(finetuned_model.device)
outputs = finetuned_model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=finetuned_tokenizer.eos_token_id
)
response = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Error generating response with finetuned model: {str(e)}"
# Example prompts
examples = [
["What is the sqrt of 101"],
["How many r's are in strawberry?"],
["If Tom has 3 more apples than Jerry and Jerry has 5 apples, how many apples does Tom have?"]
]
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Gemma-3 Model Comparison Demo")
gr.Markdown("Compare responses between the base model and the GRPO-finetuned model.")
with gr.Row():
# Base Model Column
with gr.Column(scale=1):
gr.Markdown("## Base Model (Gemma-3)")
base_input = gr.Textbox(
label="Enter your prompt",
placeholder="Type your prompt here...",
lines=5
)
base_generate_btn = gr.Button("Generate with Base Model")
base_output = gr.Textbox(label="Base Model Output", lines=10)
gr.Examples(
examples=examples,
inputs=base_input,
outputs=base_output,
fn=generate_base_response,
cache_examples=True
)
# Finetuned Model Column
with gr.Column(scale=1):
gr.Markdown("## GRPO-Finetuned Model")
finetuned_input = gr.Textbox(
label="Enter your prompt",
placeholder="Type your prompt here...",
lines=5
)
finetuned_generate_btn = gr.Button("Generate with Finetuned Model")
finetuned_output = gr.Textbox(label="Finetuned Model Output", lines=10)
gr.Examples(
examples=examples,
inputs=finetuned_input,
outputs=finetuned_output,
fn=generate_finetuned_response,
cache_examples=True
)
# Connect buttons to their respective functions
base_generate_btn.click(
fn=generate_base_response,
inputs=base_input,
outputs=base_output
)
finetuned_generate_btn.click(
fn=generate_finetuned_response,
inputs=finetuned_input,
outputs=finetuned_output
)
if __name__ == "__main__":
demo.launch()