|
from transformers import ( |
|
MarianMTModel, |
|
MarianTokenizer, |
|
) |
|
|
|
|
|
|
|
class BackTranslator: |
|
""" |
|
Perform Back-translation with pivot language. English -> German -> Spanish -> English |
|
Args: |
|
source_lang: Source language (default: 'en') |
|
pivot_lang: Pivot language (default: 'de') |
|
target_lang: Target language (default: 'es') |
|
Examples: |
|
back_translator = BackTranslator() |
|
back_translator.back_translate("Hello, how are you?") |
|
""" |
|
def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'): |
|
|
|
pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}' |
|
self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name) |
|
self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name) |
|
|
|
|
|
pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}' |
|
self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name) |
|
self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name) |
|
|
|
|
|
backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}' |
|
self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name) |
|
self.model_backward = MarianMTModel.from_pretrained(backward_model_name) |
|
|
|
|
|
self.model_pivot_forward.eval() |
|
self.model_pivot_backward.eval() |
|
self.model_backward.eval() |
|
|
|
def back_translate(self, text, device=None): |
|
try: |
|
|
|
if device is not None: |
|
self.model_pivot_forward = self.model_pivot_forward.to(device) |
|
self.model_pivot_backward = self.model_pivot_backward.to(device) |
|
self.model_backward = self.model_backward.to(device) |
|
|
|
|
|
encoded_pivot = self.tokenizer_pivot_forward([text], padding=True, |
|
truncation=True, return_tensors='pt') |
|
if device is not None: |
|
encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()} |
|
|
|
generated_pivot = self.model_pivot_forward.generate(**encoded_pivot) |
|
if device is not None: |
|
generated_pivot = generated_pivot.cpu() |
|
pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot, |
|
skip_special_tokens=True)[0] |
|
|
|
|
|
encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True, |
|
truncation=True, return_tensors='pt') |
|
if device is not None: |
|
encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()} |
|
|
|
retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot) |
|
if device is not None: |
|
retranslated_pivot = retranslated_pivot.cpu() |
|
tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot, |
|
skip_special_tokens=True)[0] |
|
|
|
|
|
encoded_back = self.tokenizer_backward([tgt_text_back], padding=True, |
|
truncation=True, return_tensors='pt') |
|
if device is not None: |
|
encoded_back = {k: v.to(device) for k, v in encoded_back.items()} |
|
|
|
retranslated = self.model_backward.generate(**encoded_back) |
|
if device is not None: |
|
retranslated = retranslated.cpu() |
|
src_text = self.tokenizer_backward.batch_decode(retranslated, |
|
skip_special_tokens=True)[0] |
|
|
|
return src_text |
|
except Exception as e: |
|
print(f"Error in back translation: {e}") |
|
return text |