File size: 2,517 Bytes
78d1101
 
49dc84f
eda40b7
78d1101
4302c2b
78d1101
1c7cbff
78d1101
 
 
 
 
 
488bb16
78d1101
261a5aa
78d1101
 
 
24c3ce1
78d1101
 
7f3d8a9
fb3e214
dfb286c
 
fb3e214
8cf4d3b
 
fb3e214
 
 
 
 
 
7f3d8a9
261a5aa
7f3d8a9
 
d4b611c
fb3e214
 
 
 
 
 
7f3d8a9
 
261a5aa
fb3e214
289156f
fb3e214
c220549
 
 
 
 
 
 
 
 
7cc5f39
fb3e214
261a5aa
20a6b7b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig
import os
import unicodedata
from huggingface_hub import login

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)


@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/3b-new-400"  
    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    
    print(f"Chargement du modèle: {model_id}")
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
    
    print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):")
    print(text)
    
    print(f"Configuration du modèle:")
    print(f"- tokenizer.model_max_length: {tokenizer.model_max_length}")
    print(f"- Position embeddings shape: {model.model.encoder.embed_positions.weights.shape}")
    print(f"- decoder.embed_positions shape: {model.model.decoder.embed_positions.weights.shape}")
    
    # Configuration du tokenizer
    tokenizer.src_lang = src_lang
    
    # Tokenisation
    inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
    input_ids = inputs["input_ids"][0]
    
    print("Tokens d'entrée:")
    print(f"- Nombre de tokens: {input_ids.shape[0]}")
    print(f"- Premiers tokens: {input_ids[:10].tolist()}")
    print(f"- Derniers tokens: {input_ids[-10:].tolist()}")
    
    # ID du token de langue cible
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
    bad_words_ids = [[tokenizer.eos_token_id]]

    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tgt_lang_id,
        max_length=max_length,
        min_length=max_length,
        num_beams=5,
        no_repeat_ngram_size=0,
        length_penalty=2.0
    )
    translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            
    return translation

def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang)