File size: 3,232 Bytes
78d1101
 
49dc84f
eda40b7
78d1101
4302c2b
78d1101
1c7cbff
78d1101
 
 
 
 
 
488bb16
78d1101
261a5aa
78d1101
 
 
24c3ce1
78d1101
 
7f3d8a9
fb3e214
dfb286c
 
fb3e214
8cf4d3b
 
fb3e214
 
 
 
 
 
7f3d8a9
261a5aa
7f3d8a9
 
d4b611c
fb3e214
 
 
 
 
 
7f3d8a9
 
261a5aa
fb3e214
7f3d8a9
fb3e214
28ec2c5
fb3e214
 
 
 
 
 
 
 
 
 
28ec2c5
fb3e214
 
 
 
 
 
 
 
 
 
 
 
 
261a5aa
20a6b7b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import spaces
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from peft import PeftModel, PeftConfig
import os
import unicodedata
from huggingface_hub import login

max_length = 512
auth_token = os.getenv('HF_SPACE_TOKEN')
login(token=auth_token)


@spaces.GPU
def goai_traduction(text, src_lang, tgt_lang):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
        model_id = "ArissBandoss/3b-new-400"  
    else:
        model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
    
    print(f"Chargement du modèle: {model_id}")
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
    
    print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):")
    print(text)
    
    print(f"Configuration du modèle:")
    print(f"- tokenizer.model_max_length: {tokenizer.model_max_length}")
    print(f"- Position embeddings shape: {model.model.encoder.embed_positions.weights.shape}")
    print(f"- decoder.embed_positions shape: {model.model.decoder.embed_positions.weights.shape}")
    
    # Configuration du tokenizer
    tokenizer.src_lang = src_lang
    
    # Tokenisation
    inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
    input_ids = inputs["input_ids"][0]
    
    print("Tokens d'entrée:")
    print(f"- Nombre de tokens: {input_ids.shape[0]}")
    print(f"- Premiers tokens: {input_ids[:10].tolist()}")
    print(f"- Derniers tokens: {input_ids[-10:].tolist()}")
    
    # ID du token de langue cible
    tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
    print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
    
    for length_penalty in [1.0, 1.5, 2.0]:
        for num_beams in [5, 10]:
            print(f"\nTest avec length_penalty={length_penalty}, num_beams={num_beams}")
            

            outputs = model.generate(
                **inputs,
                forced_bos_token_id=tgt_lang_id,
                max_new_tokens=2048,
                early_stopping=False,
                num_beams=num_beams,
                no_repeat_ngram_size=0,
                bad_words_ids=bad_words_ids,
                length_penalty=length_penalty
            )
            
        
            translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            
            print(f"Traduction ({len(translation)} caractères / {len(translation.split())} mots):")
            print(translation)
            output_ids = outputs[0]
            print(f"- Nombre de tokens générés: {output_ids.shape[0]}")
            print(f"- Premiers tokens générés: {output_ids[:10].tolist()}")
            print(f"- Derniers tokens générés: {output_ids[-10:].tolist()}")
            
    return translation

def real_time_traduction(input_text, src_lang, tgt_lang):
    return goai_traduction(input_text, src_lang, tgt_lang)