Update goai_helpers/goai_traduction.py
Browse files
goai_helpers/goai_traduction.py
CHANGED
@@ -50,32 +50,16 @@ def goai_traduction(text, src_lang, tgt_lang):
|
|
50 |
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
|
51 |
print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
|
52 |
bad_words_ids = [[tokenizer.eos_token_id]]
|
53 |
-
|
54 |
-
for length_penalty in [1.0, 1.5, 2.0, 2.5, 3]:
|
55 |
-
for num_beams in [5, 10]:
|
56 |
-
print(f"\nTest avec length_penalty={length_penalty}, num_beams={num_beams}")
|
57 |
-
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
)
|
69 |
-
|
70 |
-
|
71 |
-
translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
72 |
-
|
73 |
-
print(f"Traduction ({len(translation)} caractères / {len(translation.split())} mots):")
|
74 |
-
print(translation)
|
75 |
-
output_ids = outputs[0]
|
76 |
-
print(f"- Nombre de tokens générés: {output_ids.shape[0]}")
|
77 |
-
print(f"- Premiers tokens générés: {output_ids[:10].tolist()}")
|
78 |
-
print(f"- Derniers tokens générés: {output_ids[-10:].tolist()}")
|
79 |
|
80 |
return translation
|
81 |
|
|
|
50 |
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
|
51 |
print(f"Token ID de la langue cible ({tgt_lang}): {tgt_lang_id}")
|
52 |
bad_words_ids = [[tokenizer.eos_token_id]]
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
outputs = model.generate(
|
55 |
+
**inputs,
|
56 |
+
forced_bos_token_id=tgt_lang_id,
|
57 |
+
max_length=max_length,
|
58 |
+
min_length=max_length,
|
59 |
+
num_beams=5,
|
60 |
+
no_repeat_ngram_size=0,
|
61 |
+
length_penalty=2.0
|
62 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
return translation
|
65 |
|