File size: 5,539 Bytes
e3a1adb
e90dd4b
 
 
14f0d71
aa9f0d3
14b1d48
e3a1adb
5825fbe
 
 
 
 
 
 
 
 
 
0a7f591
 
 
 
80dc20b
 
 
e597fdc
aa9f0d3
 
 
 
 
 
 
 
5825fbe
 
80dc20b
3a93c8a
e3a1adb
e90dd4b
 
 
 
678fe06
e90dd4b
678fe06
 
 
14b1d48
678fe06
 
 
 
e90dd4b
 
 
f4612ba
 
 
e90dd4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681d8ae
 
e90dd4b
 
 
099f387
 
 
 
 
 
e90dd4b
 
 
 
 
 
 
 
 
 
099f387
e90dd4b
 
 
 
 
 
 
 
 
099f387
e90dd4b
 
 
 
 
 
 
099f387
e90dd4b
 
d472db4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca5600b
e90dd4b
a2b7e8f
 
e90dd4b
 
 
bfc4d6a
3a93c8a
aef0118
8115742
bfc4d6a
11c609e
aef0118
a700ccf
e13680c
e90dd4b
e13680c
598ac39
80dc20b
e8429a7
e597fdc
 
aa9f0d3
 
 
e597fdc
54b3c2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from modeling_bilstm_crf import BERT_BiLSTM_CRF
from generative_inference import infer_t5_bart, infer_gpt_absa, infer_deepseek, infer_allam
from huggingface_hub import hf_hub_download

# Define supported models and their adapter IDs
MODEL_OPTIONS = {
    "Araberta": {
        "base": "asmashayea/absa-araberta",
        "adapter": "asmashayea/absa-araberta"
    },
    "mT5": {
        "base": "google/mt5-base",
        "adapter": "asmashayea/mt4-absa"
    },
    "mBART": {
        "base": "facebook/mbart-large-50-many-to-many-mmt",
        "adapter": "asmashayea/mbart-absa"
    },
    "GPT3.5": {"base": "openai/gpt-3.5-turbo",
               "model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
    "GPT4o": {"base": "openai/gpt-4o",
              "model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"},
    "ALLaM": {
        "base": "ALLaM-AI/ALLaM-7B-Instruct-preview",
        "adapter": "asmashayea/allam-absa"
    },
    "DeepSeek": {
        "base": "deepseek-ai/deepseek-llm-7b-chat",
        "adapter": "asmashayea/deepseek-absa"
    }
}


cached_models = {}

def load_araberta():
    path = "asmashayea/absa-arabert"

    tokenizer = AutoTokenizer.from_pretrained(path)

    base_model = AutoModel.from_pretrained(path)
    # lora_config = LoraConfig.from_pretrained(path)
    # lora_model = get_peft_model(base_model, lora_config)


    base_model = AutoModel.from_pretrained(path)
    lora_model = PeftModel.from_pretrained(base_model, path)

    local_pt = hf_hub_download(repo_id="asmashayea/absa-arabert", filename="bilstm_crf_head.pt")

    config = AutoConfig.from_pretrained(path)
    model = BERT_BiLSTM_CRF(lora_model, config)
    model.load_state_dict(torch.load(local_pt, map_location=torch.device("cpu")))

    # model.load_state_dict(torch.load(local_pt))
    model.eval()

    cached_models["Araberta"] = (tokenizer, model)
    return tokenizer, model


def infer_araberta(text):
    if "Araberta" not in cached_models:
        tokenizer, model = load_araberta()
    else:
        tokenizer, model = cached_models["Araberta"]

    device = next(model.parameters()).device

    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predicted_ids = outputs['logits'][0].cpu().tolist()

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
    predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]

    clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
    clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]

    # ✅ New: map short to full sentiment
    sentiment_map = {
        "POS": "positive",
        "NEG": "negative",
        "NEU": "neutral"
    }

    aspects = []
    current_tokens = []
    current_sentiment = None

    for token, label in zip(clean_tokens, clean_labels):
        if label.startswith("B-"):
            if current_tokens:
                aspects.append({
                    "aspect": " ".join(current_tokens).replace("##", ""),
                    "sentiment": sentiment_map.get(current_sentiment, current_sentiment)
                })
            current_tokens = [token]
            current_sentiment = label.split("-")[1]
        elif label.startswith("I-") and current_sentiment == label.split("-")[1]:
            current_tokens.append(token)
        else:
            if current_tokens:
                aspects.append({
                    "aspect": " ".join(current_tokens).replace("##", ""),
                    "sentiment": sentiment_map.get(current_sentiment, current_sentiment)
                })
                current_tokens = []
                current_sentiment = None

    if current_tokens:
        aspects.append({
            "aspect": " ".join(current_tokens).replace("##", ""),
            "sentiment": sentiment_map.get(current_sentiment, current_sentiment)
        })

    token_predictions = []
    merged_token = ""
    merged_label = None

    for token, label in zip(clean_tokens, clean_labels):
        if token.startswith("##"):
            merged_token += token[2:]
        else:
            if merged_token:
                token_predictions.append({
                    "token": merged_token,
                    "label": merged_label
                })
            merged_token = token
            merged_label = label

    # Add last token
    if merged_token:
        token_predictions.append({
            "token": merged_token,
            "label": merged_label
        })


    return {
        "aspects": aspects,
        "token_predictions": token_predictions
    }


    
def predict_absa(text, model_choice):


    if model_choice in ['mT5', 'mBART']:
        decoded = infer_t5_bart(text, model_choice)

    elif model_choice == 'Araberta':

        decoded = infer_araberta(text)

    elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
        decoded = infer_gpt_absa(text, model_choice)

    elif model_choice == "DeepSeek":
        return infer_deepseek(text)
    
    elif model_choice == "ALLaM":
        return infer_allam(text)

    return decoded