ArissBandoss commited on
Commit
c220549
·
verified ·
1 Parent(s): 4c61768

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +9 -25
goai_helpers/goai_traduction.py CHANGED
@@ -50,32 +50,16 @@ def goai_traduction(text, src_lang, tgt_lang):
50
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
51
  print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
52
  bad_words_ids = [[tokenizer.eos_token_id]]
53
-
54
- for length_penalty in [1.0, 1.5, 2.0, 2.5, 3]:
55
- for num_beams in [5, 10]:
56
- print(f"\nTest avec length_penalty={length_penalty}, num_beams={num_beams}")
57
-
58
 
59
- outputs = model.generate(
60
- **inputs,
61
- forced_bos_token_id=tgt_lang_id,
62
- max_new_tokens=2048,
63
- early_stopping=False,
64
- num_beams=num_beams,
65
- no_repeat_ngram_size=0,
66
- bad_words_ids=bad_words_ids,
67
- length_penalty=length_penalty
68
- )
69
-
70
-
71
- translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
72
-
73
- print(f"Traduction ({len(translation)} caractères / {len(translation.split())} mots):")
74
- print(translation)
75
- output_ids = outputs[0]
76
- print(f"- Nombre de tokens générés: {output_ids.shape[0]}")
77
- print(f"- Premiers tokens générés: {output_ids[:10].tolist()}")
78
- print(f"- Derniers tokens générés: {output_ids[-10:].tolist()}")
79
 
80
  return translation
81
 
 
50
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
51
  print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
52
  bad_words_ids = [[tokenizer.eos_token_id]]
 
 
 
 
 
53
 
54
+ outputs = model.generate(
55
+ **inputs,
56
+ forced_bos_token_id=tgt_lang_id,
57
+ max_length=max_length,
58
+ min_length=max_length,
59
+ num_beams=5,
60
+ no_repeat_ngram_size=0,
61
+ length_penalty=2.0
62
+ )
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  return translation
65