Update goai_helpers/goai_traduction.py
Browse files- goai_helpers/goai_traduction.py +18 -14
goai_helpers/goai_traduction.py
CHANGED
@@ -26,21 +26,25 @@ def goai_traduction(text, src_lang, tgt_lang):
|
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
|
27 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)
|
28 |
|
29 |
-
if model_id == "
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
model = PeftModel.from_pretrained(base_model, "ArissBandoss/mos2fr-3B")
|
34 |
-
except ValueError:
|
35 |
-
from huggingface_hub import hf_hub_download
|
36 |
-
adapter_config_path = hf_hub_download(
|
37 |
-
repo_id="ArissBarndoss/mos2fr-3B",
|
38 |
-
filename="adapter_config.json"
|
39 |
-
)
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
|
|
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
|
27 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=auth_token)
|
28 |
|
29 |
+
if model_id == "ArissBarndoss/mos2fr-3B":
|
30 |
+
peft_config = PeftConfig.from_pretrained("ArissBandoss/mos2fr-3B")
|
31 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(peft_config.base_model_name_or_path)
|
32 |
+
model = PeftModel.from_pretrained(base_model, "ArissBandoss/mos2fr-3B")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
# Instead of using the pipeline, do direct generation
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
|
36 |
+
def translate(text, src_lang, tgt_lang, max_length=512):
|
37 |
+
inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True)
|
38 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
39 |
+
|
40 |
+
generation_kwargs = {}
|
41 |
+
if src_lang and tgt_lang:
|
42 |
+
generation_kwargs["forced_bos_token_id"] = tokenizer.lang_code_to_id[tgt_lang]
|
43 |
+
|
44 |
+
outputs = model.generate(**inputs, max_length=max_length, **generation_kwargs)
|
45 |
+
return tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
46 |
+
translation_text = translate(text, src_lang, tgt_lang)
|
47 |
+
return translation_text
|
48 |
|
49 |
|
50 |
|