Spaces:
Sleeping
Sleeping
Commit
·
8115742
1
Parent(s):
bfe203c
- inference.py +34 -19
inference.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1 |
import torch
|
2 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
3 |
from peft import PeftModel
|
4 |
|
5 |
-
#
|
6 |
MODEL_OPTIONS = {
|
7 |
"mT5": {
|
8 |
"base": "google/mt5-base",
|
@@ -12,12 +17,13 @@ MODEL_OPTIONS = {
|
|
12 |
"base": "facebook/mbart-large-50-many-to-many-mmt",
|
13 |
"adapter": "asmashayea/mbart-absa"
|
14 |
},
|
|
|
15 |
"GPT3.5": {
|
16 |
-
"base": "bigscience/bloom-560m", #
|
17 |
"adapter": "asmashayea/gpt-absa"
|
18 |
},
|
19 |
"GPT4o": {
|
20 |
-
"base": "bigscience/bloom-560m", #
|
21 |
"adapter": "asmashayea/gpt-absa"
|
22 |
}
|
23 |
}
|
@@ -31,8 +37,14 @@ def load_model(model_key):
|
|
31 |
base_id = MODEL_OPTIONS[model_key]["base"]
|
32 |
adapter_id = MODEL_OPTIONS[model_key]["adapter"]
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
model = PeftModel.from_pretrained(base_model, adapter_id)
|
37 |
model.eval()
|
38 |
|
@@ -41,22 +53,25 @@ def load_model(model_key):
|
|
41 |
|
42 |
def predict_absa(text, model_choice):
|
43 |
tokenizer, model = load_model(model_choice)
|
44 |
-
prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
|
45 |
|
46 |
if model_choice == "mBART":
|
47 |
-
#
|
48 |
-
tokenizer
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
56 |
else:
|
57 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
58 |
-
|
|
|
59 |
|
60 |
-
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
61 |
return decoded
|
62 |
-
|
|
|
1 |
import torch
|
2 |
+
from transformers import (
|
3 |
+
AutoTokenizer,
|
4 |
+
AutoModelForSeq2SeqLM,
|
5 |
+
MBartForConditionalGeneration,
|
6 |
+
MBart50TokenizerFast
|
7 |
+
)
|
8 |
from peft import PeftModel
|
9 |
|
10 |
+
# Supported models and their adapter IDs
|
11 |
MODEL_OPTIONS = {
|
12 |
"mT5": {
|
13 |
"base": "google/mt5-base",
|
|
|
17 |
"base": "facebook/mbart-large-50-many-to-many-mmt",
|
18 |
"adapter": "asmashayea/mbart-absa"
|
19 |
},
|
20 |
+
# You can customize GPT-like entries later
|
21 |
"GPT3.5": {
|
22 |
+
"base": "bigscience/bloom-560m", # Placeholder only
|
23 |
"adapter": "asmashayea/gpt-absa"
|
24 |
},
|
25 |
"GPT4o": {
|
26 |
+
"base": "bigscience/bloom-560m", # Placeholder only
|
27 |
"adapter": "asmashayea/gpt-absa"
|
28 |
}
|
29 |
}
|
|
|
37 |
base_id = MODEL_OPTIONS[model_key]["base"]
|
38 |
adapter_id = MODEL_OPTIONS[model_key]["adapter"]
|
39 |
|
40 |
+
if model_key == "mBART":
|
41 |
+
tokenizer = MBart50TokenizerFast.from_pretrained(base_id)
|
42 |
+
tokenizer.src_lang = "ar_AR" # Required for input
|
43 |
+
base_model = MBartForConditionalGeneration.from_pretrained(base_id)
|
44 |
+
else:
|
45 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter_id)
|
46 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
|
47 |
+
|
48 |
model = PeftModel.from_pretrained(base_model, adapter_id)
|
49 |
model.eval()
|
50 |
|
|
|
53 |
|
54 |
def predict_absa(text, model_choice):
|
55 |
tokenizer, model = load_model(model_choice)
|
56 |
+
prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text.strip()}"
|
57 |
|
58 |
if model_choice == "mBART":
|
59 |
+
tokenizer.tgt_lang = "ar_AR" # Required for output
|
60 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
outputs = model.generate(
|
64 |
+
**inputs,
|
65 |
+
max_new_tokens=128,
|
66 |
+
do_sample=False,
|
67 |
+
temperature=0.0,
|
68 |
+
forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"],
|
69 |
+
pad_token_id=tokenizer.pad_token_id,
|
70 |
+
)
|
71 |
else:
|
72 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
73 |
+
with torch.no_grad():
|
74 |
+
outputs = model.generate(**inputs, max_new_tokens=128)
|
75 |
|
76 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
77 |
return decoded
|
|