Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel, PeftConfig | |
import torch | |
import os | |
def load_model(model_id, model_type="base"): | |
try: | |
if model_type == "base": | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
return tokenizer, model | |
else: # finetuned model with PEFT | |
# Load the base model first | |
base_model_id = "satyanayak/gemma-3-base" | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Load and merge the PEFT adapters | |
model = PeftModel.from_pretrained( | |
base_model, | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
return tokenizer, model | |
except Exception as e: | |
print(f"Error loading {model_type} model: {str(e)}") | |
return None, None | |
# Load base model and tokenizer | |
base_model_id = "satyanayak/gemma-3-base" | |
base_tokenizer, base_model = load_model(base_model_id, "base") | |
# Load finetuned model and tokenizer | |
finetuned_model_id = "satyanayak/gemma-3-GRPO" | |
finetuned_tokenizer, finetuned_model = load_model(finetuned_model_id, "finetuned") | |
def generate_base_response(prompt, max_length=512): | |
if base_model is None or base_tokenizer is None: | |
return "Error: Base model failed to load. Please check if the model files are properly uploaded to Hugging Face." | |
try: | |
inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device) | |
outputs = base_model.generate( | |
**inputs, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=base_tokenizer.eos_token_id | |
) | |
response = base_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
return f"Error generating response with base model: {str(e)}" | |
def generate_finetuned_response(prompt, max_length=512): | |
if finetuned_model is None or finetuned_tokenizer is None: | |
return "Error: Finetuned model failed to load. Please check if the model files are properly uploaded to Hugging Face." | |
try: | |
inputs = finetuned_tokenizer(prompt, return_tensors="pt").to(finetuned_model.device) | |
outputs = finetuned_model.generate( | |
**inputs, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=finetuned_tokenizer.eos_token_id | |
) | |
response = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
return f"Error generating response with finetuned model: {str(e)}" | |
# Example prompts | |
examples = [ | |
["What is the sqrt of 101"], | |
["How many r's are in strawberry?"], | |
["If Tom has 3 more apples than Jerry and Jerry has 5 apples, how many apples does Tom have?"] | |
] | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Gemma-3 Model Comparison Demo") | |
gr.Markdown("Compare responses between the base model and the GRPO-finetuned model.") | |
with gr.Row(): | |
# Base Model Column | |
with gr.Column(scale=1): | |
gr.Markdown("## Base Model (Gemma-3)") | |
base_input = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Type your prompt here...", | |
lines=5 | |
) | |
base_generate_btn = gr.Button("Generate with Base Model") | |
base_output = gr.Textbox(label="Base Model Output", lines=10) | |
gr.Examples( | |
examples=examples, | |
inputs=base_input, | |
outputs=base_output, | |
fn=generate_base_response, | |
cache_examples=True | |
) | |
# Finetuned Model Column | |
with gr.Column(scale=1): | |
gr.Markdown("## GRPO-Finetuned Model") | |
finetuned_input = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Type your prompt here...", | |
lines=5 | |
) | |
finetuned_generate_btn = gr.Button("Generate with Finetuned Model") | |
finetuned_output = gr.Textbox(label="Finetuned Model Output", lines=10) | |
gr.Examples( | |
examples=examples, | |
inputs=finetuned_input, | |
outputs=finetuned_output, | |
fn=generate_finetuned_response, | |
cache_examples=True | |
) | |
# Connect buttons to their respective functions | |
base_generate_btn.click( | |
fn=generate_base_response, | |
inputs=base_input, | |
outputs=base_output | |
) | |
finetuned_generate_btn.click( | |
fn=generate_finetuned_response, | |
inputs=finetuned_input, | |
outputs=finetuned_output | |
) | |
if __name__ == "__main__": | |
demo.launch() |