from transformers import ( MarianMTModel, MarianTokenizer, ) # Retained for reference but removed from the final code. # This method did not seem helpful for this retrieval-based chatbot. class BackTranslator: """ Perform Back-translation with pivot language. English -> German -> Spanish -> English Args: source_lang: Source language (default: 'en') pivot_lang: Pivot language (default: 'de') target_lang: Target language (default: 'es') Examples: back_translator = BackTranslator() back_translator.back_translate("Hello, how are you?") """ def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'): # Forward (English to German) pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}' self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name) self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name) # Pivot translation (German to Spanish) pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}' self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name) self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name) # Backward (Spanish to English) backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}' self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name) self.model_backward = MarianMTModel.from_pretrained(backward_model_name) # Set models to eval mode self.model_pivot_forward.eval() self.model_pivot_backward.eval() self.model_backward.eval() def back_translate(self, text, device=None): try: # Move models to device if specified if device is not None: self.model_pivot_forward = self.model_pivot_forward.to(device) self.model_pivot_backward = self.model_pivot_backward.to(device) self.model_backward = self.model_backward.to(device) # Forward translation (English to German) encoded_pivot = self.tokenizer_pivot_forward([text], padding=True, truncation=True, return_tensors='pt') if device is not None: encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()} generated_pivot = self.model_pivot_forward.generate(**encoded_pivot) if device is not None: generated_pivot = generated_pivot.cpu() pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot, skip_special_tokens=True)[0] # Pivot translation (German to Spanish) encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True, truncation=True, return_tensors='pt') if device is not None: encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()} retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot) if device is not None: retranslated_pivot = retranslated_pivot.cpu() tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot, skip_special_tokens=True)[0] # Backward translation (Spanish to English) encoded_back = self.tokenizer_backward([tgt_text_back], padding=True, truncation=True, return_tensors='pt') if device is not None: encoded_back = {k: v.to(device) for k, v in encoded_back.items()} retranslated = self.model_backward.generate(**encoded_back) if device is not None: retranslated = retranslated.cpu() src_text = self.tokenizer_backward.batch_decode(retranslated, skip_special_tokens=True)[0] return src_text except Exception as e: print(f"Error in back translation: {e}") return text