Spaces:
Sleeping
Sleeping
Commit
·
bfc4d6a
1
Parent(s):
5825fbe
- app.py +1 -1
- inference.py +5 -10
- seq2seq_inference.py +9 -38
app.py
CHANGED
@@ -15,7 +15,7 @@ demo = gr.Interface(
|
|
15 |
],
|
16 |
outputs=gr.Textbox(label="Extracted Aspect-Sentiment-Opinion Triplets"),
|
17 |
title="Arabic ABSA (Aspect-Based Sentiment Analysis)",
|
18 |
-
description="Choose a model (Araberta, mT5,
|
19 |
)
|
20 |
|
21 |
if __name__ == "__main__":
|
|
|
15 |
],
|
16 |
outputs=gr.Textbox(label="Extracted Aspect-Sentiment-Opinion Triplets"),
|
17 |
title="Arabic ABSA (Aspect-Based Sentiment Analysis)",
|
18 |
+
description="Choose a model (Araberta, mT5, GPT) to extract aspects, opinions, and sentiment using LoRA adapters"
|
19 |
)
|
20 |
|
21 |
if __name__ == "__main__":
|
inference.py
CHANGED
@@ -3,7 +3,7 @@ import json
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
|
4 |
from peft import LoraConfig, get_peft_model
|
5 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
6 |
-
from seq2seq_inference import infer_t5_prompt
|
7 |
from peft import LoraConfig, get_peft_model, PeftModel
|
8 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
9 |
|
@@ -18,10 +18,6 @@ MODEL_OPTIONS = {
|
|
18 |
"base": "google/mt5-base",
|
19 |
"adapter": "asmashayea/mt4-absa"
|
20 |
},
|
21 |
-
# "mBART": {
|
22 |
-
# "base": "facebook/mbart-large-50-many-to-many-mmt",
|
23 |
-
# "adapter": "asmashayea/mbart-absa"
|
24 |
-
# },
|
25 |
"GPT3.5": {
|
26 |
"base": "bigscience/bloom-560m", # example, not ideal for ABSA
|
27 |
"adapter": "asmashayea/gpt-absa"
|
@@ -131,17 +127,16 @@ def load_model(model_key):
|
|
131 |
cached_models[model_key] = (tokenizer, model)
|
132 |
return tokenizer, model
|
133 |
|
|
|
|
|
|
|
134 |
def predict_absa(text, model_choice):
|
135 |
|
136 |
|
137 |
-
if model_choice
|
138 |
tokenizer, model = load_model(model_choice)
|
139 |
decoded = infer_t5_prompt(text, tokenizer, model)
|
140 |
|
141 |
-
elif model_choice == 'mBART':
|
142 |
-
tokenizer, model = load_model(model_choice)
|
143 |
-
decoded = infer_mBart_prompt(text, tokenizer, model)
|
144 |
-
|
145 |
elif model_choice == 'Araberta':
|
146 |
|
147 |
decoded = infer_araberta(text)
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
|
4 |
from peft import LoraConfig, get_peft_model
|
5 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
6 |
+
from seq2seq_inference import infer_t5_prompt
|
7 |
from peft import LoraConfig, get_peft_model, PeftModel
|
8 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
9 |
|
|
|
18 |
"base": "google/mt5-base",
|
19 |
"adapter": "asmashayea/mt4-absa"
|
20 |
},
|
|
|
|
|
|
|
|
|
21 |
"GPT3.5": {
|
22 |
"base": "bigscience/bloom-560m", # example, not ideal for ABSA
|
23 |
"adapter": "asmashayea/gpt-absa"
|
|
|
127 |
cached_models[model_key] = (tokenizer, model)
|
128 |
return tokenizer, model
|
129 |
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
def predict_absa(text, model_choice):
|
134 |
|
135 |
|
136 |
+
if model_choice in ['mT5', 'mBART']:
|
137 |
tokenizer, model = load_model(model_choice)
|
138 |
decoded = infer_t5_prompt(text, tokenizer, model)
|
139 |
|
|
|
|
|
|
|
|
|
140 |
elif model_choice == 'Araberta':
|
141 |
|
142 |
decoded = infer_araberta(text)
|
seq2seq_inference.py
CHANGED
@@ -3,7 +3,6 @@ import torch
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from peft import PeftModel
|
5 |
|
6 |
-
# Updated Prediction Function for mBART
|
7 |
SYSTEM_PROMPT = (
|
8 |
"You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.\n\n"
|
9 |
"Instructions:\n"
|
@@ -19,38 +18,14 @@ SYSTEM_PROMPT = (
|
|
19 |
)
|
20 |
|
21 |
|
22 |
-
def infer_mBart_prompt(review_text, tokenizer, model):
|
23 |
-
# Set target language for mBART
|
24 |
-
tokenizer.tgt_lang = "ar_AR" # Change as needed ("en_XX" for English)
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
-
inputs = tokenizer(prompt, return_tensors="pt", padding=
|
29 |
|
30 |
with torch.no_grad():
|
31 |
-
outputs =
|
32 |
-
**inputs,
|
33 |
-
max_new_tokens=128,
|
34 |
-
do_sample=False,
|
35 |
-
temperature=0.0,
|
36 |
-
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tokenizer.tgt_lang), # safer
|
37 |
-
pad_token_id=tokenizer.pad_token_id
|
38 |
-
)
|
39 |
-
|
40 |
-
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).replace("ar_AR ", "").replace("en_XX ", "").strip()
|
41 |
-
return decoded
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
def infer_t5_prompt(review_text, tokenizer, model):
|
47 |
-
prompt = (
|
48 |
-
SYSTEM_PROMPT + f"\n\nReview: {review_text}"
|
49 |
-
)
|
50 |
-
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
|
51 |
-
|
52 |
-
with torch.no_grad():
|
53 |
-
outputs = model.generate(
|
54 |
**inputs,
|
55 |
max_new_tokens=256,
|
56 |
num_beams=4,
|
@@ -62,18 +37,14 @@ def infer_t5_prompt(review_text, tokenizer, model):
|
|
62 |
)
|
63 |
|
64 |
decoded = tokenizer.decode(
|
65 |
-
outputs[0],
|
66 |
-
skip_special_tokens=True,
|
67 |
clean_up_tokenization_spaces=False
|
68 |
).strip()
|
69 |
|
70 |
-
# Optional: remove T5 special tokens like <extra_id_0>
|
71 |
decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
|
72 |
|
73 |
try:
|
74 |
-
|
75 |
-
except json.JSONDecodeError
|
76 |
-
|
77 |
-
json_output = decoded
|
78 |
-
|
79 |
-
return json_output
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from peft import PeftModel
|
5 |
|
|
|
6 |
SYSTEM_PROMPT = (
|
7 |
"You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.\n\n"
|
8 |
"Instructions:\n"
|
|
|
18 |
)
|
19 |
|
20 |
|
|
|
|
|
|
|
21 |
|
22 |
+
def infer_t5_prompt(review_text, tokenizer, peft_model):
|
23 |
+
prompt = SYSTEM_PROMPT + f"\n\nReview: {review_text}"
|
24 |
|
25 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(peft_model.device)
|
26 |
|
27 |
with torch.no_grad():
|
28 |
+
outputs = peft_model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
**inputs,
|
30 |
max_new_tokens=256,
|
31 |
num_beams=4,
|
|
|
37 |
)
|
38 |
|
39 |
decoded = tokenizer.decode(
|
40 |
+
outputs[0],
|
41 |
+
skip_special_tokens=True,
|
42 |
clean_up_tokenization_spaces=False
|
43 |
).strip()
|
44 |
|
|
|
45 |
decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
|
46 |
|
47 |
try:
|
48 |
+
return json.loads(decoded)
|
49 |
+
except json.JSONDecodeError:
|
50 |
+
return decoded
|
|
|
|
|
|