Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, set_seed | |
import re | |
import numpy as np | |
import pandas as pd | |
# Set a seed for reproducibility | |
set_seed(42) | |
# Define five small models for generation (free, lightweight) | |
small_models = [ | |
"distilgpt2", # ~82M parameters | |
"gpt2", # ~124M parameters | |
"EleutherAI/gpt-neo-125M", # ~125M parameters | |
"sshleifer/tiny-gpt2", # extremely small variant | |
"microsoft/DialoGPT-small" # DialoGPT small | |
] | |
# Define five languages: English, German, Spanish, French, Portuguese | |
languages = { | |
"en": "English", | |
"de": "German", | |
"es": "Spanish", | |
"fr": "French", | |
"pt": "Portuguese" | |
} | |
# Define two cost-effective grammar evaluation models (unchanged) | |
grammar_model_names = [ | |
"vennify/t5-base-grammar-correction", | |
"hassaanik/grammar-correction-model" | |
] | |
# Functions to load pipelines on demand | |
def load_generation_pipeline(model_name): | |
try: | |
# Using text-generation pipeline for causal LM models | |
return pipeline("text-generation", model=model_name) | |
except Exception as e: | |
print(f"Error loading generation model {model_name}: {e}") | |
return None | |
def load_grammar_pipeline(model_name): | |
try: | |
return pipeline("text2text-generation", model=model_name) | |
except Exception as e: | |
print(f"Error loading grammar model {model_name}: {e}") | |
return None | |
# Pre-load grammar evaluator pipelines | |
rater_models = [] | |
for model_name in grammar_model_names: | |
p = load_grammar_pipeline(model_name) | |
if p is not None: | |
rater_models.append(p) | |
# Utility functions for checking palindromes and cleaning text | |
def clean_text(text): | |
return re.sub(r'[^a-zA-Z0-9]', '', text.lower()) | |
def is_palindrome(text): | |
cleaned = clean_text(text) | |
return cleaned == cleaned[::-1] | |
# Updated prompt that instructs the model to output ONLY the palindrome. | |
def build_prompt(lang): | |
return ( | |
f"Instruction: Write the longest original palindrome you can in {lang}. " | |
"The output should contain nothing else but the palindrome. " | |
"Do not include any additional commentary or repeated instructions. " | |
"Palindrome: " | |
) | |
def grammar_prompt(pal, lang): | |
return f'''Rate from 0 to 100 how grammatically correct this palindrome is in {lang}. Only return a number with no explanation:\n\n"{pal}"\n''' | |
def extract_score(text): | |
match = re.search(r"\d{1,3}", text) | |
if match: | |
score = int(match.group()) | |
return min(max(score, 0), 100) | |
return 0 | |
# Main benchmark function that runs all tests at once | |
def run_benchmark_all(): | |
results = [] | |
# Iterate over each small model | |
for model_name in small_models: | |
gen_pipeline = load_generation_pipeline(model_name) | |
if gen_pipeline is None: | |
continue # Skip this model if it fails to load | |
# Iterate over each language | |
for code, lang in languages.items(): | |
prompt = build_prompt(lang) | |
try: | |
# Generate text with a moderate max token limit | |
gen_output = gen_pipeline(prompt, max_new_tokens=50, do_sample=True)[0]['generated_text'].strip() | |
except Exception as e: | |
gen_output = f"Error generating text: {e}" | |
# Check if the generated output is a palindrome | |
valid = is_palindrome(gen_output) | |
cleaned_len = len(clean_text(gen_output)) | |
# Evaluate grammar using both grammar models | |
scores = [] | |
for rater in rater_models: | |
rprompt = grammar_prompt(gen_output, lang) | |
try: | |
rtext = rater(rprompt, max_new_tokens=10)[0]['generated_text'] | |
score = extract_score(rtext) | |
scores.append(score) | |
except Exception as e: | |
scores.append(0) | |
avg_score = np.mean(scores) if scores else 0 | |
# Penalize if the generated text is not a valid palindrome | |
penalty = (avg_score / 100) if valid else (avg_score / 100) * 0.5 | |
final_score = round(cleaned_len * penalty, 2) | |
results.append({ | |
"Model": model_name, | |
"Language": lang, | |
"Palindrome": gen_output, | |
"Valid": "✅" if valid else "❌", | |
"Length": cleaned_len, | |
"Grammar Score": avg_score, | |
"Final Score": final_score | |
}) | |
df = pd.DataFrame(results).sort_values(by="Final Score", ascending=False).reset_index(drop=True) | |
return gr.Dataframe(df) | |
# Build the Gradio UI using Blocks (canvas layout) | |
with gr.Blocks(title="Small Model Palindrome Benchmark") as demo: | |
gr.Markdown("# Small Model Palindrome Benchmark") | |
gr.Markdown("This benchmark automatically runs over 5 small text-generation models and 5 languages (English, German, Spanish, French, Portuguese).") | |
with gr.Row(): | |
run_button = gr.Button("Run All Benchmarks") | |
output_table = gr.Dataframe(label="Benchmark Results") | |
run_button.click(fn=run_benchmark_all, inputs=[], outputs=output_table) | |
demo.launch() | |