# src/evaluation.py import torch import numpy as np from tqdm.auto import tqdm from sacrebleu.metrics import BLEU, CHRF from rouge_score import rouge_scorer import Levenshtein from collections import defaultdict from transformers.models.whisper.english_normalizer import BasicTextNormalizer import salt.constants import datetime import os from google.cloud import translate_v3 from config import GOOGLE_LANG_MAP def setup_google_translate(): """Setup Google Cloud Translation client if credentials available.""" try: # Check if running in HF Space with credentials if os.getenv("GOOGLE_APPLICATION_CREDENTIALS") or os.getenv("GOOGLE_CLOUD_PROJECT"): client = translate_v3.TranslationServiceClient() project_id = os.getenv("GOOGLE_CLOUD_PROJECT", "sb-gcp-project-01") parent = f"projects/{project_id}/locations/global" return client, parent else: print("Google Cloud credentials not found. Google Translate will not be available.") return None, None except Exception as e: print(f"Error setting up Google Translate: {e}") return None, None def google_translate_batch(texts, source_langs, target_langs, client, parent): """Translate using Google Cloud Translation API.""" translations = [] for text, src_lang, tgt_lang in tqdm(zip(texts, source_langs, target_langs), total=len(texts), desc="Google Translate"): try: # Map SALT language codes to Google's format src_google = GOOGLE_LANG_MAP.get(src_lang, src_lang) tgt_google = GOOGLE_LANG_MAP.get(tgt_lang, tgt_lang) # Check if language pair is supported supported_langs = ['lg', 'ach', 'sw', 'en'] if src_google not in supported_langs or tgt_google not in supported_langs: translations.append(f"[UNSUPPORTED: {src_lang}->{tgt_lang}]") continue # Make translation request request = { "parent": parent, "contents": [text], "mime_type": "text/plain", "source_language_code": src_google, "target_language_code": tgt_google, } response = client.translate_text(request=request) translation = response.translations[0].translated_text translations.append(translation) except Exception as e: print(f"Error translating '{text}': {e}") translations.append(f"[ERROR: {str(e)[:50]}]") return translations def get_translation_function(model, tokenizer, model_path): """Get appropriate translation function based on model type.""" if model_path == 'google-translate': client, parent = setup_google_translate() if client is None: raise Exception("Google Translate credentials not available") def translation_fn(texts, from_langs, to_langs): return google_translate_batch(texts, from_langs, to_langs, client, parent) return translation_fn elif 'gemma' in str(type(model)).lower() or 'gemma' in model_path.lower(): return get_gemma_translation_fn(model, tokenizer) elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'Qwen2ForCausalLM' in str(type(model.base_model.model)): return get_qwen_translation_fn(model, tokenizer) elif 'm2m_100' in str(type(model)).lower(): return get_nllb_translation_fn(model, tokenizer) elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'LlamaForCausalLM' in str(type(model.base_model.model)): return get_llama_translation_fn(model, tokenizer) else: # Generic function for other models return get_generic_translation_fn(model, tokenizer) def get_gemma_translation_fn(model, tokenizer): """Translation function for Gemma models.""" def translation_fn(texts, from_langs, to_langs): SYSTEM_MESSAGE = 'You are a linguist and translation assistant specialising in Ugandan languages.' translations = [] batch_size = 4 device = next(model.parameters()).device instructions = [ f'Translate from {salt.constants.SALT_LANGUAGE_NAMES[from_lang]} ' f'to {salt.constants.SALT_LANGUAGE_NAMES[to_lang]}: {text}' for text, from_lang, to_lang in zip(texts, from_langs, to_langs) ] for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"): batch_instructions = instructions[i:i + batch_size] messages_list = [ [ {"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": instruction} ] for instruction in batch_instructions ] prompts = [ tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) for messages in messages_list ] inputs = tokenizer( prompts, return_tensors="pt", padding=True, padding_side='left', max_length=512, truncation=True ).to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.5, num_beams=5, do_sample=True, no_repeat_ngram_size=5, pad_token_id=tokenizer.eos_token_id ) for j in range(len(outputs)): translation = tokenizer.decode( outputs[j, inputs['input_ids'].shape[1]:], skip_special_tokens=True ) translations.append(translation) return translations return translation_fn def get_qwen_translation_fn(model, tokenizer): """Translation function for Qwen models.""" def translation_fn(texts, from_langs, to_langs): SYSTEM_MESSAGE = 'You are a Ugandan language assistant.' translations = [] batch_size = 8 device = next(model.parameters()).device instructions = [ f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} ' f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}' for text, from_lang, to_lang in zip(texts, from_langs, to_langs) ] for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"): batch_instructions = instructions[i:i + batch_size] messages_list = [ [ {"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": instruction} ] for instruction in batch_instructions ] prompts = [ tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) for messages in messages_list ] inputs = tokenizer( prompts, return_tensors="pt", padding=True, padding_side='left', truncation=True ).to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.01, pad_token_id=tokenizer.eos_token_id ) for j in range(len(outputs)): translation = tokenizer.decode( outputs[j, inputs['input_ids'].shape[1]:], skip_special_tokens=True ) translations.append(translation) return translations return translation_fn def get_nllb_translation_fn(model, tokenizer): """Translation function for NLLB models.""" def translation_fn(texts, source_langs, target_langs): translations = [] language_tokens = salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION device = next(model.parameters()).device for text, source_language, target_language in tqdm( zip(texts, source_langs, target_langs), total=len(texts), desc="NLLB Translation"): inputs = tokenizer(text, return_tensors="pt").to(device) inputs['input_ids'][0][0] = language_tokens[source_language] with torch.no_grad(): translated_tokens = model.generate( **inputs, forced_bos_token_id=language_tokens[target_language], max_length=100, num_beams=5, ) result = tokenizer.batch_decode( translated_tokens, skip_special_tokens=True)[0] translations.append(result) return translations return translation_fn def get_llama_translation_fn(model, tokenizer): """Translation function for Llama models.""" def translation_fn(texts, from_langs, to_langs): DATE_TODAY = datetime.datetime.now().strftime("%d %b %Y") SYSTEM_MESSAGE = '' translations = [] batch_size = 8 device = next(model.parameters()).device instructions = [ f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} ' f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}' for text, from_lang, to_lang in zip(texts, from_langs, to_langs) ] for i in tqdm(range(0, len(instructions), batch_size), desc="Llama Translation"): batch_instructions = instructions[i:i + batch_size] messages_list = [ [ {"role": "system", "content": SYSTEM_MESSAGE}, {"role": "user", "content": instruction} ] for instruction in batch_instructions ] prompts = [ tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, date_string=DATE_TODAY, ) for messages in messages_list ] inputs = tokenizer( prompts, return_tensors="pt", padding=True, padding_side='left', ).to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.01, pad_token_id=tokenizer.eos_token_id ) for j in range(len(outputs)): translation = tokenizer.decode( outputs[j, inputs['input_ids'].shape[1]:], skip_special_tokens=True ) translations.append(translation) return translations return translation_fn def get_generic_translation_fn(model, tokenizer): """Generic translation function for unknown model types.""" def translation_fn(texts, from_langs, to_langs): translations = [] device = next(model.parameters()).device for text, from_lang, to_lang in tqdm(zip(texts, from_langs, to_langs), desc="Generic Translation"): prompt = f"Translate from {from_lang} to {to_lang}: {text}" inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, temperature=0.7, pad_token_id=tokenizer.eos_token_id ) translation = tokenizer.decode( outputs[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True ) translations.append(translation) return translations return translation_fn def calculate_metrics(reference: str, prediction: str) -> dict: """Calculate multiple translation quality metrics.""" bleu = BLEU(effective_order=True) bleu_score = bleu.sentence_score(prediction, [reference]).score chrf = CHRF() chrf_score = chrf.sentence_score(prediction, [reference]).score / 100.0 cer = Levenshtein.distance(reference, prediction) / max(len(reference), 1) ref_words = reference.split() pred_words = prediction.split() wer = Levenshtein.distance(ref_words, pred_words) / max(len(ref_words), 1) len_ratio = len(prediction) / max(len(reference), 1) metrics = { "bleu": bleu_score, "chrf": chrf_score, "cer": cer, "wer": wer, "len_ratio": len_ratio, } try: scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) rouge_scores = scorer.score(reference, prediction) metrics["rouge1"] = rouge_scores['rouge1'].fmeasure metrics["rouge2"] = rouge_scores['rouge2'].fmeasure metrics["rougeL"] = rouge_scores['rougeL'].fmeasure metrics["quality_score"] = ( bleu_score/100 + chrf_score + (1-cer) + (1-wer) + rouge_scores['rouge1'].fmeasure + rouge_scores['rougeL'].fmeasure ) / 6 except Exception as e: print(f"Error calculating ROUGE metrics: {e}") metrics["quality_score"] = (bleu_score/100 + chrf_score + (1-cer) + (1-wer)) / 4 return metrics def evaluate_model_full(model, tokenizer, model_path: str, test_data) -> dict: """Complete model evaluation pipeline.""" # Get translation function translation_fn = get_translation_function(model, tokenizer, model_path) # Generate predictions print("Generating translations...") predictions = translation_fn( list(test_data['source']), list(test_data['source.language']), list(test_data['target.language']), ) # Calculate metrics by language pair print("Calculating metrics...") translation_subsets = defaultdict(list) for idx, row in test_data.iterrows(): direction = row['source.language'] + '_to_' + row['target.language'] row_dict = dict(row) row_dict['prediction'] = predictions[idx] translation_subsets[direction].append(row_dict) normalizer = BasicTextNormalizer() grouped_metrics = defaultdict(dict) for subset in translation_subsets.keys(): subset_metrics = defaultdict(list) for example in translation_subsets[subset]: prediction = normalizer(str(example['prediction'])) reference = normalizer(example['target']) metrics = calculate_metrics(reference, prediction) for m in metrics.keys(): subset_metrics[m].append(metrics[m]) for m in subset_metrics.keys(): if subset_metrics[m]: # Check if list is not empty grouped_metrics[subset][m] = float(np.mean(subset_metrics[m])) # Calculate overall averages all_metrics = list(grouped_metrics.values())[0].keys() if grouped_metrics else [] for m in all_metrics: metric_values = [] for subset in translation_subsets.keys(): if m in grouped_metrics[subset]: metric_values.append(grouped_metrics[subset][m]) if metric_values: grouped_metrics['averages'][m] = float(np.mean(metric_values)) return dict(grouped_metrics)