asmashayea commited on
Commit
099f387
·
1 Parent(s): ca6eb6e
Files changed (1) hide show
  1. inference.py +12 -9
inference.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
4
  from peft import LoraConfig, get_peft_model, PeftModel
5
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
6
- from seq2seq_inference import infer_t5_prompt
7
  from huggingface_hub import hf_hub_download
8
 
9
  # Define supported models and their adapter IDs
@@ -61,7 +61,6 @@ def infer_araberta(text):
61
  else:
62
  tokenizer, model = cached_models["Araberta"]
63
 
64
-
65
  device = next(model.parameters()).device
66
 
67
  inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
@@ -73,15 +72,18 @@ def infer_araberta(text):
73
  predicted_ids = outputs['logits'][0].cpu().tolist()
74
 
75
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
76
- # predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
77
  predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]
78
 
79
-
80
  clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
81
  clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
82
 
 
 
 
 
 
 
83
 
84
- # Group by aspect span
85
  aspects = []
86
  current_tokens = []
87
  current_sentiment = None
@@ -91,7 +93,7 @@ def infer_araberta(text):
91
  if current_tokens:
92
  aspects.append({
93
  "aspect": " ".join(current_tokens).replace("##", ""),
94
- "sentiment": current_sentiment
95
  })
96
  current_tokens = [token]
97
  current_sentiment = label.split("-")[1]
@@ -101,7 +103,7 @@ def infer_araberta(text):
101
  if current_tokens:
102
  aspects.append({
103
  "aspect": " ".join(current_tokens).replace("##", ""),
104
- "sentiment": current_sentiment
105
  })
106
  current_tokens = []
107
  current_sentiment = None
@@ -109,7 +111,7 @@ def infer_araberta(text):
109
  if current_tokens:
110
  aspects.append({
111
  "aspect": " ".join(current_tokens).replace("##", ""),
112
- "sentiment": current_sentiment
113
  })
114
 
115
  token_predictions = [
@@ -125,6 +127,7 @@ def infer_araberta(text):
125
 
126
 
127
 
 
128
  def load_model(model_key):
129
  if model_key in cached_models:
130
  return cached_models[model_key]
@@ -148,7 +151,7 @@ def predict_absa(text, model_choice):
148
 
149
  if model_choice in ['mT5', 'mBART']:
150
  tokenizer, model = load_model(model_choice)
151
- decoded = infer_t5_prompt(text, tokenizer, model)
152
 
153
  elif model_choice == 'Araberta':
154
 
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
4
  from peft import LoraConfig, get_peft_model, PeftModel
5
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
6
+ from seq2seq_inference import infer_t5_bart
7
  from huggingface_hub import hf_hub_download
8
 
9
  # Define supported models and their adapter IDs
 
61
  else:
62
  tokenizer, model = cached_models["Araberta"]
63
 
 
64
  device = next(model.parameters()).device
65
 
66
  inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
 
72
  predicted_ids = outputs['logits'][0].cpu().tolist()
73
 
74
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
 
75
  predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]
76
 
 
77
  clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
78
  clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
79
 
80
+ # ✅ New: map short to full sentiment
81
+ sentiment_map = {
82
+ "POS": "positive",
83
+ "NEG": "negative",
84
+ "NEU": "neutral"
85
+ }
86
 
 
87
  aspects = []
88
  current_tokens = []
89
  current_sentiment = None
 
93
  if current_tokens:
94
  aspects.append({
95
  "aspect": " ".join(current_tokens).replace("##", ""),
96
+ "sentiment": sentiment_map.get(current_sentiment, current_sentiment)
97
  })
98
  current_tokens = [token]
99
  current_sentiment = label.split("-")[1]
 
103
  if current_tokens:
104
  aspects.append({
105
  "aspect": " ".join(current_tokens).replace("##", ""),
106
+ "sentiment": sentiment_map.get(current_sentiment, current_sentiment)
107
  })
108
  current_tokens = []
109
  current_sentiment = None
 
111
  if current_tokens:
112
  aspects.append({
113
  "aspect": " ".join(current_tokens).replace("##", ""),
114
+ "sentiment": sentiment_map.get(current_sentiment, current_sentiment)
115
  })
116
 
117
  token_predictions = [
 
127
 
128
 
129
 
130
+
131
  def load_model(model_key):
132
  if model_key in cached_models:
133
  return cached_models[model_key]
 
151
 
152
  if model_choice in ['mT5', 'mBART']:
153
  tokenizer, model = load_model(model_choice)
154
+ decoded = infer_t5_bart(text, tokenizer, model)
155
 
156
  elif model_choice == 'Araberta':
157