Update goai_helpers/goai_traduction.py
Browse files
goai_helpers/goai_traduction.py
CHANGED
@@ -34,12 +34,12 @@ def goai_traduction(text, src_lang, tgt_lang):
|
|
34 |
# Instead of using the pipeline, do direct generation
|
35 |
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
|
36 |
def translate(text, src_lang, tgt_lang, max_length=512):
|
37 |
-
inputs = tokenizer(text, return_tensors="pt", max_length=max_length
|
38 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
39 |
|
40 |
generation_kwargs = {}
|
41 |
if src_lang and tgt_lang:
|
42 |
-
generation_kwargs["forced_bos_token_id"] = tokenizer.
|
43 |
|
44 |
outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
|
45 |
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
|
34 |
# Instead of using the pipeline, do direct generation
|
35 |
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
|
36 |
def translate(text, src_lang, tgt_lang, max_length=512):
|
37 |
+
inputs = tokenizer(text, return_tensors="pt", max_length=max_length)
|
38 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
39 |
|
40 |
generation_kwargs = {}
|
41 |
if src_lang and tgt_lang:
|
42 |
+
generation_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids[tgt_lang]
|
43 |
|
44 |
outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
|
45 |
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|