Spaces:
Sleeping
Sleeping
Commit
·
099f387
1
Parent(s):
ca6eb6e
- inference.py +12 -9
inference.py
CHANGED
@@ -3,7 +3,7 @@ import json
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
|
4 |
from peft import LoraConfig, get_peft_model, PeftModel
|
5 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
6 |
-
from seq2seq_inference import
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
# Define supported models and their adapter IDs
|
@@ -61,7 +61,6 @@ def infer_araberta(text):
|
|
61 |
else:
|
62 |
tokenizer, model = cached_models["Araberta"]
|
63 |
|
64 |
-
|
65 |
device = next(model.parameters()).device
|
66 |
|
67 |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
|
@@ -73,15 +72,18 @@ def infer_araberta(text):
|
|
73 |
predicted_ids = outputs['logits'][0].cpu().tolist()
|
74 |
|
75 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
|
76 |
-
# predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
|
77 |
predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]
|
78 |
|
79 |
-
|
80 |
clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
|
81 |
clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
# Group by aspect span
|
85 |
aspects = []
|
86 |
current_tokens = []
|
87 |
current_sentiment = None
|
@@ -91,7 +93,7 @@ def infer_araberta(text):
|
|
91 |
if current_tokens:
|
92 |
aspects.append({
|
93 |
"aspect": " ".join(current_tokens).replace("##", ""),
|
94 |
-
"sentiment": current_sentiment
|
95 |
})
|
96 |
current_tokens = [token]
|
97 |
current_sentiment = label.split("-")[1]
|
@@ -101,7 +103,7 @@ def infer_araberta(text):
|
|
101 |
if current_tokens:
|
102 |
aspects.append({
|
103 |
"aspect": " ".join(current_tokens).replace("##", ""),
|
104 |
-
"sentiment": current_sentiment
|
105 |
})
|
106 |
current_tokens = []
|
107 |
current_sentiment = None
|
@@ -109,7 +111,7 @@ def infer_araberta(text):
|
|
109 |
if current_tokens:
|
110 |
aspects.append({
|
111 |
"aspect": " ".join(current_tokens).replace("##", ""),
|
112 |
-
"sentiment": current_sentiment
|
113 |
})
|
114 |
|
115 |
token_predictions = [
|
@@ -125,6 +127,7 @@ def infer_araberta(text):
|
|
125 |
|
126 |
|
127 |
|
|
|
128 |
def load_model(model_key):
|
129 |
if model_key in cached_models:
|
130 |
return cached_models[model_key]
|
@@ -148,7 +151,7 @@ def predict_absa(text, model_choice):
|
|
148 |
|
149 |
if model_choice in ['mT5', 'mBART']:
|
150 |
tokenizer, model = load_model(model_choice)
|
151 |
-
decoded =
|
152 |
|
153 |
elif model_choice == 'Araberta':
|
154 |
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
|
4 |
from peft import LoraConfig, get_peft_model, PeftModel
|
5 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
6 |
+
from seq2seq_inference import infer_t5_bart
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
# Define supported models and their adapter IDs
|
|
|
61 |
else:
|
62 |
tokenizer, model = cached_models["Araberta"]
|
63 |
|
|
|
64 |
device = next(model.parameters()).device
|
65 |
|
66 |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
|
|
|
72 |
predicted_ids = outputs['logits'][0].cpu().tolist()
|
73 |
|
74 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
|
|
|
75 |
predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]
|
76 |
|
|
|
77 |
clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
|
78 |
clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
|
79 |
|
80 |
+
# ✅ New: map short to full sentiment
|
81 |
+
sentiment_map = {
|
82 |
+
"POS": "positive",
|
83 |
+
"NEG": "negative",
|
84 |
+
"NEU": "neutral"
|
85 |
+
}
|
86 |
|
|
|
87 |
aspects = []
|
88 |
current_tokens = []
|
89 |
current_sentiment = None
|
|
|
93 |
if current_tokens:
|
94 |
aspects.append({
|
95 |
"aspect": " ".join(current_tokens).replace("##", ""),
|
96 |
+
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
|
97 |
})
|
98 |
current_tokens = [token]
|
99 |
current_sentiment = label.split("-")[1]
|
|
|
103 |
if current_tokens:
|
104 |
aspects.append({
|
105 |
"aspect": " ".join(current_tokens).replace("##", ""),
|
106 |
+
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
|
107 |
})
|
108 |
current_tokens = []
|
109 |
current_sentiment = None
|
|
|
111 |
if current_tokens:
|
112 |
aspects.append({
|
113 |
"aspect": " ".join(current_tokens).replace("##", ""),
|
114 |
+
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
|
115 |
})
|
116 |
|
117 |
token_predictions = [
|
|
|
127 |
|
128 |
|
129 |
|
130 |
+
|
131 |
def load_model(model_key):
|
132 |
if model_key in cached_models:
|
133 |
return cached_models[model_key]
|
|
|
151 |
|
152 |
if model_choice in ['mT5', 'mBART']:
|
153 |
tokenizer, model = load_model(model_choice)
|
154 |
+
decoded = infer_t5_bart(text, tokenizer, model)
|
155 |
|
156 |
elif model_choice == 'Araberta':
|
157 |
|