File size: 4,025 Bytes
78d1101
 
54108c5
 
78d1101
54108c5
1c7cbff
78d1101
 
 
 
a6284c0
 
 
 
 
 
 
 
 
 
54108c5
a6284c0
54108c5
a6284c0
 
 
 
54108c5
 
 
a6284c0
54108c5
 
 
 
 
 
 
 
a6284c0
 
 
 
 
54108c5
 
a6284c0
54108c5
 
a6284c0
54108c5
 
 
 
 
 
78d1101
 
a6284c0
 
 
 
 
 
 
 
 
54108c5
 
 
 
 
 
 
 
 
 
 
78d1101
261a5aa
4a24c4f
 
78d1101
 
7f3d8a9
dfb286c
 
fb3e214
7f3d8a9
261a5aa
7f3d8a9
 
54108c5
7f3d8a9
 
261a5aa
4a24c4f
a6284c0
c220549
54108c5
 
 
 
a6284c0
 
54108c5
 
c220549
4a24c4f
 
7cc5f39
4a24c4f
261a5aa
20a6b7b
 
a6284c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import spaces
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from huggingface_hub import login
import os

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)

def get_tokenizer(src_lang, tgt_lang):
    """Initialise et retourne le tokenizer approprié"""
    if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/3b-new-400"
    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    
    return AutoTokenizer.from_pretrained(model_id, token=auth_token)

def split_text_by_tokens(text, src_lang, tgt_lang, max_tokens_per_chunk=200):
    """
    Divise le texte en chunks en respectant les phrases et en comptant les tokens.
    """
    tokenizer = get_tokenizer(src_lang, tgt_lang)
    tokenizer.src_lang = src_lang
    
    # Séparation basée sur les phrases
    sentences = re.split(r'([.!?])', text)
    chunks = []
    current_chunk = ""
    current_tokens = 0
    
    for i in range(0, len(sentences), 2):
        # Reconstruire la phrase avec sa ponctuation
        if i + 1 < len(sentences):
            sentence = sentences[i] + sentences[i+1]
        else:
            sentence = sentences[i]
        
        # Calculer le nombre de tokens pour cette phrase
        sentence_tokens = len(tokenizer.encode(sentence))
        
        # Si l'ajout de cette phrase dépasse la limite de tokens, on crée un nouveau chunk
        if current_tokens + sentence_tokens > max_tokens_per_chunk and current_chunk:
            chunks.append(current_chunk.strip())
            current_chunk = sentence
            current_tokens = sentence_tokens
        else:
            current_chunk += sentence
            current_tokens += sentence_tokens
    
    # Ajouter le dernier chunk s'il reste du texte
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    return chunks

@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang, max_tokens_per_chunk=200):
    # Vérifier si le texte a besoin d'être divisé en comptant les tokens
    tokenizer = get_tokenizer(src_lang, tgt_lang)
    tokenizer.src_lang = src_lang
    
    text_tokens = len(tokenizer.encode(text))
    
    if text_tokens > max_tokens_per_chunk:
        chunks = split_text_by_tokens(text, src_lang, tgt_lang, max_tokens_per_chunk)
        translations = []
        
        for chunk in chunks:
            translated_chunk = translate_chunk(chunk, src_lang, tgt_lang)
            translations.append(translated_chunk)
        
        return " ".join(translations)
    else:
        return translate_chunk(text, src_lang, tgt_lang)

def translate_chunk(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/3b-new-400"
    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
    
    # Configuration du tokenizer
    tokenizer.src_lang = src_lang
    
    # Tokenisation
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
    
    # ID du token de langue cible
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    
    # Paramètres de génération optimisés
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tgt_lang_id,
        max_new_tokens=512,
        num_beams=5,
        no_repeat_ngram_size=3,
        repetition_penalty=1.5,
        length_penalty=1.0,
        early_stopping=True
    )
    
    # Décodage
    translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return translation

def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang, max_tokens_per_chunk=200)