Spaces:
Runtime error
Runtime error
File size: 5,406 Bytes
64631f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import os
def load_model(model_id, model_type="base"):
try:
if model_type == "base":
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
return tokenizer, model
else: # finetuned model with PEFT
# Load the base model first
base_model_id = "satyanayak/gemma-3-base"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto"
)
# Load and merge the PEFT adapters
model = PeftModel.from_pretrained(
base_model,
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
return tokenizer, model
except Exception as e:
print(f"Error loading {model_type} model: {str(e)}")
return None, None
# Load base model and tokenizer
base_model_id = "satyanayak/gemma-3-base"
base_tokenizer, base_model = load_model(base_model_id, "base")
# Load finetuned model and tokenizer
finetuned_model_id = "satyanayak/gemma-3-GRPO"
finetuned_tokenizer, finetuned_model = load_model(finetuned_model_id, "finetuned")
def generate_base_response(prompt, max_length=512):
if base_model is None or base_tokenizer is None:
return "Error: Base model failed to load. Please check if the model files are properly uploaded to Hugging Face."
try:
inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device)
outputs = base_model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=base_tokenizer.eos_token_id
)
response = base_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Error generating response with base model: {str(e)}"
def generate_finetuned_response(prompt, max_length=512):
if finetuned_model is None or finetuned_tokenizer is None:
return "Error: Finetuned model failed to load. Please check if the model files are properly uploaded to Hugging Face."
try:
inputs = finetuned_tokenizer(prompt, return_tensors="pt").to(finetuned_model.device)
outputs = finetuned_model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=finetuned_tokenizer.eos_token_id
)
response = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Error generating response with finetuned model: {str(e)}"
# Example prompts
examples = [
["What is the sqrt of 101"],
["How many r's are in strawberry?"],
["If Tom has 3 more apples than Jerry and Jerry has 5 apples, how many apples does Tom have?"]
]
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Gemma-3 Model Comparison Demo")
gr.Markdown("Compare responses between the base model and the GRPO-finetuned model.")
with gr.Row():
# Base Model Column
with gr.Column(scale=1):
gr.Markdown("## Base Model (Gemma-3)")
base_input = gr.Textbox(
label="Enter your prompt",
placeholder="Type your prompt here...",
lines=5
)
base_generate_btn = gr.Button("Generate with Base Model")
base_output = gr.Textbox(label="Base Model Output", lines=10)
gr.Examples(
examples=examples,
inputs=base_input,
outputs=base_output,
fn=generate_base_response,
cache_examples=True
)
# Finetuned Model Column
with gr.Column(scale=1):
gr.Markdown("## GRPO-Finetuned Model")
finetuned_input = gr.Textbox(
label="Enter your prompt",
placeholder="Type your prompt here...",
lines=5
)
finetuned_generate_btn = gr.Button("Generate with Finetuned Model")
finetuned_output = gr.Textbox(label="Finetuned Model Output", lines=10)
gr.Examples(
examples=examples,
inputs=finetuned_input,
outputs=finetuned_output,
fn=generate_finetuned_response,
cache_examples=True
)
# Connect buttons to their respective functions
base_generate_btn.click(
fn=generate_base_response,
inputs=base_input,
outputs=base_output
)
finetuned_generate_btn.click(
fn=generate_finetuned_response,
inputs=finetuned_input,
outputs=finetuned_output
)
if __name__ == "__main__":
demo.launch() |