asmashayea commited on
Commit
e90dd4b
·
1 Parent(s): afe677f
Files changed (2) hide show
  1. inference.py +89 -32
  2. model.py → modeling_bilstm_crf.py +21 -10
inference.py CHANGED
@@ -1,35 +1,96 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- from peft import PeftModel
 
 
4
  from seq2seq_inference import infer_t5_prompt, infer_mBart_prompt
5
- from transformers import AutoTokenizer, AutoModelForTokenClassification
6
-
7
- # Define supported models and their adapter IDs
8
- MODEL_OPTIONS = {
9
- "Araberta": {
10
- "base": "asmashayea/absa-araberta",
11
- "adapter": "asmashayea/absa-araberta"
12
- },
13
- "mT5": {
14
- "base": "google/mt5-base",
15
- "adapter": "asmashayea/mt4-absa"
16
- },
17
- "mBART": {
18
- "base": "facebook/mbart-large-50-many-to-many-mmt",
19
- "adapter": "asmashayea/mbart-absa"
20
- },
21
- "GPT3.5": {
22
- "base": "bigscience/bloom-560m", # example, not ideal for ABSA
23
- "adapter": "asmashayea/gpt-absa"
24
- },
25
- "GPT4o": {
26
- "base": "bigscience/bloom-560m", # example, not ideal for ABSA
27
- "adapter": "asmashayea/gpt-absa"
28
- }
29
- }
30
 
31
  cached_models = {}
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def load_model(model_key):
34
  if model_key in cached_models:
35
  return cached_models[model_key]
@@ -47,8 +108,6 @@ def load_model(model_key):
47
 
48
  def predict_absa(text, model_choice):
49
 
50
-
51
-
52
 
53
  if model_choice == 'mT5':
54
  tokenizer, model = load_model(model_choice)
@@ -60,9 +119,7 @@ def predict_absa(text, model_choice):
60
 
61
  elif model_choice == 'Araberta':
62
 
63
- model = AutoModelForTokenClassification.from_pretrained("asmashayea/absa-araberta")
64
- tokenizer = AutoTokenizer.from_pretrained("asmashayea/absa-araberta")
65
- decoded = infer_mBart_prompt(text, tokenizer, model)
66
 
67
 
68
  # prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
 
1
  import torch
2
+ import json
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
4
+ from peft import LoraConfig, get_peft_model
5
+ from modeling_bilstm_crf import BERT_BiLSTM_CRF
6
  from seq2seq_inference import infer_t5_prompt, infer_mBart_prompt
7
+ from peft import LoraConfig, get_peft_model, PeftModel
8
+ from modeling_bilstm_crf import BERT_BiLSTM_CRF
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  cached_models = {}
11
 
12
+ def load_araberta():
13
+ path = "asmashayea/absa-arabert"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(path)
16
+ base_model = AutoModel.from_pretrained(path)
17
+ lora_config = LoraConfig.from_pretrained(path)
18
+ lora_model = get_peft_model(base_model, lora_config)
19
+
20
+ config = AutoConfig.from_pretrained(path)
21
+ model = BERT_BiLSTM_CRF(lora_model, config)
22
+ model.load_state_dict(torch.load("bilstm_crf_head.pt"))
23
+ model.eval()
24
+
25
+ cached_models["Araberta"] = (tokenizer, model)
26
+ return tokenizer, model
27
+
28
+
29
+ def infer_araberta(text):
30
+ if "Araberta" not in cached_models:
31
+ tokenizer, model = load_araberta()
32
+ else:
33
+ tokenizer, model = cached_models["Araberta"]
34
+
35
+
36
+ device = next(model.parameters()).device
37
+
38
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
39
+ input_ids = inputs['input_ids'].to(device)
40
+ attention_mask = inputs['attention_mask'].to(device)
41
+
42
+ with torch.no_grad():
43
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
44
+ predicted_ids = outputs['logits'][0].cpu().tolist()
45
+
46
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
47
+ predicted_labels = [model.config.id2label.get(p, 'O') for p in predicted_ids]
48
+
49
+ clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
50
+ clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
51
+
52
+ # Horizontal output
53
+ pairs = [f"{token}: {label}" for token, label in zip(clean_tokens, clean_labels)]
54
+ horizontal_output = " | ".join(pairs)
55
+
56
+ # Group by aspect span
57
+ aspects = []
58
+ current_tokens = []
59
+ current_sentiment = None
60
+
61
+ for token, label in zip(clean_tokens, clean_labels):
62
+ if label.startswith("B-"):
63
+ if current_tokens:
64
+ aspects.append({
65
+ "aspect": " ".join(current_tokens).replace("##", ""),
66
+ "sentiment": current_sentiment
67
+ })
68
+ current_tokens = [token]
69
+ current_sentiment = label.split("-")[1]
70
+ elif label.startswith("I-") and current_sentiment == label.split("-")[1]:
71
+ current_tokens.append(token)
72
+ else:
73
+ if current_tokens:
74
+ aspects.append({
75
+ "aspect": " ".join(current_tokens).replace("##", ""),
76
+ "sentiment": current_sentiment
77
+ })
78
+ current_tokens = []
79
+ current_sentiment = None
80
+
81
+ if current_tokens:
82
+ aspects.append({
83
+ "aspect": " ".join(current_tokens).replace("##", ""),
84
+ "sentiment": current_sentiment
85
+ })
86
+
87
+ return {
88
+ "token_predictions": horizontal_output,
89
+ "aspects": aspects
90
+ }
91
+
92
+
93
+
94
  def load_model(model_key):
