File size: 5,251 Bytes
12a6276
9d5d030
12a6276
 
 
 
9d5d030
 
 
4136261
 
 
 
 
 
3fb2bff
9d5d030
 
3fb2bff
4136261
 
 
 
 
 
 
 
3fb2bff
9d5d030
 
 
12a6276
 
4136261
9d5d030
 
3fb2bff
9d5d030
 
 
 
 
 
 
 
 
 
 
 
4136261
9d5d030
 
 
 
 
 
4136261
12a6276
 
 
 
 
 
 
3fb2bff
 
 
 
 
 
 
 
 
12a6276
 
 
 
 
 
 
 
 
 
4136261
 
12a6276
3fb2bff
4136261
 
 
3fb2bff
4136261
3fb2bff
4136261
3fb2bff
9d5d030
3fb2bff
4136261
9d5d030
4136261
3fb2bff
4136261
 
 
3fb2bff
4136261
 
 
 
 
 
 
 
 
 
3fb2bff
4136261
 
 
 
 
 
 
 
 
 
 
 
9d5d030
12a6276
 
 
3fb2bff
4136261
 
3fb2bff
9d5d030
 
4136261
9d5d030
 
4136261
9d5d030
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from transformers import pipeline, set_seed
import re
import numpy as np
import pandas as pd

# Set a seed for reproducibility
set_seed(42)

# Define five small models for generation (free, lightweight)
small_models = [
    "distilgpt2",                    # ~82M parameters
    "gpt2",                          # ~124M parameters
    "EleutherAI/gpt-neo-125M",         # ~125M parameters
    "sshleifer/tiny-gpt2",           # extremely small variant
    "microsoft/DialoGPT-small"       # DialoGPT small
]

# Define five languages: English, German, Spanish, French, Portuguese
languages = {
    "en": "English",
    "de": "German",
    "es": "Spanish",
    "fr": "French",
    "pt": "Portuguese"
}

# Define two cost-effective grammar evaluation models (unchanged)
grammar_model_names = [
    "vennify/t5-base-grammar-correction",
    "hassaanik/grammar-correction-model"
]

# Functions to load pipelines on demand
def load_generation_pipeline(model_name):
    try:
        # Using text-generation pipeline for causal LM models
        return pipeline("text-generation", model=model_name)
    except Exception as e:
        print(f"Error loading generation model {model_name}: {e}")
        return None

def load_grammar_pipeline(model_name):
    try:
        return pipeline("text2text-generation", model=model_name)
    except Exception as e:
        print(f"Error loading grammar model {model_name}: {e}")
        return None

# Pre-load grammar evaluator pipelines
rater_models = []
for model_name in grammar_model_names:
    p = load_grammar_pipeline(model_name)
    if p is not None:
        rater_models.append(p)

# Utility functions for checking palindromes and cleaning text
def clean_text(text):
    return re.sub(r'[^a-zA-Z0-9]', '', text.lower())

def is_palindrome(text):
    cleaned = clean_text(text)
    return cleaned == cleaned[::-1]

# Updated prompt that instructs the model to output ONLY the palindrome.
def build_prompt(lang):
    return (
        f"Instruction: Write the longest original palindrome you can in {lang}. "
        "The output should contain nothing else but the palindrome. "
        "Do not include any additional commentary or repeated instructions. "
        "Palindrome: "
    )

def grammar_prompt(pal, lang):
    return f'''Rate from 0 to 100 how grammatically correct this palindrome is in {lang}. Only return a number with no explanation:\n\n"{pal}"\n'''

def extract_score(text):
    match = re.search(r"\d{1,3}", text)
    if match:
        score = int(match.group())
        return min(max(score, 0), 100)
    return 0

# Main benchmark function that runs all tests at once
def run_benchmark_all():
    results = []
    # Iterate over each small model
    for model_name in small_models:
        gen_pipeline = load_generation_pipeline(model_name)
        if gen_pipeline is None:
            continue  # Skip this model if it fails to load

        # Iterate over each language
        for code, lang in languages.items():
            prompt = build_prompt(lang)
            try:
                # Generate text with a moderate max token limit
                gen_output = gen_pipeline(prompt, max_new_tokens=50, do_sample=True)[0]['generated_text'].strip()
            except Exception as e:
                gen_output = f"Error generating text: {e}"
            # Check if the generated output is a palindrome
            valid = is_palindrome(gen_output)
            cleaned_len = len(clean_text(gen_output))
            
            # Evaluate grammar using both grammar models
            scores = []
            for rater in rater_models:
                rprompt = grammar_prompt(gen_output, lang)
                try:
                    rtext = rater(rprompt, max_new_tokens=10)[0]['generated_text']
                    score = extract_score(rtext)
                    scores.append(score)
                except Exception as e:
                    scores.append(0)
            avg_score = np.mean(scores) if scores else 0
            # Penalize if the generated text is not a valid palindrome
            penalty = (avg_score / 100) if valid else (avg_score / 100) * 0.5
            final_score = round(cleaned_len * penalty, 2)
            
            results.append({
                "Model": model_name,
                "Language": lang,
                "Palindrome": gen_output,
                "Valid": "✅" if valid else "❌",
                "Length": cleaned_len,
                "Grammar Score": avg_score,
                "Final Score": final_score
            })
    
    df = pd.DataFrame(results).sort_values(by="Final Score", ascending=False).reset_index(drop=True)
    return gr.Dataframe(df)

# Build the Gradio UI using Blocks (canvas layout)
with gr.Blocks(title="Small Model Palindrome Benchmark") as demo:
    gr.Markdown("# Small Model Palindrome Benchmark")
    gr.Markdown("This benchmark automatically runs over 5 small text-generation models and 5 languages (English, German, Spanish, French, Portuguese).")
    
    with gr.Row():
        run_button = gr.Button("Run All Benchmarks")
    output_table = gr.Dataframe(label="Benchmark Results")
    
    run_button.click(fn=run_benchmark_all, inputs=[], outputs=output_table)

demo.launch()