Spaces:
Sleeping
Sleeping
Commit
·
e90dd4b
1
Parent(s):
afe677f
araberta
Browse files- inference.py +89 -32
- model.py → modeling_bilstm_crf.py +21 -10
inference.py
CHANGED
@@ -1,35 +1,96 @@
|
|
1 |
import torch
|
2 |
-
|
3 |
-
from
|
|
|
|
|
4 |
from seq2seq_inference import infer_t5_prompt, infer_mBart_prompt
|
5 |
-
from
|
6 |
-
|
7 |
-
# Define supported models and their adapter IDs
|
8 |
-
MODEL_OPTIONS = {
|
9 |
-
"Araberta": {
|
10 |
-
"base": "asmashayea/absa-araberta",
|
11 |
-
"adapter": "asmashayea/absa-araberta"
|
12 |
-
},
|
13 |
-
"mT5": {
|
14 |
-
"base": "google/mt5-base",
|
15 |
-
"adapter": "asmashayea/mt4-absa"
|
16 |
-
},
|
17 |
-
"mBART": {
|
18 |
-
"base": "facebook/mbart-large-50-many-to-many-mmt",
|
19 |
-
"adapter": "asmashayea/mbart-absa"
|
20 |
-
},
|
21 |
-
"GPT3.5": {
|
22 |
-
"base": "bigscience/bloom-560m", # example, not ideal for ABSA
|
23 |
-
"adapter": "asmashayea/gpt-absa"
|
24 |
-
},
|
25 |
-
"GPT4o": {
|
26 |
-
"base": "bigscience/bloom-560m", # example, not ideal for ABSA
|
27 |
-
"adapter": "asmashayea/gpt-absa"
|
28 |
-
}
|
29 |
-
}
|
30 |
|
31 |
cached_models = {}
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def load_model(model_key):
|
34 |
if model_key in cached_models:
|
35 |
return cached_models[model_key]
|
@@ -47,8 +108,6 @@ def load_model(model_key):
|
|
47 |
|
48 |
def predict_absa(text, model_choice):
|
49 |
|
50 |
-
|
51 |
-
|
52 |
|
53 |
if model_choice == 'mT5':
|
54 |
tokenizer, model = load_model(model_choice)
|
@@ -60,9 +119,7 @@ def predict_absa(text, model_choice):
|
|
60 |
|
61 |
elif model_choice == 'Araberta':
|
62 |
|
63 |
-
|
64 |
-
tokenizer = AutoTokenizer.from_pretrained("asmashayea/absa-araberta")
|
65 |
-
decoded = infer_mBart_prompt(text, tokenizer, model)
|
66 |
|
67 |
|
68 |
# prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
|
|
|
1 |
import torch
|
2 |
+
import json
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
|
4 |
+
from peft import LoraConfig, get_peft_model
|
5 |
+
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
6 |
from seq2seq_inference import infer_t5_prompt, infer_mBart_prompt
|
7 |
+
from peft import LoraConfig, get_peft_model, PeftModel
|
8 |
+
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
cached_models = {}
|
11 |
|
12 |
+
def load_araberta():
|
13 |
+
path = "asmashayea/absa-arabert"
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
16 |
+
base_model = AutoModel.from_pretrained(path)
|
17 |
+
lora_config = LoraConfig.from_pretrained(path)
|
18 |
+
lora_model = get_peft_model(base_model, lora_config)
|
19 |
+
|
20 |
+
config = AutoConfig.from_pretrained(path)
|
21 |
+
model = BERT_BiLSTM_CRF(lora_model, config)
|
22 |
+
model.load_state_dict(torch.load("bilstm_crf_head.pt"))
|
23 |
+
model.eval()
|
24 |
+
|
25 |
+
cached_models["Araberta"] = (tokenizer, model)
|
26 |
+
return tokenizer, model
|
27 |
+
|
28 |
+
|
29 |
+
def infer_araberta(text):
|
30 |
+
if "Araberta" not in cached_models:
|
31 |
+
tokenizer, model = load_araberta()
|
32 |
+
else:
|
33 |
+
tokenizer, model = cached_models["Araberta"]
|
34 |
+
|
35 |
+
|
36 |
+
device = next(model.parameters()).device
|
37 |
+
|
38 |
+
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
|
39 |
+
input_ids = inputs['input_ids'].to(device)
|
40 |
+
attention_mask = inputs['attention_mask'].to(device)
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
44 |
+
predicted_ids = outputs['logits'][0].cpu().tolist()
|
45 |
+
|
46 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
|
47 |
+
predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
|
48 |
+
|
49 |
+
clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
|
50 |
+
clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
|
51 |
+
|
52 |
+
# Horizontal output
|
53 |
+
pairs = [f"{token}: {label}" for token, label in zip(clean_tokens, clean_labels)]
|
54 |
+
horizontal_output = " | ".join(pairs)
|
55 |
+
|
56 |
+
# Group by aspect span
|
57 |
+
aspects = []
|
58 |
+
current_tokens = []
|
59 |
+
current_sentiment = None
|
60 |
+
|
61 |
+
for token, label in zip(clean_tokens, clean_labels):
|
62 |
+
if label.startswith("B-"):
|
63 |
+
if current_tokens:
|
64 |
+
aspects.append({
|
65 |
+
"aspect": " ".join(current_tokens).replace("##", ""),
|
66 |
+
"sentiment": current_sentiment
|
67 |
+
})
|
68 |
+
current_tokens = [token]
|
69 |
+
current_sentiment = label.split("-")[1]
|
70 |
+
elif label.startswith("I-") and current_sentiment == label.split("-")[1]:
|
71 |
+
current_tokens.append(token)
|
72 |
+
else:
|
73 |
+
if current_tokens:
|
74 |
+
aspects.append({
|
75 |
+
"aspect": " ".join(current_tokens).replace("##", ""),
|
76 |
+
"sentiment": current_sentiment
|
77 |
+
})
|
78 |
+
current_tokens = []
|
79 |
+
current_sentiment = None
|
80 |
+
|
81 |
+
if current_tokens:
|
82 |
+
aspects.append({
|
83 |
+
"aspect": " ".join(current_tokens).replace("##", ""),
|
84 |
+
"sentiment": current_sentiment
|
85 |
+
})
|
86 |
+
|
87 |
+
return {
|
88 |
+
"token_predictions": horizontal_output,
|
89 |
+
"aspects": aspects
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
def load_model(model_key):
|
95 |
if model_key in cached_models:
|
96 |
return cached_models[model_key]
|
|
|
108 |
|
109 |
def predict_absa(text, model_choice):
|
110 |
|
|
|
|
|
111 |
|
112 |
if model_choice == 'mT5':
|
113 |
tokenizer, model = load_model(model_choice)
|
|
|
119 |
|
120 |
elif model_choice == 'Araberta':
|
121 |
|
122 |
+
decoded = infer_araberta(text)
|
|
|
|
|
123 |
|
124 |
|
125 |
# prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
|
model.py → modeling_bilstm_crf.py
RENAMED
@@ -3,32 +3,43 @@ import torch.nn as nn
|
|
3 |
from torchcrf import CRF
|
4 |
|
5 |
class BERT_BiLSTM_CRF(nn.Module):
|
6 |
-
def __init__(self, base_model,
|
7 |
super().__init__()
|
8 |
self.bert = base_model
|
|
|
|
|
|
|
|
|
9 |
self.bilstm = nn.LSTM(
|
10 |
-
|
11 |
-
|
12 |
num_layers=2,
|
13 |
batch_first=True,
|
14 |
bidirectional=True,
|
15 |
-
dropout=
|
16 |
)
|
17 |
self.dropout = nn.Dropout(dropout_rate)
|
18 |
-
self.classifier = nn.Linear(rnn_dim * 2, num_labels)
|
19 |
-
self.crf = CRF(num_labels, batch_first=True)
|
20 |
|
21 |
def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
emissions = self.classifier(lstm_out)
|
25 |
mask = attention_mask.bool()
|
26 |
|
27 |
if labels is not None:
|
28 |
safe_labels = labels.clone()
|
29 |
-
safe_labels[labels == -100] =
|
30 |
loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
|
31 |
return {'loss': loss, 'logits': emissions}
|
32 |
else:
|
33 |
decoded = self.crf.decode(emissions, mask=mask)
|
34 |
-
|
|
|
|
|
|
|
|
3 |
from torchcrf import CRF
|
4 |
|
5 |
class BERT_BiLSTM_CRF(nn.Module):
|
6 |
+
def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256):
|
7 |
super().__init__()
|
8 |
self.bert = base_model
|
9 |
+
self.label2id = config.label2id # <-- pulled from config
|
10 |
+
self.id2label = config.id2label
|
11 |
+
self.num_labels = config.num_labels
|
12 |
+
|
13 |
self.bilstm = nn.LSTM(
|
14 |
+
self.bert.config.hidden_size,
|
15 |
+
rnn_dim,
|
16 |
num_layers=2,
|
17 |
batch_first=True,
|
18 |
bidirectional=True,
|
19 |
+
dropout=0.2
|
20 |
)
|
21 |
self.dropout = nn.Dropout(dropout_rate)
|
22 |
+
self.classifier = nn.Linear(rnn_dim * 2, self.num_labels)
|
23 |
+
self.crf = CRF(self.num_labels, batch_first=True)
|
24 |
|
25 |
def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
|
26 |
+
outputs = self.bert(
|
27 |
+
input_ids=input_ids,
|
28 |
+
attention_mask=attention_mask,
|
29 |
+
token_type_ids=token_type_ids
|
30 |
+
)
|
31 |
+
lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state))
|
32 |
emissions = self.classifier(lstm_out)
|
33 |
mask = attention_mask.bool()
|
34 |
|
35 |
if labels is not None:
|
36 |
safe_labels = labels.clone()
|
37 |
+
safe_labels[labels == -100] = self.label2id['O']
|
38 |
loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
|
39 |
return {'loss': loss, 'logits': emissions}
|
40 |
else:
|
41 |
decoded = self.crf.decode(emissions, mask=mask)
|
42 |
+
max_len = input_ids.shape[1]
|
43 |
+
padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded]
|
44 |
+
logits = torch.tensor(padded_decoded, device=input_ids.device)
|
45 |
+
return {'logits': logits}
|