Spaces:
Sleeping
Sleeping
# src/evaluation.py | |
import torch | |
import numpy as np | |
from tqdm.auto import tqdm | |
from sacrebleu.metrics import BLEU, CHRF | |
from rouge_score import rouge_scorer | |
import Levenshtein | |
from collections import defaultdict | |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer | |
import salt.constants | |
import datetime | |
import os | |
from google.cloud import translate_v3 | |
from config import GOOGLE_LANG_MAP | |
def setup_google_translate(): | |
"""Setup Google Cloud Translation client if credentials available.""" | |
try: | |
# Check if running in HF Space with credentials | |
if os.getenv("GOOGLE_APPLICATION_CREDENTIALS") or os.getenv("GOOGLE_CLOUD_PROJECT"): | |
client = translate_v3.TranslationServiceClient() | |
project_id = os.getenv("GOOGLE_CLOUD_PROJECT", "sb-gcp-project-01") | |
parent = f"projects/{project_id}/locations/global" | |
return client, parent | |
else: | |
print("Google Cloud credentials not found. Google Translate will not be available.") | |
return None, None | |
except Exception as e: | |
print(f"Error setting up Google Translate: {e}") | |
return None, None | |
def google_translate_batch(texts, source_langs, target_langs, client, parent): | |
"""Translate using Google Cloud Translation API.""" | |
translations = [] | |
for text, src_lang, tgt_lang in tqdm(zip(texts, source_langs, target_langs), | |
total=len(texts), desc="Google Translate"): | |
try: | |
# Map SALT language codes to Google's format | |
src_google = GOOGLE_LANG_MAP.get(src_lang, src_lang) | |
tgt_google = GOOGLE_LANG_MAP.get(tgt_lang, tgt_lang) | |
# Check if language pair is supported | |
supported_langs = ['lg', 'ach', 'sw', 'en'] | |
if src_google not in supported_langs or tgt_google not in supported_langs: | |
translations.append(f"[UNSUPPORTED: {src_lang}->{tgt_lang}]") | |
continue | |
# Make translation request | |
request = { | |
"parent": parent, | |
"contents": [text], | |
"mime_type": "text/plain", | |
"source_language_code": src_google, | |
"target_language_code": tgt_google, | |
} | |
response = client.translate_text(request=request) | |
translation = response.translations[0].translated_text | |
translations.append(translation) | |
except Exception as e: | |
print(f"Error translating '{text}': {e}") | |
translations.append(f"[ERROR: {str(e)[:50]}]") | |
return translations | |
def get_translation_function(model, tokenizer, model_path): | |
"""Get appropriate translation function based on model type.""" | |
if model_path == 'google-translate': | |
client, parent = setup_google_translate() | |
if client is None: | |
raise Exception("Google Translate credentials not available") | |
def translation_fn(texts, from_langs, to_langs): | |
return google_translate_batch(texts, from_langs, to_langs, client, parent) | |
return translation_fn | |
elif 'gemma' in str(type(model)).lower() or 'gemma' in model_path.lower(): | |
return get_gemma_translation_fn(model, tokenizer) | |
elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'Qwen2ForCausalLM' in str(type(model.base_model.model)): | |
return get_qwen_translation_fn(model, tokenizer) | |
elif 'm2m_100' in str(type(model)).lower(): | |
return get_nllb_translation_fn(model, tokenizer) | |
elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'LlamaForCausalLM' in str(type(model.base_model.model)): | |
return get_llama_translation_fn(model, tokenizer) | |
else: | |
# Generic function for other models | |
return get_generic_translation_fn(model, tokenizer) | |
def get_gemma_translation_fn(model, tokenizer): | |
"""Translation function for Gemma models.""" | |
def translation_fn(texts, from_langs, to_langs): | |
SYSTEM_MESSAGE = 'You are a linguist and translation assistant specialising in Ugandan languages.' | |
translations = [] | |
batch_size = 4 | |
device = next(model.parameters()).device | |
instructions = [ | |
f'Translate from {salt.constants.SALT_LANGUAGE_NAMES[from_lang]} ' | |
f'to {salt.constants.SALT_LANGUAGE_NAMES[to_lang]}: {text}' | |
for text, from_lang, to_lang in zip(texts, from_langs, to_langs) | |
] | |
for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"): | |
batch_instructions = instructions[i:i + batch_size] | |
messages_list = [ | |
[ | |
{"role": "system", "content": SYSTEM_MESSAGE}, | |
{"role": "user", "content": instruction} | |
] for instruction in batch_instructions | |
] | |
prompts = [ | |
tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) for messages in messages_list | |
] | |
inputs = tokenizer( | |
prompts, return_tensors="pt", | |
padding=True, padding_side='left', | |
max_length=512, truncation=True | |
).to(device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
temperature=0.5, | |
num_beams=5, | |
do_sample=True, | |
no_repeat_ngram_size=5, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
for j in range(len(outputs)): | |
translation = tokenizer.decode( | |
outputs[j, inputs['input_ids'].shape[1]:], | |
skip_special_tokens=True | |
) | |
translations.append(translation) | |
return translations | |
return translation_fn | |
def get_qwen_translation_fn(model, tokenizer): | |
"""Translation function for Qwen models.""" | |
def translation_fn(texts, from_langs, to_langs): | |
SYSTEM_MESSAGE = 'You are a Ugandan language assistant.' | |
translations = [] | |
batch_size = 8 | |
device = next(model.parameters()).device | |
instructions = [ | |
f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} ' | |
f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}' | |
for text, from_lang, to_lang in zip(texts, from_langs, to_langs) | |
] | |
for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"): | |
batch_instructions = instructions[i:i + batch_size] | |
messages_list = [ | |
[ | |
{"role": "system", "content": SYSTEM_MESSAGE}, | |
{"role": "user", "content": instruction} | |
] for instruction in batch_instructions | |
] | |
prompts = [ | |
tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) for messages in messages_list | |
] | |
inputs = tokenizer( | |
prompts, return_tensors="pt", | |
padding=True, padding_side='left', truncation=True | |
).to(device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, max_new_tokens=100, | |
temperature=0.01, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
for j in range(len(outputs)): | |
translation = tokenizer.decode( | |
outputs[j, inputs['input_ids'].shape[1]:], | |
skip_special_tokens=True | |
) | |
translations.append(translation) | |
return translations | |
return translation_fn | |
def get_nllb_translation_fn(model, tokenizer): | |
"""Translation function for NLLB models.""" | |
def translation_fn(texts, source_langs, target_langs): | |
translations = [] | |
language_tokens = salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION | |
device = next(model.parameters()).device | |
for text, source_language, target_language in tqdm( | |
zip(texts, source_langs, target_langs), total=len(texts), desc="NLLB Translation"): | |
inputs = tokenizer(text, return_tensors="pt").to(device) | |
inputs['input_ids'][0][0] = language_tokens[source_language] | |
with torch.no_grad(): | |
translated_tokens = model.generate( | |
**inputs, | |
forced_bos_token_id=language_tokens[target_language], | |
max_length=100, | |
num_beams=5, | |
) | |
result = tokenizer.batch_decode( | |
translated_tokens, skip_special_tokens=True)[0] | |
translations.append(result) | |
return translations | |
return translation_fn | |
def get_llama_translation_fn(model, tokenizer): | |
"""Translation function for Llama models.""" | |
def translation_fn(texts, from_langs, to_langs): | |
DATE_TODAY = datetime.datetime.now().strftime("%d %b %Y") | |
SYSTEM_MESSAGE = '' | |
translations = [] | |
batch_size = 8 | |
device = next(model.parameters()).device | |
instructions = [ | |
f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} ' | |
f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}' | |
for text, from_lang, to_lang in zip(texts, from_langs, to_langs) | |
] | |
for i in tqdm(range(0, len(instructions), batch_size), desc="Llama Translation"): | |
batch_instructions = instructions[i:i + batch_size] | |
messages_list = [ | |
[ | |
{"role": "system", "content": SYSTEM_MESSAGE}, | |
{"role": "user", "content": instruction} | |
] for instruction in batch_instructions | |
] | |
prompts = [ | |
tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True, | |
date_string=DATE_TODAY, | |
) for messages in messages_list | |
] | |
inputs = tokenizer( | |
prompts, return_tensors="pt", | |
padding=True, padding_side='left', | |
).to(device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, max_new_tokens=100, | |
temperature=0.01, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
for j in range(len(outputs)): | |
translation = tokenizer.decode( | |
outputs[j, inputs['input_ids'].shape[1]:], | |
skip_special_tokens=True | |
) | |
translations.append(translation) | |
return translations | |
return translation_fn | |
def get_generic_translation_fn(model, tokenizer): | |
"""Generic translation function for unknown model types.""" | |
def translation_fn(texts, from_langs, to_langs): | |
translations = [] | |
device = next(model.parameters()).device | |
for text, from_lang, to_lang in tqdm(zip(texts, from_langs, to_langs), | |
desc="Generic Translation"): | |
prompt = f"Translate from {from_lang} to {to_lang}: {text}" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
translation = tokenizer.decode( | |
outputs[0, inputs['input_ids'].shape[1]:], | |
skip_special_tokens=True | |
) | |
translations.append(translation) | |
return translations | |
return translation_fn | |
def calculate_metrics(reference: str, prediction: str) -> dict: | |
"""Calculate multiple translation quality metrics.""" | |
bleu = BLEU(effective_order=True) | |
bleu_score = bleu.sentence_score(prediction, [reference]).score | |
chrf = CHRF() | |
chrf_score = chrf.sentence_score(prediction, [reference]).score / 100.0 | |
cer = Levenshtein.distance(reference, prediction) / max(len(reference), 1) | |
ref_words = reference.split() | |
pred_words = prediction.split() | |
wer = Levenshtein.distance(ref_words, pred_words) / max(len(ref_words), 1) | |
len_ratio = len(prediction) / max(len(reference), 1) | |
metrics = { | |
"bleu": bleu_score, | |
"chrf": chrf_score, | |
"cer": cer, | |
"wer": wer, | |
"len_ratio": len_ratio, | |
} | |
try: | |
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) | |
rouge_scores = scorer.score(reference, prediction) | |
metrics["rouge1"] = rouge_scores['rouge1'].fmeasure | |
metrics["rouge2"] = rouge_scores['rouge2'].fmeasure | |
metrics["rougeL"] = rouge_scores['rougeL'].fmeasure | |
metrics["quality_score"] = ( | |
bleu_score/100 + | |
chrf_score + | |
(1-cer) + | |
(1-wer) + | |
rouge_scores['rouge1'].fmeasure + | |
rouge_scores['rougeL'].fmeasure | |
) / 6 | |
except Exception as e: | |
print(f"Error calculating ROUGE metrics: {e}") | |
metrics["quality_score"] = (bleu_score/100 + chrf_score + (1-cer) + (1-wer)) / 4 | |
return metrics | |
def evaluate_model_full(model, tokenizer, model_path: str, test_data) -> dict: | |
"""Complete model evaluation pipeline.""" | |
# Get translation function | |
translation_fn = get_translation_function(model, tokenizer, model_path) | |
# Generate predictions | |
print("Generating translations...") | |
predictions = translation_fn( | |
list(test_data['source']), | |
list(test_data['source.language']), | |
list(test_data['target.language']), | |
) | |
# Calculate metrics by language pair | |
print("Calculating metrics...") | |
translation_subsets = defaultdict(list) | |
for idx, row in test_data.iterrows(): | |
direction = row['source.language'] + '_to_' + row['target.language'] | |
row_dict = dict(row) | |
row_dict['prediction'] = predictions[idx] | |
translation_subsets[direction].append(row_dict) | |
normalizer = BasicTextNormalizer() | |
grouped_metrics = defaultdict(dict) | |
for subset in translation_subsets.keys(): | |
subset_metrics = defaultdict(list) | |
for example in translation_subsets[subset]: | |
prediction = normalizer(str(example['prediction'])) | |
reference = normalizer(example['target']) | |
metrics = calculate_metrics(reference, prediction) | |
for m in metrics.keys(): | |
subset_metrics[m].append(metrics[m]) | |
for m in subset_metrics.keys(): | |
if subset_metrics[m]: # Check if list is not empty | |
grouped_metrics[subset][m] = float(np.mean(subset_metrics[m])) | |
# Calculate overall averages | |
all_metrics = list(grouped_metrics.values())[0].keys() if grouped_metrics else [] | |
for m in all_metrics: | |
metric_values = [] | |
for subset in translation_subsets.keys(): | |
if m in grouped_metrics[subset]: | |
metric_values.append(grouped_metrics[subset][m]) | |
if metric_values: | |
grouped_metrics['averages'][m] = float(np.mean(metric_values)) | |
return dict(grouped_metrics) |