File size: 2,339 Bytes
78d1101
 
49dc84f
eda40b7
78d1101
4302c2b
78d1101
1c7cbff
78d1101
 
 
 
 
 
 
 
261a5aa
78d1101
 
 
f2b394c
78d1101
 
261a5aa
ef4152c
261a5aa
090b150
 
c0087f6
beb2b9a
 
49dc84f
 
 
 
 
 
 
c0087f6
261a5aa
 
 
 
 
 
065813f
261a5aa
 
 
 
 
 
49dc84f
 
261a5aa
5f4aaf1
49dc84f
5f4aaf1
 
 
261a5aa
 
 
9d520cc
 
261a5aa
20a6b7b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig
import os
import unicodedata
from huggingface_hub import login

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)


@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/mos2fr-3B-1200"
    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token, truncation=True, max_length=512)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
    print(model.lm_head.weight.shape)  # doit être [vocab_size, hidden_size]
    print(model.model.shared.weight.shape)  # idem

    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)

    generation_config = GenerationConfig(
    max_new_tokens=1024,
    early_stopping=False,
    decoder_start_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
    forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
    eos_token_id=tokenizer.eos_token_id,
    )

    
    # Ajout du code de langue source
    tokenizer.src_lang = src_lang
    
    # Tokenisation du texte d'entrée
    inputs = tokenizer(text, return_tensors="pt").to(device)
    print(inputs)
    
    # Utilisation de convert_tokens_to_ids au lieu de lang_code_to_id
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    
    # Génération avec paramètres améliorés
    outputs = model.generate(
    **inputs,
    generation_config=generation_config
    )


    print("Token IDs:", outputs)
    print("Tokens:", [tokenizer.decode([tok]) for tok in outputs[0]])

    
    # Décodage de la sortie
    translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    print("ici translation")
    print(translation)
    return translation


def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang)