ArissBandoss commited on
Commit
beb2b9a
·
verified ·
1 Parent(s): c0087f6

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +6 -4
goai_helpers/goai_traduction.py CHANGED
@@ -25,6 +25,11 @@ def goai_traduction(text, src_lang, tgt_lang):
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token, truncation=True, max_length=512)
26
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
27
 
 
 
 
 
 
28
  print("decoder_start_token_id:", model.config.decoder_start_token_id)
29
  print("forced_bos_token_id:", model.config.forced_bos_token_id)
30
  print("eos_token_id:", model.config.eos_token_id)
@@ -45,10 +50,7 @@ def goai_traduction(text, src_lang, tgt_lang):
45
  **inputs,
46
  forced_bos_token_id=tgt_lang_id,
47
  eos_token_id=tokenizer.eos_token_id,
48
- max_length=512,
49
- num_beams=4,
50
- do_sample=False,
51
- no_repeat_ngram_size=3
52
  )
53
 
54
  print("Token IDs:", outputs)
 
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token, truncation=True, max_length=512)
26
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
27
 
28
+ tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
29
+
30
+ model.config.forced_bos_token_id = tgt_lang_id
31
+ model.config.decoder_start_token_id = tgt_lang_id
32
+
33
  print("decoder_start_token_id:", model.config.decoder_start_token_id)
34
  print("forced_bos_token_id:", model.config.forced_bos_token_id)
35
  print("eos_token_id:", model.config.eos_token_id)
 
50
  **inputs,
51
  forced_bos_token_id=tgt_lang_id,
52
  eos_token_id=tokenizer.eos_token_id,
53
+ max_length=512
 
 
 
54
  )
55
 
56
  print("Token IDs:", outputs)