import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel, PeftConfig import torch import os def load_model(model_id, model_type="base"): try: if model_type == "base": tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto" ) return tokenizer, model else: # finetuned model with PEFT # Load the base model first base_model_id = "satyanayak/gemma-3-base" tokenizer = AutoTokenizer.from_pretrained(base_model_id) base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto" ) # Load and merge the PEFT adapters model = PeftModel.from_pretrained( base_model, model_id, torch_dtype=torch.float16, device_map="auto" ) return tokenizer, model except Exception as e: print(f"Error loading {model_type} model: {str(e)}") return None, None # Load base model and tokenizer base_model_id = "satyanayak/gemma-3-base" base_tokenizer, base_model = load_model(base_model_id, "base") # Load finetuned model and tokenizer finetuned_model_id = "satyanayak/gemma-3-GRPO" finetuned_tokenizer, finetuned_model = load_model(finetuned_model_id, "finetuned") def generate_base_response(prompt, max_length=512): if base_model is None or base_tokenizer is None: return "Error: Base model failed to load. Please check if the model files are properly uploaded to Hugging Face." try: inputs = base_tokenizer(prompt, return_tensors="pt").to(base_model.device) outputs = base_model.generate( **inputs, max_length=max_length, num_return_sequences=1, temperature=0.7, do_sample=True, pad_token_id=base_tokenizer.eos_token_id ) response = base_tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: return f"Error generating response with base model: {str(e)}" def generate_finetuned_response(prompt, max_length=512): if finetuned_model is None or finetuned_tokenizer is None: return "Error: Finetuned model failed to load. Please check if the model files are properly uploaded to Hugging Face." try: inputs = finetuned_tokenizer(prompt, return_tensors="pt").to(finetuned_model.device) outputs = finetuned_model.generate( **inputs, max_length=max_length, num_return_sequences=1, temperature=0.7, do_sample=True, pad_token_id=finetuned_tokenizer.eos_token_id ) response = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: return f"Error generating response with finetuned model: {str(e)}" # Example prompts examples = [ ["What is the sqrt of 101"], ["How many r's are in strawberry?"], ["If Tom has 3 more apples than Jerry and Jerry has 5 apples, how many apples does Tom have?"] ] # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Gemma-3 Model Comparison Demo") gr.Markdown("Compare responses between the base model and the GRPO-finetuned model.") with gr.Row(): # Base Model Column with gr.Column(scale=1): gr.Markdown("## Base Model (Gemma-3)") base_input = gr.Textbox( label="Enter your prompt", placeholder="Type your prompt here...", lines=5 ) base_generate_btn = gr.Button("Generate with Base Model") base_output = gr.Textbox(label="Base Model Output", lines=10) gr.Examples( examples=examples, inputs=base_input, outputs=base_output, fn=generate_base_response, cache_examples=True ) # Finetuned Model Column with gr.Column(scale=1): gr.Markdown("## GRPO-Finetuned Model") finetuned_input = gr.Textbox( label="Enter your prompt", placeholder="Type your prompt here...", lines=5 ) finetuned_generate_btn = gr.Button("Generate with Finetuned Model") finetuned_output = gr.Textbox(label="Finetuned Model Output", lines=10) gr.Examples( examples=examples, inputs=finetuned_input, outputs=finetuned_output, fn=generate_finetuned_response, cache_examples=True ) # Connect buttons to their respective functions base_generate_btn.click( fn=generate_base_response, inputs=base_input, outputs=base_output ) finetuned_generate_btn.click( fn=generate_finetuned_response, inputs=finetuned_input, outputs=finetuned_output ) if __name__ == "__main__": demo.launch()