ArissBandoss commited on
Commit
4a24c4f
·
verified ·
1 Parent(s): 7cc5f39

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +21 -26
goai_helpers/goai_traduction.py CHANGED
@@ -15,53 +15,48 @@ login(token=auth_token)
15
  def goai_traduction(text, src_lang, tgt_lang):
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
- if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
19
- model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
20
- elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
21
- model_id = "ArissBandoss/3b-new-400"
22
  else:
23
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
24
 
25
- print(f"Chargement du modèle: {model_id}")
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
27
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
28
 
29
- print(f"Texte brut ({len(text)} caractères / {len(text.split())} mots):")
30
- print(text)
31
-
32
- print(f"Configuration du modèle:")
33
- print(f"- tokenizer.model_max_length: {tokenizer.model_max_length}")
34
- print(f"- Position embeddings shape: {model.model.encoder.embed_positions.weights.shape}")
35
- print(f"- decoder.embed_positions shape: {model.model.decoder.embed_positions.weights.shape}")
36
-
37
  # Configuration du tokenizer
38
  tokenizer.src_lang = src_lang
39
 
40
  # Tokenisation
41
  inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
42
- input_ids = inputs["input_ids"][0]
43
 
44
- print("Tokens d'entrée:")
45
- print(f"- Nombre de tokens: {input_ids.shape[0]}")
46
- print(f"- Premiers tokens: {input_ids[:10].tolist()}")
47
- print(f"- Derniers tokens: {input_ids[-10:].tolist()}")
48
 
49
  # ID du token de langue cible
50
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
51
- print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
52
- bad_words_ids = [[tokenizer.eos_token_id]]
53
-
 
 
54
  outputs = model.generate(
55
  **inputs,
56
  forced_bos_token_id=tgt_lang_id,
57
- max_length=max_length,
58
- min_length=max_length,
59
  num_beams=5,
60
- no_repeat_ngram_size=0,
61
- length_penalty=2.0
 
 
 
62
  )
 
 
63
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
64
-
65
  return translation
66
 
67
  def real_time_traduction(input_text, src_lang, tgt_lang):
 
15
  def goai_traduction(text, src_lang, tgt_lang):
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ if src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
19
+ model_id = "ArissBandoss/3b-new-400"
 
 
20
  else:
21
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
22
 
 
23
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
24
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
25
 
 
 
 
 
 
 
 
 
26
  # Configuration du tokenizer
27
  tokenizer.src_lang = src_lang
28
 
29
  # Tokenisation
30
  inputs = tokenizer(text, return_tensors="pt", truncation=False).to(device)
31
+ input_length = inputs["input_ids"].shape[1]
32
 
33
+ # Estimation intelligente de la longueur de sortie attendue
34
+ # Pour le mooré vers français, un facteur de 1.2-1.5 est généralement bon
35
+ expected_output_length = int(input_length * 1.3)
 
36
 
37
  # ID du token de langue cible
38
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
39
+
40
+ # ID du token EOS
41
+ eos_token_id = tokenizer.eos_token_id
42
+
43
+ # Bloquer complètement le token EOS jusqu'à un certain point
44
  outputs = model.generate(
45
  **inputs,
46
  forced_bos_token_id=tgt_lang_id,
47
+ max_new_tokens=1024,
48
+ min_length=expected_output_length,
49
  num_beams=5,
50
+ no_repeat_ngram_size=4,
51
+ repetition_penalty=2.0,
52
+ length_penalty=1.5,
53
+ diversity_penalty=0.5,
54
+ num_beam_groups=5
55
  )
56
+
57
+ # Décodage
58
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
59
+
60
  return translation
61
 
62
  def real_time_traduction(input_text, src_lang, tgt_lang):