Spaces:
Running
Running
File size: 5,539 Bytes
e3a1adb e90dd4b 14f0d71 aa9f0d3 14b1d48 e3a1adb 5825fbe 0a7f591 80dc20b e597fdc aa9f0d3 5825fbe 80dc20b 3a93c8a e3a1adb e90dd4b 678fe06 e90dd4b 678fe06 14b1d48 678fe06 e90dd4b f4612ba e90dd4b 681d8ae e90dd4b 099f387 e90dd4b 099f387 e90dd4b 099f387 e90dd4b 099f387 e90dd4b d472db4 ca5600b e90dd4b a2b7e8f e90dd4b bfc4d6a 3a93c8a aef0118 8115742 bfc4d6a 11c609e aef0118 a700ccf e13680c e90dd4b e13680c 598ac39 80dc20b e8429a7 e597fdc aa9f0d3 e597fdc 54b3c2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import torch
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from modeling_bilstm_crf import BERT_BiLSTM_CRF
from generative_inference import infer_t5_bart, infer_gpt_absa, infer_deepseek, infer_allam
from huggingface_hub import hf_hub_download
# Define supported models and their adapter IDs
MODEL_OPTIONS = {
"Araberta": {
"base": "asmashayea/absa-araberta",
"adapter": "asmashayea/absa-araberta"
},
"mT5": {
"base": "google/mt5-base",
"adapter": "asmashayea/mt4-absa"
},
"mBART": {
"base": "facebook/mbart-large-50-many-to-many-mmt",
"adapter": "asmashayea/mbart-absa"
},
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
"GPT4o": {"base": "openai/gpt-4o",
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"},
"ALLaM": {
"base": "ALLaM-AI/ALLaM-7B-Instruct-preview",
"adapter": "asmashayea/allam-absa"
},
"DeepSeek": {
"base": "deepseek-ai/deepseek-llm-7b-chat",
"adapter": "asmashayea/deepseek-absa"
}
}
cached_models = {}
def load_araberta():
path = "asmashayea/absa-arabert"
tokenizer = AutoTokenizer.from_pretrained(path)
base_model = AutoModel.from_pretrained(path)
# lora_config = LoraConfig.from_pretrained(path)
# lora_model = get_peft_model(base_model, lora_config)
base_model = AutoModel.from_pretrained(path)
lora_model = PeftModel.from_pretrained(base_model, path)
local_pt = hf_hub_download(repo_id="asmashayea/absa-arabert", filename="bilstm_crf_head.pt")
config = AutoConfig.from_pretrained(path)
model = BERT_BiLSTM_CRF(lora_model, config)
model.load_state_dict(torch.load(local_pt, map_location=torch.device("cpu")))
# model.load_state_dict(torch.load(local_pt))
model.eval()
cached_models["Araberta"] = (tokenizer, model)
return tokenizer, model
def infer_araberta(text):
if "Araberta" not in cached_models:
tokenizer, model = load_araberta()
else:
tokenizer, model = cached_models["Araberta"]
device = next(model.parameters()).device
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predicted_ids = outputs['logits'][0].cpu().tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]
clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
# ✅ New: map short to full sentiment
sentiment_map = {
"POS": "positive",
"NEG": "negative",
"NEU": "neutral"
}
aspects = []
current_tokens = []
current_sentiment = None
for token, label in zip(clean_tokens, clean_labels):
if label.startswith("B-"):
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
})
current_tokens = [token]
current_sentiment = label.split("-")[1]
elif label.startswith("I-") and current_sentiment == label.split("-")[1]:
current_tokens.append(token)
else:
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
})
current_tokens = []
current_sentiment = None
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
})
token_predictions = []
merged_token = ""
merged_label = None
for token, label in zip(clean_tokens, clean_labels):
if token.startswith("##"):
merged_token += token[2:]
else:
if merged_token:
token_predictions.append({
"token": merged_token,
"label": merged_label
})
merged_token = token
merged_label = label
# Add last token
if merged_token:
token_predictions.append({
"token": merged_token,
"label": merged_label
})
return {
"aspects": aspects,
"token_predictions": token_predictions
}
def predict_absa(text, model_choice):
if model_choice in ['mT5', 'mBART']:
decoded = infer_t5_bart(text, model_choice)
elif model_choice == 'Araberta':
decoded = infer_araberta(text)
elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
decoded = infer_gpt_absa(text, model_choice)
elif model_choice == "DeepSeek":
return infer_deepseek(text)
elif model_choice == "ALLaM":
return infer_allam(text)
return decoded
|