ArissBandoss commited on
Commit
c486284
·
verified ·
1 Parent(s): b8808c9

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +0 -20
goai_helpers/goai_traduction.py CHANGED
@@ -26,26 +26,6 @@ def goai_traduction(text, src_lang, tgt_lang):
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
27
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)
28
 
29
- if model_id == "ArissBandoss/mos2fr-5B-800":
30
- peft_config = PeftConfig.from_pretrained("ArissBandoss/mos2fr-5B-800")
31
- base_model = AutoModelForSeq2SeqLM.from_pretrained(peft_config.base_model_name_or_path)
32
- model = PeftModel.from_pretrained(base_model, "ArissBandoss/mos2fr-5B-800")
33
-
34
- # Instead of using the pipeline, do direct generation
35
- tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
36
- def translate(text, src_lang, tgt_lang, max_length=512):
37
- inputs = tokenizer(text, return_tensors="pt", max_length=max_length)
38
- inputs = {k: v.to(device) for k, v in inputs.items()}
39
-
40
- generation_kwargs = {}
41
- if src_lang and tgt_lang:
42
- generation_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids[tgt_lang]
43
-
44
- outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
45
- return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
46
- translation_text = translate(text, src_lang, tgt_lang)
47
- return translation_text
48
-
49
 
50
 
51
  trans_pipe = pipeline("translation",
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
27
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  trans_pipe = pipeline("translation",