csc525_retrieval_based_chatbot / back_translator.py
JoeArmani
updates through 4th iteration
300fe5d
raw
history blame
4.47 kB
from transformers import (
MarianMTModel,
MarianTokenizer,
)
# Retained for reference but removed from the final code.
# This method did not seem helpful for this retrieval-based chatbot.
class BackTranslator:
"""
Perform Back-translation with pivot language. English -> German -> Spanish -> English
Args:
source_lang: Source language (default: 'en')
pivot_lang: Pivot language (default: 'de')
target_lang: Target language (default: 'es')
Examples:
back_translator = BackTranslator()
back_translator.back_translate("Hello, how are you?")
"""
def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'):
# Forward (English to German)
pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}'
self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
# Pivot translation (German to Spanish)
pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)
# Backward (Spanish to English)
backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
# Set models to eval mode
self.model_pivot_forward.eval()
self.model_pivot_backward.eval()
self.model_backward.eval()
def back_translate(self, text, device=None):
try:
# Move models to device if specified
if device is not None:
self.model_pivot_forward = self.model_pivot_forward.to(device)
self.model_pivot_backward = self.model_pivot_backward.to(device)
self.model_backward = self.model_backward.to(device)
# Forward translation (English to German)
encoded_pivot = self.tokenizer_pivot_forward([text], padding=True,
truncation=True, return_tensors='pt')
if device is not None:
encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()}
generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
if device is not None:
generated_pivot = generated_pivot.cpu()
pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot,
skip_special_tokens=True)[0]
# Pivot translation (German to Spanish)
encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True,
truncation=True, return_tensors='pt')
if device is not None:
encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()}
retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
if device is not None:
retranslated_pivot = retranslated_pivot.cpu()
tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot,
skip_special_tokens=True)[0]
# Backward translation (Spanish to English)
encoded_back = self.tokenizer_backward([tgt_text_back], padding=True,
truncation=True, return_tensors='pt')
if device is not None:
encoded_back = {k: v.to(device) for k, v in encoded_back.items()}
retranslated = self.model_backward.generate(**encoded_back)
if device is not None:
retranslated = retranslated.cpu()
src_text = self.tokenizer_backward.batch_decode(retranslated,
skip_special_tokens=True)[0]
return src_text
except Exception as e:
print(f"Error in back translation: {e}")
return text