95
  if model_key in cached_models:
96
  return cached_models[model_key]
 
108
 
109
  def predict_absa(text, model_choice):
110
 
 
 
111
 
112
  if model_choice == 'mT5':
113
  tokenizer, model = load_model(model_choice)
 
119
 
120
  elif model_choice == 'Araberta':
121
 
122
+ decoded = infer_araberta(text)
 
 
123
 
124
 
125
  # prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
model.py → modeling_bilstm_crf.py RENAMED
@@ -3,32 +3,43 @@ import torch.nn as nn
3
  from torchcrf import CRF
4
 
5
  class BERT_BiLSTM_CRF(nn.Module):
6
- def __init__(self, base_model, num_labels, rnn_dim=256, dropout_rate=0.2):
7
  super().__init__()
8
  self.bert = base_model
 
 
 
 
9
  self.bilstm = nn.LSTM(
10
- input_size=self.bert.config.hidden_size,
11
- hidden_size=rnn_dim,
12
  num_layers=2,
13
  batch_first=True,
14
  bidirectional=True,
15
- dropout=dropout_rate
16
  )
17
  self.dropout = nn.Dropout(dropout_rate)
18
- self.classifier = nn.Linear(rnn_dim * 2, num_labels)
19
- self.crf = CRF(num_labels, batch_first=True)
20
 
21
  def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
22
- bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
23
- lstm_out, _ = self.bilstm(self.dropout(bert_output))
 
 
 
 
24
  emissions = self.classifier(lstm_out)
25
  mask = attention_mask.bool()
26
 
27
  if labels is not None:
28
  safe_labels = labels.clone()
29
- safe_labels[labels == -100] = 0
30
  loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
31
  return {'loss': loss, 'logits': emissions}
32
  else:
33
  decoded = self.crf.decode(emissions, mask=mask)
34
- return {'logits': torch.tensor(decoded)}
 
 
 
 
3
  from torchcrf import CRF
4
 
5
  class BERT_BiLSTM_CRF(nn.Module):
6
+ def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256):
7
  super().__init__()
8
  self.bert = base_model
9
+ self.label2id = config.label2id # <-- pulled from config
10
+ self.id2label = config.id2label
11
+ self.num_labels = config.num_labels
12
+
13
  self.bilstm = nn.LSTM(
14
+ self.bert.config.hidden_size,
15
+ rnn_dim,
16
  num_layers=2,
17
  batch_first=True,
18
  bidirectional=True,
19
+ dropout=0.2
20
  )
21
  self.dropout = nn.Dropout(dropout_rate)
22
+ self.classifier = nn.Linear(rnn_dim * 2, self.num_labels)
23
+ self.crf = CRF(self.num_labels, batch_first=True)
24
 
25
  def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
26
+ outputs = self.bert(
27
+ input_ids=input_ids,
28
+ attention_mask=attention_mask,
29
+ token_type_ids=token_type_ids
30
+ )
31
+ lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state))
32
  emissions = self.classifier(lstm_out)
33
  mask = attention_mask.bool()
34
 
35
  if labels is not None:
36
  safe_labels = labels.clone()
37
+ safe_labels[labels == -100] = self.label2id['O']
38
  loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
39
  return {'loss': loss, 'logits': emissions}
40
  else:
41
  decoded = self.crf.decode(emissions, mask=mask)
42
+ max_len = input_ids.shape[1]
43
+ padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded]
44
+ logits = torch.tensor(padded_decoded, device=input_ids.device)
45
+ return {'logits': logits}