ArissBandoss commited on
Commit
261a5aa
·
verified ·
1 Parent(s): 33f1bab

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +27 -18
goai_helpers/goai_traduction.py CHANGED
@@ -13,30 +13,39 @@ login(token=auth_token)
13
  @spaces.GPU
14
  def goai_traduction(text, src_lang, tgt_lang):
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
  if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
18
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
19
-
20
  elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
21
- model_id = "ArissBandoss/mos2fr-5B-800"
22
-
23
  else:
24
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
25
-
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
27
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)
28
-
29
-
30
-
31
- trans_pipe = pipeline("translation",
32
- model=model, tokenizer=tokenizer,
33
- src_lang=src_lang, tgt_lang=tgt_lang,
34
- max_length=max_length,
35
- max_new_tokens=512,
36
- device=device
37
- )
38
-
39
- return trans_pipe(text)[0]["translation_text"]
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def real_time_traduction(input_text, src_lang, tgt_lang):
 
13
  @spaces.GPU
14
  def goai_traduction(text, src_lang, tgt_lang):
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
  if src_lang == "fra_Latn" and tgt_lang == "mos_Latn":
18
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
 
19
  elif src_lang == "mos_Latn" and tgt_lang == "fra_Latn":
20
+ model_id = "ArissBandoss/mos2fr-3B"
 
21
  else:
22
  model_id = "ArissBandoss/nllb-200-distilled-600M-finetuned-fr-to-mos-V4"
23
+
24
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
25
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token).to(device)
26
+
27
+ # Ajout du code de langue source
28
+ tokenizer.src_lang = src_lang
29
+
30
+ # Tokenisation du texte d'entrée
31
+ inputs = tokenizer(text, return_tensors="pt").to(device)
32
+
33
+ # Utilisation de convert_tokens_to_ids au lieu de lang_code_to_id
34
+ tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang)
35
+
36
+ # Génération avec paramètres améliorés
37
+ outputs = model.generate(
38
+ **inputs,
39
+ forced_bos_token_id=tgt_lang_id,
40
+ max_new_tokens=1024,
41
+ num_beams=5,
42
+ early_stopping=True
43
+ )
44
+
45
+ # Décodage de la sortie
46
+ translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
47
+
48
+ return translation
49
 
50
 
51
  def real_time_traduction(input_text, src_lang, tgt_lang):