ArissBandoss commited on
Commit
49dc84f
·
verified ·
1 Parent(s): beb2b9a

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +11 -11
goai_helpers/goai_traduction.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import spaces
3
- from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
4
  from peft import PeftModel, PeftConfig
5
  import os
6
  import unicodedata
@@ -27,12 +27,13 @@ def goai_traduction(text, src_lang, tgt_lang):
27
 
28
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
29
 
30
- model.config.forced_bos_token_id = tgt_lang_id
31
- model.config.decoder_start_token_id = tgt_lang_id
32
-
33
- print("decoder_start_token_id:", model.config.decoder_start_token_id)
34
- print("forced_bos_token_id:", model.config.forced_bos_token_id)
35
- print("eos_token_id:", model.config.eos_token_id)
 
36
 
37
 
38
  # Ajout du code de langue source
@@ -47,12 +48,11 @@ def goai_traduction(text, src_lang, tgt_lang):
47
 
48
  # Génération avec paramètres améliorés
49
  outputs = model.generate(
50
- **inputs,
51
- forced_bos_token_id=tgt_lang_id,
52
- eos_token_id=tokenizer.eos_token_id,
53
- max_length=512
54
  )
55
 
 
56
  print("Token IDs:", outputs)
57
  print("Tokens:", [tokenizer.decode([tok]) for tok in outputs[0]])
58
 
 
1
  import torch
2
  import spaces
3
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
4
  from peft import PeftModel, PeftConfig
5
  import os
6
  import unicodedata
 
27
 
28
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
29
 
30
+ generation_config = GenerationConfig(
31
+ max_new_tokens=1024,
32
+ early_stopping=False,
33
+ decoder_start_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
34
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
35
+ eos_token_id=tokenizer.eos_token_id,
36
+ )
37
 
38
 
39
  # Ajout du code de langue source
 
48
 
49
  # Génération avec paramètres améliorés
50
  outputs = model.generate(
51
+ **inputs,
52
+ generation_config=generation_config
 
 
53
  )
54
 
55
+
56
  print("Token IDs:", outputs)
57
  print("Tokens:", [tokenizer.decode([tok]) for tok in outputs[0]])
58