ArissBandoss commited on
Commit
af19dcb
·
verified ·
1 Parent(s): 91812f2

Update goai_helpers/goai_traduction.py

Browse files
Files changed (1) hide show
  1. goai_helpers/goai_traduction.py +18 -14
goai_helpers/goai_traduction.py CHANGED
@@ -26,21 +26,25 @@ def goai_traduction(text, src_lang, tgt_lang):
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
27
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)
28
 
29
- if model_id == "ArissBandoss/mos2fr-3B":
30
- try:
31
- peft_config = PeftConfig.from_pretrained("ArissBandoss/mos2fr-3B")
32
- base_model = AutoModelForSeq2SeqLM.from_pretrained(peft_config.base_model_name_or_path)
33
- model = PeftModel.from_pretrained(base_model, "ArissBandoss/mos2fr-3B")
34
- except ValueError:
35
- from huggingface_hub import hf_hub_download
36
- adapter_config_path = hf_hub_download(
37
- repo_id="ArissBarndoss/mos2fr-3B",
38
- filename="adapter_config.json"
39
- )
40
 
41
- peft_config = PeftConfig.from_json_file(adapter_config_path)
42
- base_model = AutoModelForSeq2SeqLM.from_pretrained(peft_config.base_model_name_or_path)
43
- model = PeftModel.from_pretrained(base_model, "ArissBandoss/mos2fr-3B")
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
 
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
27
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)
28
 
29
+ if model_id == "ArissBarndoss/mos2fr-3B":
30
+ peft_config = PeftConfig.from_pretrained("ArissBandoss/mos2fr-3B")
31
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(peft_config.base_model_name_or_path)
32
+ model = PeftModel.from_pretrained(base_model, "ArissBandoss/mos2fr-3B")
 
 
 
 
 
 
 
33
 
34
+ # Instead of using the pipeline, do direct generation
35
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
36
+ def translate(text, src_lang, tgt_lang, max_length=512):
37
+ inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True)
38
+ inputs = {k: v.to(device) for k, v in inputs.items()}
39
+
40
+ generation_kwargs = {}
41
+ if src_lang and tgt_lang:
42
+ generation_kwargs["forced_bos_token_id"] = tokenizer.lang_code_to_id[tgt_lang]
43
+
44
+ outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
45
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
46
+ translation_text = translate(text, src_lang, tgt_lang)
47
+ return translation_text
48
 
49
 
50