File size: 1,736 Bytes
78d1101
 
49dc84f
eda40b7
78d1101
4302c2b
78d1101
1c7cbff
78d1101
 
 
 
 
 
 
 
261a5aa
78d1101
 
 
24c3ce1
78d1101
 
7f3d8a9
dfb286c
 
7f3d8a9
 
261a5aa
7f3d8a9
 
261a5aa
7f3d8a9
 
261a5aa
7f3d8a9
 
261a5aa
dfb286c
 
7f3d8a9
 
d79365c
7f3d8a9
 
261a5aa
7f3d8a9
 
261a5aa
7f3d8a9
261a5aa
20a6b7b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig
import os
import unicodedata
from huggingface_hub import login

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)


@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/3b-new-400"  
    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
    
    # Configuration du tokenizer
    tokenizer.src_lang = src_lang
    
    # Tokenisation
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # ID du token de langue cible
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    
    # Génération avec les paramètres optimaux
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tgt_lang_id,
        max_new_tokens=1024,
        early_stopping=False,
        num_beams=5,
        no_repeat_ngram_size=0,
        length_penalty=1.0
    )
    
    # Décodage
    translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return translation

def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang)