File size: 1,722 Bytes
78d1101
 
49dc84f
eda40b7
78d1101
4302c2b
78d1101
1c7cbff
78d1101
 
 
 
 
 
488bb16
78d1101
261a5aa
4a24c4f
 
78d1101
 
7f3d8a9
dfb286c
 
fb3e214
7f3d8a9
261a5aa
7f3d8a9
 
d4b611c
4a24c4f
fb3e214
7f3d8a9
 
261a5aa
4a24c4f
 
 
 
 
c220549
 
 
4a24c4f
c220549
4a24c4f
07fc3c2
c220549
4a24c4f
 
7cc5f39
4a24c4f
261a5aa
20a6b7b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig
import os
import unicodedata
from huggingface_hub import login

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)


@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/3b-new-400"
    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
    
    # Configuration du tokenizer
    tokenizer.src_lang = src_lang
    
    # Tokenisation
    inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
    input_length = inputs["input_ids"].shape[1]
    
    
    # ID du token de langue cible
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    
    # ID du token EOS
    eos_token_id = tokenizer.eos_token_id
    
    # Bloquer complètement le token EOS jusqu'à un certain point
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tgt_lang_id,
        max_new_tokens=1024,
        num_beams=5,
        repetition_penalty=2.0,  
        length_penalty=2,
    )
    
    # Décodage
    translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return translation

def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang)