ArissBandoss commited on
Commit
dfb286c
·
verified ·
1 Parent(s): 090b150

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +12 -32
goai_helpers/goai_traduction.py CHANGED
@@ -21,50 +21,30 @@ def goai_traduction(text, src_lang, tgt_lang):
21
  model_id = "ArissBandoss/mos2fr-3B-1200"
22
  else:
23
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
24
-
25
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token, truncation=True, max_length=512)
26
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
27
- print(model.lm_head.weight.shape) # doit être [vocab_size, hidden_size]
28
- print(model.model.shared.weight.shape) # idem
29
-
30
- tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
31
 
32
- generation_config = GenerationConfig(
33
- max_new_tokens=1024,
34
- early_stopping=False,
35
- decoder_start_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
36
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
37
- eos_token_id=tokenizer.eos_token_id,
38
- )
39
 
40
-
41
- # Ajout du code de langue source
42
  tokenizer.src_lang = src_lang
43
-
44
- # Tokenisation du texte d'entrée
45
  inputs = tokenizer(text, return_tensors="pt").to(device)
46
- print(inputs)
47
-
48
- # Utilisation de convert_tokens_to_ids au lieu de lang_code_to_id
49
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
50
-
51
- # Génération avec paramètres améliorés
52
  outputs = model.generate(
53
- **inputs,
54
- generation_config=generation_config
 
 
 
 
55
  )
56
 
57
-
58
- print("Token IDs:", outputs)
59
- print("Tokens:", [tokenizer.decode([tok]) for tok in outputs[0]])
60
-
61
-
62
- # Décodage de la sortie
63
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
64
  print("ici translation")
65
  print(translation)
66
  return translation
67
 
68
-
69
  def real_time_traduction(input_text, src_lang, tgt_lang):
70
  return goai_traduction(input_text, src_lang, tgt_lang)
 
21
  model_id = "ArissBandoss/mos2fr-3B-1200"
22
  else:
23
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
 
 
 
 
 
 
 
24
 
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
26
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
 
 
 
 
 
27
 
 
 
28
  tokenizer.src_lang = src_lang
 
 
29
  inputs = tokenizer(text, return_tensors="pt").to(device)
30
+
31
+ # Ajout du code de langue cible
 
32
  tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
33
+
34
+ # Génération contrôlée
35
  outputs = model.generate(
36
+ **inputs,
37
+ forced_bos_token_id=tgt_lang_id,
38
+ eos_token_id=tokenizer.eos_token_id, # S’assurer que le modèle peut s’arrêter
39
+ max_length=512, # Teste avec 256 puis augmente progressivement
40
+ do_sample=False,
41
+ early_stopping=True
42
  )
43
 
 
 
 
 
 
 
44
  translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
45
  print("ici translation")
46
  print(translation)
47
  return translation
48
 
 
49
  def real_time_traduction(input_text, src_lang, tgt_lang):
50
  return goai_traduction(input_text, src_lang, tgt_lang)