akera commited on
Commit
d0ca936
·
verified ·
1 Parent(s): 97a3aa2

Create evaluation.py

Browse files
Files changed (1) hide show
  1. src/evaluation.py +413 -0
src/evaluation.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/evaluation.py
2
+ import torch
3
+ import numpy as np
4
+ from tqdm.auto import tqdm
5
+ from sacrebleu.metrics import BLEU, CHRF
6
+ from rouge_score import rouge_scorer
7
+ import Levenshtein
8
+ from collections import defaultdict
9
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
10
+ import salt.constants
11
+ import datetime
12
+ import os
13
+ from google.cloud import translate_v3
14
+ from config import GOOGLE_LANG_MAP
15
+
16
+ def setup_google_translate():
17
+ """Setup Google Cloud Translation client if credentials available."""
18
+ try:
19
+ # Check if running in HF Space with credentials
20
+ if os.getenv("GOOGLE_APPLICATION_CREDENTIALS") or os.getenv("GOOGLE_CLOUD_PROJECT"):
21
+ client = translate_v3.TranslationServiceClient()
22
+ project_id = os.getenv("GOOGLE_CLOUD_PROJECT", "sb-gcp-project-01")
23
+ parent = f"projects/{project_id}/locations/global"
24
+ return client, parent
25
+ else:
26
+ print("Google Cloud credentials not found. Google Translate will not be available.")
27
+ return None, None
28
+ except Exception as e:
29
+ print(f"Error setting up Google Translate: {e}")
30
+ return None, None
31
+
32
+ def google_translate_batch(texts, source_langs, target_langs, client, parent):
33
+ """Translate using Google Cloud Translation API."""
34
+ translations = []
35
+
36
+ for text, src_lang, tgt_lang in tqdm(zip(texts, source_langs, target_langs),
37
+ total=len(texts), desc="Google Translate"):
38
+ try:
39
+ # Map SALT language codes to Google's format
40
+ src_google = GOOGLE_LANG_MAP.get(src_lang, src_lang)
41
+ tgt_google = GOOGLE_LANG_MAP.get(tgt_lang, tgt_lang)
42
+
43
+ # Check if language pair is supported
44
+ supported_langs = ['lg', 'ach', 'sw', 'en']
45
+ if src_google not in supported_langs or tgt_google not in supported_langs:
46
+ translations.append(f"[UNSUPPORTED: {src_lang}->{tgt_lang}]")
47
+ continue
48
+
49
+ # Make translation request
50
+ request = {
51
+ "parent": parent,
52
+ "contents": [text],
53
+ "mime_type": "text/plain",
54
+ "source_language_code": src_google,
55
+ "target_language_code": tgt_google,
56
+ }
57
+
58
+ response = client.translate_text(request=request)
59
+ translation = response.translations[0].translated_text
60
+ translations.append(translation)
61
+
62
+ except Exception as e:
63
+ print(f"Error translating '{text}': {e}")
64
+ translations.append(f"[ERROR: {str(e)[:50]}]")
65
+
66
+ return translations
67
+
68
+ def get_translation_function(model, tokenizer, model_path):
69
+ """Get appropriate translation function based on model type."""
70
+
71
+ if model_path == 'google-translate':
72
+ client, parent = setup_google_translate()
73
+ if client is None:
74
+ raise Exception("Google Translate credentials not available")
75
+
76
+ def translation_fn(texts, from_langs, to_langs):
77
+ return google_translate_batch(texts, from_langs, to_langs, client, parent)
78
+
79
+ return translation_fn
80
+
81
+ elif 'gemma' in str(type(model)).lower() or 'gemma' in model_path.lower():
82
+ return get_gemma_translation_fn(model, tokenizer)
83
+
84
+ elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'Qwen2ForCausalLM' in str(type(model.base_model.model)):
85
+ return get_qwen_translation_fn(model, tokenizer)
86
+
87
+ elif 'm2m_100' in str(type(model)).lower():
88
+ return get_nllb_translation_fn(model, tokenizer)
89
+
90
+ elif hasattr(model, 'base_model') and hasattr(model.base_model, 'model') and 'LlamaForCausalLM' in str(type(model.base_model.model)):
91
+ return get_llama_translation_fn(model, tokenizer)
92
+
93
+ else:
94
+ # Generic function for other models
95
+ return get_generic_translation_fn(model, tokenizer)
96
+
97
+ def get_gemma_translation_fn(model, tokenizer):
98
+ """Translation function for Gemma models."""
99
+ def translation_fn(texts, from_langs, to_langs):
100
+ SYSTEM_MESSAGE = 'You are a linguist and translation assistant specialising in Ugandan languages.'
101
+ translations = []
102
+ batch_size = 4
103
+ device = next(model.parameters()).device
104
+
105
+ instructions = [
106
+ f'Translate from {salt.constants.SALT_LANGUAGE_NAMES[from_lang]} '
107
+ f'to {salt.constants.SALT_LANGUAGE_NAMES[to_lang]}: {text}'
108
+ for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
109
+ ]
110
+
111
+ for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"):
112
+ batch_instructions = instructions[i:i + batch_size]
113
+ messages_list = [
114
+ [
115
+ {"role": "system", "content": SYSTEM_MESSAGE},
116
+ {"role": "user", "content": instruction}
117
+ ] for instruction in batch_instructions
118
+ ]
119
+
120
+ prompts = [
121
+ tokenizer.apply_chat_template(
122
+ messages, tokenize=False, add_generation_prompt=True
123
+ ) for messages in messages_list
124
+ ]
125
+
126
+ inputs = tokenizer(
127
+ prompts, return_tensors="pt",
128
+ padding=True, padding_side='left',
129
+ max_length=512, truncation=True
130
+ ).to(device)
131
+
132
+ with torch.no_grad():
133
+ outputs = model.generate(
134
+ **inputs,
135
+ max_new_tokens=100,
136
+ temperature=0.5,
137
+ num_beams=5,
138
+ do_sample=True,
139
+ no_repeat_ngram_size=5,
140
+ pad_token_id=tokenizer.eos_token_id
141
+ )
142
+
143
+ for j in range(len(outputs)):
144
+ translation = tokenizer.decode(
145
+ outputs[j, inputs['input_ids'].shape[1]:],
146
+ skip_special_tokens=True
147
+ )
148
+ translations.append(translation)
149
+
150
+ return translations
151
+
152
+ return translation_fn
153
+
154
+ def get_qwen_translation_fn(model, tokenizer):
155
+ """Translation function for Qwen models."""
156
+ def translation_fn(texts, from_langs, to_langs):
157
+ SYSTEM_MESSAGE = 'You are a Ugandan language assistant.'
158
+ translations = []
159
+ batch_size = 8
160
+ device = next(model.parameters()).device
161
+
162
+ instructions = [
163
+ f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} '
164
+ f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}'
165
+ for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
166
+ ]
167
+
168
+ for i in tqdm(range(0, len(instructions), batch_size), desc="Generating translations"):
169
+ batch_instructions = instructions[i:i + batch_size]
170
+ messages_list = [
171
+ [
172
+ {"role": "system", "content": SYSTEM_MESSAGE},
173
+ {"role": "user", "content": instruction}
174
+ ] for instruction in batch_instructions
175
+ ]
176
+
177
+ prompts = [
178
+ tokenizer.apply_chat_template(
179
+ messages, tokenize=False, add_generation_prompt=True
180
+ ) for messages in messages_list
181
+ ]
182
+
183
+ inputs = tokenizer(
184
+ prompts, return_tensors="pt",
185
+ padding=True, padding_side='left', truncation=True
186
+ ).to(device)
187
+
188
+ with torch.no_grad():
189
+ outputs = model.generate(
190
+ **inputs, max_new_tokens=100,
191
+ temperature=0.01,
192
+ pad_token_id=tokenizer.eos_token_id
193
+ )
194
+
195
+ for j in range(len(outputs)):
196
+ translation = tokenizer.decode(
197
+ outputs[j, inputs['input_ids'].shape[1]:],
198
+ skip_special_tokens=True
199
+ )
200
+ translations.append(translation)
201
+
202
+ return translations
203
+
204
+ return translation_fn
205
+
206
+ def get_nllb_translation_fn(model, tokenizer):
207
+ """Translation function for NLLB models."""
208
+ def translation_fn(texts, source_langs, target_langs):
209
+ translations = []
210
+ language_tokens = salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION
211
+ device = next(model.parameters()).device
212
+
213
+ for text, source_language, target_language in tqdm(
214
+ zip(texts, source_langs, target_langs), total=len(texts), desc="NLLB Translation"):
215
+
216
+ inputs = tokenizer(text, return_tensors="pt").to(device)
217
+ inputs['input_ids'][0][0] = language_tokens[source_language]
218
+
219
+ with torch.no_grad():
220
+ translated_tokens = model.generate(
221
+ **inputs,
222
+ forced_bos_token_id=language_tokens[target_language],
223
+ max_length=100,
224
+ num_beams=5,
225
+ )
226
+
227
+ result = tokenizer.batch_decode(
228
+ translated_tokens, skip_special_tokens=True)[0]
229
+ translations.append(result)
230
+
231
+ return translations
232
+
233
+ return translation_fn
234
+
235
+ def get_llama_translation_fn(model, tokenizer):
236
+ """Translation function for Llama models."""
237
+ def translation_fn(texts, from_langs, to_langs):
238
+ DATE_TODAY = datetime.datetime.now().strftime("%d %b %Y")
239
+ SYSTEM_MESSAGE = ''
240
+ translations = []
241
+ batch_size = 8
242
+ device = next(model.parameters()).device
243
+
244
+ instructions = [
245
+ f'Translate from {salt.constants.SALT_LANGUAGE_NAMES.get(from_lang, from_lang)} '
246
+ f'to {salt.constants.SALT_LANGUAGE_NAMES.get(to_lang, to_lang)}: {text}'
247
+ for text, from_lang, to_lang in zip(texts, from_langs, to_langs)
248
+ ]
249
+
250
+ for i in tqdm(range(0, len(instructions), batch_size), desc="Llama Translation"):
251
+ batch_instructions = instructions[i:i + batch_size]
252
+ messages_list = [
253
+ [
254
+ {"role": "system", "content": SYSTEM_MESSAGE},
255
+ {"role": "user", "content": instruction}
256
+ ] for instruction in batch_instructions
257
+ ]
258
+
259
+ prompts = [
260
+ tokenizer.apply_chat_template(
261
+ messages, tokenize=False, add_generation_prompt=True,
262
+ date_string=DATE_TODAY,
263
+ ) for messages in messages_list
264
+ ]
265
+
266
+ inputs = tokenizer(
267
+ prompts, return_tensors="pt",
268
+ padding=True, padding_side='left',
269
+ ).to(device)
270
+
271
+ with torch.no_grad():
272
+ outputs = model.generate(
273
+ **inputs, max_new_tokens=100,
274
+ temperature=0.01,
275
+ pad_token_id=tokenizer.eos_token_id
276
+ )
277
+
278
+ for j in range(len(outputs)):
279
+ translation = tokenizer.decode(
280
+ outputs[j, inputs['input_ids'].shape[1]:],
281
+ skip_special_tokens=True
282
+ )
283
+ translations.append(translation)
284
+
285
+ return translations
286
+
287
+ return translation_fn
288
+
289
+ def get_generic_translation_fn(model, tokenizer):
290
+ """Generic translation function for unknown model types."""
291
+ def translation_fn(texts, from_langs, to_langs):
292
+ translations = []
293
+ device = next(model.parameters()).device
294
+
295
+ for text, from_lang, to_lang in tqdm(zip(texts, from_langs, to_langs),
296
+ desc="Generic Translation"):
297
+ prompt = f"Translate from {from_lang} to {to_lang}: {text}"
298
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
299
+
300
+ with torch.no_grad():
301
+ outputs = model.generate(
302
+ **inputs,
303
+ max_new_tokens=100,
304
+ temperature=0.7,
305
+ pad_token_id=tokenizer.eos_token_id
306
+ )
307
+
308
+ translation = tokenizer.decode(
309
+ outputs[0, inputs['input_ids'].shape[1]:],
310
+ skip_special_tokens=True
311
+ )
312
+ translations.append(translation)
313
+
314
+ return translations
315
+
316
+ return translation_fn
317
+
318
+ def calculate_metrics(reference: str, prediction: str) -> dict:
319
+ """Calculate multiple translation quality metrics."""
320
+ bleu = BLEU(effective_order=True)
321
+ bleu_score = bleu.sentence_score(prediction, [reference]).score
322
+
323
+ chrf = CHRF()
324
+ chrf_score = chrf.sentence_score(prediction, [reference]).score / 100.0
325
+
326
+ cer = Levenshtein.distance(reference, prediction) / max(len(reference), 1)
327
+
328
+ ref_words = reference.split()
329
+ pred_words = prediction.split()
330
+ wer = Levenshtein.distance(ref_words, pred_words) / max(len(ref_words), 1)
331
+
332
+ len_ratio = len(prediction) / max(len(reference), 1)
333
+
334
+ metrics = {
335
+ "bleu": bleu_score,
336
+ "chrf": chrf_score,
337
+ "cer": cer,
338
+ "wer": wer,
339
+ "len_ratio": len_ratio,
340
+ }
341
+
342
+ try:
343
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
344
+ rouge_scores = scorer.score(reference, prediction)
345
+
346
+ metrics["rouge1"] = rouge_scores['rouge1'].fmeasure
347
+ metrics["rouge2"] = rouge_scores['rouge2'].fmeasure
348
+ metrics["rougeL"] = rouge_scores['rougeL'].fmeasure
349
+
350
+ metrics["quality_score"] = (
351
+ bleu_score/100 +
352
+ chrf_score +
353
+ (1-cer) +
354
+ (1-wer) +
355
+ rouge_scores['rouge1'].fmeasure +
356
+ rouge_scores['rougeL'].fmeasure
357
+ ) / 6
358
+ except Exception as e:
359
+ print(f"Error calculating ROUGE metrics: {e}")
360
+ metrics["quality_score"] = (bleu_score/100 + chrf_score + (1-cer) + (1-wer)) / 4
361
+
362
+ return metrics
363
+
364
+ def evaluate_model_full(model, tokenizer, model_path: str, test_data) -> dict:
365
+ """Complete model evaluation pipeline."""
366
+
367
+ # Get translation function
368
+ translation_fn = get_translation_function(model, tokenizer, model_path)
369
+
370
+ # Generate predictions
371
+ print("Generating translations...")
372
+ predictions = translation_fn(
373
+ list(test_data['source']),
374
+ list(test_data['source.language']),
375
+ list(test_data['target.language']),
376
+ )
377
+
378
+ # Calculate metrics by language pair
379
+ print("Calculating metrics...")
380
+ translation_subsets = defaultdict(list)
381
+ for idx, row in test_data.iterrows():
382
+ direction = row['source.language'] + '_to_' + row['target.language']
383
+ row_dict = dict(row)
384
+ row_dict['prediction'] = predictions[idx]
385
+ translation_subsets[direction].append(row_dict)
386
+
387
+ normalizer = BasicTextNormalizer()
388
+ grouped_metrics = defaultdict(dict)
389
+
390
+ for subset in translation_subsets.keys():
391
+ subset_metrics = defaultdict(list)
392
+ for example in translation_subsets[subset]:
393
+ prediction = normalizer(str(example['prediction']))
394
+ reference = normalizer(example['target'])
395
+ metrics = calculate_metrics(reference, prediction)
396
+ for m in metrics.keys():
397
+ subset_metrics[m].append(metrics[m])
398
+
399
+ for m in subset_metrics.keys():
400
+ if subset_metrics[m]: # Check if list is not empty
401
+ grouped_metrics[subset][m] = float(np.mean(subset_metrics[m]))
402
+
403
+ # Calculate overall averages
404
+ all_metrics = list(grouped_metrics.values())[0].keys() if grouped_metrics else []
405
+ for m in all_metrics:
406
+ metric_values = []
407
+ for subset in translation_subsets.keys():
408
+ if m in grouped_metrics[subset]:
409
+ metric_values.append(grouped_metrics[subset][m])
410
+ if metric_values:
411
+ grouped_metrics['averages'][m] = float(np.mean(metric_values))
412
+
413
+ return dict(grouped_metrics)