ArissBandoss commited on
Commit
b8808c9
·
verified ·
1 Parent(s): 3e462e2

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +2 -2
goai_helpers/goai_traduction.py CHANGED
@@ -34,12 +34,12 @@ def goai_traduction(text, src_lang, tgt_lang):
34
  # Instead of using the pipeline, do direct generation
35
  tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
36
  def translate(text, src_lang, tgt_lang, max_length=512):
37
- inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True)
38
  inputs = {k: v.to(device) for k, v in inputs.items()}
39
 
40
  generation_kwargs = {}
41
  if src_lang and tgt_lang:
42
- generation_kwargs["forced_bos_token_id"] = tokenizer.lang_code_to_id[tgt_lang]
43
 
44
  outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
45
  return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 
34
  # Instead of using the pipeline, do direct generation
35
  tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
36
  def translate(text, src_lang, tgt_lang, max_length=512):
37
+ inputs = tokenizer(text, return_tensors="pt", max_length=max_length)
38
  inputs = {k: v.to(device) for k, v in inputs.items()}
39
 
40
  generation_kwargs = {}
41
  if src_lang and tgt_lang:
42
+ generation_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids[tgt_lang]
43
 
44
  outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
45
  return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]