leaderboard / src /evaluation.py
akera's picture
Create evaluation.py
d0ca936 verified
raw
history blame
15.9 kB
# src/evaluation.py
import torch
import numpy as np
from tqdm.auto import tqdm
from sacrebleu.metrics import BLEU, CHRF
from rouge_score import rouge_scorer
import Levenshtein
from collections import defaultdict
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import salt.constants
import datetime
import os
from google.cloud import translate_v3
from config import GOOGLE_LANG_MAP
def setup_google_translate():
"""Setup Google Cloud Translation client if credentials available."""
try:
# Check if running in HF Space with credentials
if os.getenv("GOOGLE_APPLICATION_CREDENTIALS") or os.getenv("GOOGLE_CLOUD_PROJECT"):
client = translate_v3.TranslationServiceClient()
project_id = os.getenv("GOOGLE_CLOUD_PROJECT", "sb-gcp-project-01")
parent = f"projects/{project_id}/locations/global"
return client, parent
else:
print("Google Cloud credentials not found. Google Translate will not be available.")
return None, None
except Exception as e:
print(f"Error setting up Google Translate: {e}")
return None, None
def google_translate_batch(texts, source_langs, target_langs, client, parent):
"""Translate using Google Cloud Translation API."""
translations = []
for text, src_lang, tgt_lang in tqdm(zip(texts, source_langs, target_langs),
total=len(texts), desc="Google Translate"):
try:
# Map SALT language codes to Google's format
src_google = GOOGLE_LANG_MAP.get(src_lang, src_lang)
tgt_google = GOOGLE_LANG_MAP.get(tgt_lang, tgt_lang)
# Check if language pair is supported
supported_langs = ['lg', 'ach', 'sw', 'en']
if src_google not in supported_langs or tgt_google not in supported_langs:
translations.append(f"[UNSUPPORTED: {src_lang}->{tgt_lang}]")
continue
# Make translation request
request = {
"parent": parent,
"contents": [text],
"mime_type": "text/plain",
"source_language_code": src_google,
"target_language_code": tgt_google,
}
response = client.translate_text(request=request)
translation = response.translations[0].translated_text
translations.append(translation)
except Exception as e:
print(f"Error translating '{text}': {e}")
translations.append(f"[ERROR: {str(e)[:50]}]")
return translations
def get_translation_function(model, tokenizer, model_path):
"""Get appropriate translation function based on model type."""
if model_path == 'google-translate':
client, parent = setup_google_translate()
if client is None:
raise Exception("Google Translate credentials not available")
def translation_fn(texts, from_langs, to_langs):
return google_translate_batch(texts, from_langs, to_langs, client, parent)
return translation_fn
elif 'gemma' in str(type(model)).lower() or 'gemma' in model_path.lower():
return get_gemma_translation_fn(model, tokenizer)
elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'Qwen2ForCausalLM' in str(type(model.base_model.model)):
return get_qwen_translation_fn(model, tokenizer)
elif 'm2m_100' in str(type(model)).lower():
return get_nllb_translation_fn(model, tokenizer)
elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'LlamaForCausalLM' in str(type(model.base_model.model)):
return get_llama_translation_fn(model, tokenizer)
else:
# Generic function for other models
return get_generic_translation_fn(model, tokenizer)
def get_gemma_translation_fn(model, tokenizer):
"""Translation function for Gemma models."""
def translation_fn(texts, from_langs, to_langs):
SYSTEM_MESSAGE = 'You are a linguist and translation assistant specialising in Ugandan languages.'
translations = []
batch_size = 4
device = next(model.parameters()).device
instructions = [
f'Translate from {salt.constants.SALT_LANGUAGE_NAMES[from_lang]} '
f'to {salt.constants.SALT_LANGUAGE_NAMES[to_lang]}: {text}'
for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
]
for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"):
batch_instructions = instructions[i:i + batch_size]
messages_list = [
[
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": instruction}
] for instruction in batch_instructions
]
prompts = [
tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
) for messages in messages_list
]
inputs = tokenizer(
prompts, return_tensors="pt",
padding=True, padding_side='left',
max_length=512, truncation=True
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.5,
num_beams=5,
do_sample=True,
no_repeat_ngram_size=5,
pad_token_id=tokenizer.eos_token_id
)
for j in range(len(outputs)):
translation = tokenizer.decode(
outputs[j, inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
translations.append(translation)
return translations
return translation_fn
def get_qwen_translation_fn(model, tokenizer):
"""Translation function for Qwen models."""
def translation_fn(texts, from_langs, to_langs):
SYSTEM_MESSAGE = 'You are a Ugandan language assistant.'
translations = []
batch_size = 8
device = next(model.parameters()).device
instructions = [
f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} '
f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}'
for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
]
for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"):
batch_instructions = instructions[i:i + batch_size]
messages_list = [
[
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": instruction}
] for instruction in batch_instructions
]
prompts = [
tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
) for messages in messages_list
]
inputs = tokenizer(
prompts, return_tensors="pt",
padding=True, padding_side='left', truncation=True
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=100,
temperature=0.01,
pad_token_id=tokenizer.eos_token_id
)
for j in range(len(outputs)):
translation = tokenizer.decode(
outputs[j, inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
translations.append(translation)
return translations
return translation_fn
def get_nllb_translation_fn(model, tokenizer):
"""Translation function for NLLB models."""
def translation_fn(texts, source_langs, target_langs):
translations = []
language_tokens = salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION
device = next(model.parameters()).device
for text, source_language, target_language in tqdm(
zip(texts, source_langs, target_langs), total=len(texts), desc="NLLB Translation"):
inputs = tokenizer(text, return_tensors="pt").to(device)
inputs['input_ids'][0][0] = language_tokens[source_language]
with torch.no_grad():
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=language_tokens[target_language],
max_length=100,
num_beams=5,
)
result = tokenizer.batch_decode(
translated_tokens, skip_special_tokens=True)[0]
translations.append(result)
return translations
return translation_fn
def get_llama_translation_fn(model, tokenizer):
"""Translation function for Llama models."""
def translation_fn(texts, from_langs, to_langs):
DATE_TODAY = datetime.datetime.now().strftime("%d %b %Y")
SYSTEM_MESSAGE = ''
translations = []
batch_size = 8
device = next(model.parameters()).device
instructions = [
f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} '
f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}'
for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
]
for i in tqdm(range(0, len(instructions), batch_size), desc="Llama Translation"):
batch_instructions = instructions[i:i + batch_size]
messages_list = [
[
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": instruction}
] for instruction in batch_instructions
]
prompts = [
tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
date_string=DATE_TODAY,
) for messages in messages_list
]
inputs = tokenizer(
prompts, return_tensors="pt",
padding=True, padding_side='left',
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=100,
temperature=0.01,
pad_token_id=tokenizer.eos_token_id
)
for j in range(len(outputs)):
translation = tokenizer.decode(
outputs[j, inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
translations.append(translation)
return translations
return translation_fn
def get_generic_translation_fn(model, tokenizer):
"""Generic translation function for unknown model types."""
def translation_fn(texts, from_langs, to_langs):
translations = []
device = next(model.parameters()).device
for text, from_lang, to_lang in tqdm(zip(texts, from_langs, to_langs),
desc="Generic Translation"):
prompt = f"Translate from {from_lang} to {to_lang}: {text}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
translation = tokenizer.decode(
outputs[0, inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
translations.append(translation)
return translations
return translation_fn
def calculate_metrics(reference: str, prediction: str) -> dict:
"""Calculate multiple translation quality metrics."""
bleu = BLEU(effective_order=True)
bleu_score = bleu.sentence_score(prediction, [reference]).score
chrf = CHRF()
chrf_score = chrf.sentence_score(prediction, [reference]).score / 100.0
cer = Levenshtein.distance(reference, prediction) / max(len(reference), 1)
ref_words = reference.split()
pred_words = prediction.split()
wer = Levenshtein.distance(ref_words, pred_words) / max(len(ref_words), 1)
len_ratio = len(prediction) / max(len(reference), 1)
metrics = {
"bleu": bleu_score,
"chrf": chrf_score,
"cer": cer,
"wer": wer,
"len_ratio": len_ratio,
}
try:
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
rouge_scores = scorer.score(reference, prediction)
metrics["rouge1"] = rouge_scores['rouge1'].fmeasure
metrics["rouge2"] = rouge_scores['rouge2'].fmeasure
metrics["rougeL"] = rouge_scores['rougeL'].fmeasure
metrics["quality_score"] = (
bleu_score/100 +
chrf_score +
(1-cer) +
(1-wer) +
rouge_scores['rouge1'].fmeasure +
rouge_scores['rougeL'].fmeasure
) / 6
except Exception as e:
print(f"Error calculating ROUGE metrics: {e}")
metrics["quality_score"] = (bleu_score/100 + chrf_score + (1-cer) + (1-wer)) / 4
return metrics
def evaluate_model_full(model, tokenizer, model_path: str, test_data) -> dict:
"""Complete model evaluation pipeline."""
# Get translation function
translation_fn = get_translation_function(model, tokenizer, model_path)
# Generate predictions
print("Generating translations...")
predictions = translation_fn(
list(test_data['source']),
list(test_data['source.language']),
list(test_data['target.language']),
)
# Calculate metrics by language pair
print("Calculating metrics...")
translation_subsets = defaultdict(list)
for idx, row in test_data.iterrows():
direction = row['source.language'] + '_to_' + row['target.language']
row_dict = dict(row)
row_dict['prediction'] = predictions[idx]
translation_subsets[direction].append(row_dict)
normalizer = BasicTextNormalizer()
grouped_metrics = defaultdict(dict)
for subset in translation_subsets.keys():
subset_metrics = defaultdict(list)
for example in translation_subsets[subset]:
prediction = normalizer(str(example['prediction']))
reference = normalizer(example['target'])
metrics = calculate_metrics(reference, prediction)
for m in metrics.keys():
subset_metrics[m].append(metrics[m])
for m in subset_metrics.keys():
if subset_metrics[m]: # Check if list is not empty
grouped_metrics[subset][m] = float(np.mean(subset_metrics[m]))
# Calculate overall averages
all_metrics = list(grouped_metrics.values())[0].keys() if grouped_metrics else []
for m in all_metrics:
metric_values = []
for subset in translation_subsets.keys():
if m in grouped_metrics[subset]:
metric_values.append(grouped_metrics[subset][m])
if metric_values:
grouped_metrics['averages'][m] = float(np.mean(metric_values))
return dict(grouped_metrics)