asmashayea commited on
Commit
e089772
Β·
1 Parent(s): a458c5b

πŸš€ Initial deploy of ABSA Space

Browse files
Files changed (4) hide show
  1. app.py +14 -0
  2. inference.py +21 -0
  3. model.py +41 -0
  4. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from ABSA.inference import model, tokenizer, label2id, id2label
3
+ import torch
4
+
5
+ def predict(text):
6
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
7
+ with torch.no_grad():
8
+ outputs = model(**inputs)
9
+ preds = outputs["logits"].squeeze(0).tolist()
10
+ labels = [id2label.get(p, "O") for p in preds]
11
+ tokens = tokenizer.tokenize(text)
12
+ return list(zip(tokens, labels))
13
+
14
+ gr.Interface(fn=predict, inputs="text", outputs="json", title="Arabic ABSA Model").launch()
inference.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ABSA.model import BERT_BiLSTM_CRF # Same model class you defined
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import json
5
+
6
+ # Load tokenizer and base model
7
+ model_path = "saved_model"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
9
+ base_model = AutoModel.from_pretrained(model_path)
10
+
11
+ # Load label mappings
12
+ with open(f"{model_path}/label2id.json") as f:
13
+ label2id = json.load(f)
14
+ with open(f"{model_path}/id2label.json") as f:
15
+ id2label = {int(k): v for k, v in json.load(f).items()}
16
+
17
+ # Init and load model
18
+ num_labels = len(label2id)
19
+ model = BERT_BiLSTM_CRF(base_model, num_labels)
20
+ model.load_state_dict(torch.load(f"{model_path}/full_model.pth", map_location="cpu"))
21
+ model.eval()
model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchcrf import CRF
4
+
5
+ class BERT_BiLSTM_CRF(nn.Module):
6
+ def __init__(self, base_model, num_labels, dropout_rate=0.2, rnn_dim=256):
7
+ super().__init__()
8
+ self.bert = base_model
9
+ self.bilstm = nn.LSTM(
10
+ self.bert.config.hidden_size,
11
+ rnn_dim,
12
+ num_layers=2,
13
+ batch_first=True,
14
+ bidirectional=True,
15
+ dropout=0.2
16
+ )
17
+ self.dropout = nn.Dropout(dropout_rate)
18
+ self.classifier = nn.Linear(rnn_dim * 2, num_labels)
19
+ self.crf = CRF(num_labels, batch_first=True)
20
+
21
+ def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
22
+ outputs = self.bert(
23
+ input_ids=input_ids,
24
+ attention_mask=attention_mask,
25
+ token_type_ids=token_type_ids
26
+ )
27
+ lstm_out, _ = self.bilstm(self.dropout(outputs.last_hidden_state))
28
+ emissions = self.classifier(lstm_out)
29
+ mask = attention_mask.bool()
30
+
31
+ if labels is not None:
32
+ safe_labels = labels.clone()
33
+ safe_labels[labels == -100] = 0 # Default to "O" index
34
+ loss = -self.crf(emissions, safe_labels, mask=mask, reduction='mean')
35
+ return {'loss': loss, 'logits': emissions}
36
+ else:
37
+ decoded = self.crf.decode(emissions, mask=mask)
38
+ max_len = input_ids.shape[1]
39
+ padded_decoded = [seq + [0] * (max_len - len(seq)) for seq in decoded]
40
+ logits = torch.tensor(padded_decoded, device=input_ids.device)
41
+ return {'logits': logits}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ torchcrf
5
+ peft