palindroms / app.py
PabloTJ's picture
Update app.py
3fb2bff verified
raw
history blame
5.25 kB
import gradio as gr
from transformers import pipeline, set_seed
import re
import numpy as np
import pandas as pd
# Set a seed for reproducibility
set_seed(42)
# Define five small models for generation (free, lightweight)
small_models = [
"distilgpt2", # ~82M parameters
"gpt2", # ~124M parameters
"EleutherAI/gpt-neo-125M", # ~125M parameters
"sshleifer/tiny-gpt2", # extremely small variant
"microsoft/DialoGPT-small" # DialoGPT small
]
# Define five languages: English, German, Spanish, French, Portuguese
languages = {
"en": "English",
"de": "German",
"es": "Spanish",
"fr": "French",
"pt": "Portuguese"
}
# Define two cost-effective grammar evaluation models (unchanged)
grammar_model_names = [
"vennify/t5-base-grammar-correction",
"hassaanik/grammar-correction-model"
]
# Functions to load pipelines on demand
def load_generation_pipeline(model_name):
try:
# Using text-generation pipeline for causal LM models
return pipeline("text-generation", model=model_name)
except Exception as e:
print(f"Error loading generation model {model_name}: {e}")
return None
def load_grammar_pipeline(model_name):
try:
return pipeline("text2text-generation", model=model_name)
except Exception as e:
print(f"Error loading grammar model {model_name}: {e}")
return None
# Pre-load grammar evaluator pipelines
rater_models = []
for model_name in grammar_model_names:
p = load_grammar_pipeline(model_name)
if p is not None:
rater_models.append(p)
# Utility functions for checking palindromes and cleaning text
def clean_text(text):
return re.sub(r'[^a-zA-Z0-9]', '', text.lower())
def is_palindrome(text):
cleaned = clean_text(text)
return cleaned == cleaned[::-1]
# Updated prompt that instructs the model to output ONLY the palindrome.
def build_prompt(lang):
return (
f"Instruction: Write the longest original palindrome you can in {lang}. "
"The output should contain nothing else but the palindrome. "
"Do not include any additional commentary or repeated instructions. "
"Palindrome: "
)
def grammar_prompt(pal, lang):
return f'''Rate from 0 to 100 how grammatically correct this palindrome is in {lang}. Only return a number with no explanation:\n\n"{pal}"\n'''
def extract_score(text):
match = re.search(r"\d{1,3}", text)
if match:
score = int(match.group())
return min(max(score, 0), 100)
return 0
# Main benchmark function that runs all tests at once
def run_benchmark_all():
results = []
# Iterate over each small model
for model_name in small_models:
gen_pipeline = load_generation_pipeline(model_name)
if gen_pipeline is None:
continue # Skip this model if it fails to load
# Iterate over each language
for code, lang in languages.items():
prompt = build_prompt(lang)
try:
# Generate text with a moderate max token limit
gen_output = gen_pipeline(prompt, max_new_tokens=50, do_sample=True)[0]['generated_text'].strip()
except Exception as e:
gen_output = f"Error generating text: {e}"
# Check if the generated output is a palindrome
valid = is_palindrome(gen_output)
cleaned_len = len(clean_text(gen_output))
# Evaluate grammar using both grammar models
scores = []
for rater in rater_models:
rprompt = grammar_prompt(gen_output, lang)
try:
rtext = rater(rprompt, max_new_tokens=10)[0]['generated_text']
score = extract_score(rtext)
scores.append(score)
except Exception as e:
scores.append(0)
avg_score = np.mean(scores) if scores else 0
# Penalize if the generated text is not a valid palindrome
penalty = (avg_score / 100) if valid else (avg_score / 100) * 0.5
final_score = round(cleaned_len * penalty, 2)
results.append({
"Model": model_name,
"Language": lang,
"Palindrome": gen_output,
"Valid": "✅" if valid else "❌",
"Length": cleaned_len,
"Grammar Score": avg_score,
"Final Score": final_score
})
df = pd.DataFrame(results).sort_values(by="Final Score", ascending=False).reset_index(drop=True)
return gr.Dataframe(df)
# Build the Gradio UI using Blocks (canvas layout)
with gr.Blocks(title="Small Model Palindrome Benchmark") as demo:
gr.Markdown("# Small Model Palindrome Benchmark")
gr.Markdown("This benchmark automatically runs over 5 small text-generation models and 5 languages (English, German, Spanish, French, Portuguese).")
with gr.Row():
run_button = gr.Button("Run All Benchmarks")
output_table = gr.Dataframe(label="Benchmark Results")
run_button.click(fn=run_benchmark_all, inputs=[], outputs=output_table)
demo.launch()