File size: 4,473 Bytes
3190e1e
 
 
 
 
300fe5d
 
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300fe5d
3190e1e
 
 
 
 
 
 
 
300fe5d
 
 
 
 
3190e1e
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
3190e1e
300fe5d
 
 
 
 
3190e1e
300fe5d
 
 
 
 
 
 
 
 
 
 
3190e1e
300fe5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from transformers import (
    MarianMTModel,
    MarianTokenizer,
)

# Retained for reference but removed from the final code.
# This method did not seem helpful for this retrieval-based chatbot.
class BackTranslator:
    """
    Perform Back-translation with pivot language. English -> German -> Spanish -> English
    Args:
        source_lang: Source language (default: 'en')
        pivot_lang: Pivot language (default: 'de')
        target_lang: Target language (default: 'es')
    Examples:
        back_translator = BackTranslator()
        back_translator.back_translate("Hello, how are you?")
    """
    def __init__(self, source_lang='en', pivot_lang='de', target_lang='es'):
        # Forward (English to German)
        pivot_forward_model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}'
        self.tokenizer_pivot_forward = MarianTokenizer.from_pretrained(pivot_forward_model_name)
        self.model_pivot_forward = MarianMTModel.from_pretrained(pivot_forward_model_name)
        
        # Pivot translation (German to Spanish) 
        pivot_backward_model_name = f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}'
        self.tokenizer_pivot_backward = MarianTokenizer.from_pretrained(pivot_backward_model_name)
        self.model_pivot_backward = MarianMTModel.from_pretrained(pivot_backward_model_name)

        # Backward (Spanish to English)
        backward_model_name = f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}'
        self.tokenizer_backward = MarianTokenizer.from_pretrained(backward_model_name)
        self.model_backward = MarianMTModel.from_pretrained(backward_model_name)
        
        # Set models to eval mode
        self.model_pivot_forward.eval()
        self.model_pivot_backward.eval()
        self.model_backward.eval()

    def back_translate(self, text, device=None):
        try:
            # Move models to device if specified
            if device is not None:
                self.model_pivot_forward = self.model_pivot_forward.to(device)
                self.model_pivot_backward = self.model_pivot_backward.to(device)
                self.model_backward = self.model_backward.to(device)
            
            # Forward translation (English to German)
            encoded_pivot = self.tokenizer_pivot_forward([text], padding=True, 
                                                       truncation=True, return_tensors='pt')
            if device is not None:
                encoded_pivot = {k: v.to(device) for k, v in encoded_pivot.items()}
            
            generated_pivot = self.model_pivot_forward.generate(**encoded_pivot)
            if device is not None:
                generated_pivot = generated_pivot.cpu()
            pivot_text = self.tokenizer_pivot_forward.batch_decode(generated_pivot, 
                                                                 skip_special_tokens=True)[0]

            # Pivot translation (German to Spanish)
            encoded_back_pivot = self.tokenizer_pivot_backward([pivot_text], padding=True, 
                                                             truncation=True, return_tensors='pt')
            if device is not None:
                encoded_back_pivot = {k: v.to(device) for k, v in encoded_back_pivot.items()}
            
            retranslated_pivot = self.model_pivot_backward.generate(**encoded_back_pivot)
            if device is not None:
                retranslated_pivot = retranslated_pivot.cpu()
            tgt_text_back = self.tokenizer_pivot_backward.batch_decode(retranslated_pivot, 
                                                                     skip_special_tokens=True)[0]

            # Backward translation (Spanish to English)
            encoded_back = self.tokenizer_backward([tgt_text_back], padding=True, 
                                                 truncation=True, return_tensors='pt')
            if device is not None:
                encoded_back = {k: v.to(device) for k, v in encoded_back.items()}
            
            retranslated = self.model_backward.generate(**encoded_back)
            if device is not None:
                retranslated = retranslated.cpu()
            src_text = self.tokenizer_backward.batch_decode(retranslated, 
                                                          skip_special_tokens=True)[0]
            
            return src_text
        except Exception as e:
            print(f"Error in back translation: {e}")
            return text