File size: 1,939 Bytes
e3a1adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8429a7
e3a1adb
e8429a7
e3a1adb
 
 
e8429a7
 
 
e3a1adb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from model import BERT_BiLSTM_CRF

# Load LoRA-wrapped model from HuggingFace Hub
model_id = "asmashayea/absa-model"
tokenizer = AutoTokenizer.from_pretrained(model_id)
base_model = AutoModel.from_pretrained(model_id)
base_model = PeftModel.from_pretrained(base_model, model_id)

# Label mapping must match your training label2id
id2label = {
    0: 'O',
    1: 'B-POS',
    2: 'I-POS',
    3: 'B-NEG',
    4: 'I-NEG',
    5: 'B-NEU',
    6: 'I-NEU'
}

model = BERT_BiLSTM_CRF(base_model=base_model, num_labels=len(id2label))
model.eval()

def bio_to_spans(tokens, labels):
    spans = []
    current = []
    current_label = ""
    for idx, label in enumerate(labels):
        if label.startswith("B-"):
            if current:
                spans.append((" ".join(current), current_label))
                current = []
            current = [tokens[idx]]
            current_label = label[2:]
        elif label.startswith("I-") and current:
            current.append(tokens[idx])
        else:
            if current:
                spans.append((" ".join(current), current_label))
                current = []
    if current:
        spans.append((" ".join(current), current_label))
    return spans

def predict_absa(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        output = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        print("output: ",output)
        pred_ids = output['logits'][0].tolist()
        print("pred_ids: ",pred_ids)

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    labels = [id2label.get(i, 'O') for i in pred_ids]
    print("tokens: ",tokens)
    print("labels: ",labels)


    result = bio_to_spans(tokens, labels)
    return [{"aspect": asp, "sentiment": pol} for asp, pol in result]