absa-app / inference.py
asmashayea's picture
m
e8429a7
raw
history blame
1.94 kB
import torch
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from model import BERT_BiLSTM_CRF
# Load LoRA-wrapped model from HuggingFace Hub
model_id = "asmashayea/absa-model"
tokenizer = AutoTokenizer.from_pretrained(model_id)
base_model = AutoModel.from_pretrained(model_id)
base_model = PeftModel.from_pretrained(base_model, model_id)
# Label mapping must match your training label2id
id2label = {
0: 'O',
1: 'B-POS',
2: 'I-POS',
3: 'B-NEG',
4: 'I-NEG',
5: 'B-NEU',
6: 'I-NEU'
}
model = BERT_BiLSTM_CRF(base_model=base_model, num_labels=len(id2label))
model.eval()
def bio_to_spans(tokens, labels):
spans = []
current = []
current_label = ""
for idx, label in enumerate(labels):
if label.startswith("B-"):
if current:
spans.append((" ".join(current), current_label))
current = []
current = [tokens[idx]]
current_label = label[2:]
elif label.startswith("I-") and current:
current.append(tokens[idx])
else:
if current:
spans.append((" ".join(current), current_label))
current = []
if current:
spans.append((" ".join(current), current_label))
return spans
def predict_absa(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
print("output: ",output)
pred_ids = output['logits'][0].tolist()
print("pred_ids: ",pred_ids)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [id2label.get(i, 'O') for i in pred_ids]
print("tokens: ",tokens)
print("labels: ",labels)
result = bio_to_spans(tokens, labels)
return [{"aspect": asp, "sentiment": pol} for asp, pol in result]