Update goai_helpers/goai_traduction.py
Browse files
goai_helpers/goai_traduction.py
CHANGED
|
@@ -25,6 +25,11 @@ def goai_traduction(text, src_lang, tgt_lang):
|
|
| 25 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token, truncation=True, max_length=512)
|
| 26 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
print("decoder_start_token_id:", model.config.decoder_start_token_id)
|
| 29 |
print("forced_bos_token_id:", model.config.forced_bos_token_id)
|
| 30 |
print("eos_token_id:", model.config.eos_token_id)
|
|
@@ -45,10 +50,7 @@ def goai_traduction(text, src_lang, tgt_lang):
|
|
| 45 |
**inputs,
|
| 46 |
forced_bos_token_id=tgt_lang_id,
|
| 47 |
eos_token_id=tokenizer.eos_token_id,
|
| 48 |
-
max_length=512
|
| 49 |
-
num_beams=4,
|
| 50 |
-
do_sample=False,
|
| 51 |
-
no_repeat_ngram_size=3
|
| 52 |
)
|
| 53 |
|
| 54 |
print("Token IDs:", outputs)
|
|
|
|
| 25 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token, truncation=True, max_length=512)
|
| 26 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
|
| 27 |
|
| 28 |
+
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
|
| 29 |
+
|
| 30 |
+
model.config.forced_bos_token_id = tgt_lang_id
|
| 31 |
+
model.config.decoder_start_token_id = tgt_lang_id
|
| 32 |
+
|
| 33 |
print("decoder_start_token_id:", model.config.decoder_start_token_id)
|
| 34 |
print("forced_bos_token_id:", model.config.forced_bos_token_id)
|
| 35 |
print("eos_token_id:", model.config.eos_token_id)
|
|
|
|
| 50 |
**inputs,
|
| 51 |
forced_bos_token_id=tgt_lang_id,
|
| 52 |
eos_token_id=tokenizer.eos_token_id,
|
| 53 |
+
max_length=512
|
|
|
|
|
|
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
print("Token IDs:", outputs)
|