absa-app / inference.py
asmashayea's picture
ok
d472db4
import torch
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from modeling_bilstm_crf import BERT_BiLSTM_CRF
from generative_inference import infer_t5_bart, infer_gpt_absa, infer_deepseek, infer_allam
from huggingface_hub import hf_hub_download
# Define supported models and their adapter IDs
MODEL_OPTIONS = {
"Araberta": {
"base": "asmashayea/absa-araberta",
"adapter": "asmashayea/absa-araberta"
},
"mT5": {
"base": "google/mt5-base",
"adapter": "asmashayea/mt4-absa"
},
"mBART": {
"base": "facebook/mbart-large-50-many-to-many-mmt",
"adapter": "asmashayea/mbart-absa"
},
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
"GPT4o": {"base": "openai/gpt-4o",
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"},
"ALLaM": {
"base": "ALLaM-AI/ALLaM-7B-Instruct-preview",
"adapter": "asmashayea/allam-absa"
},
"DeepSeek": {
"base": "deepseek-ai/deepseek-llm-7b-chat",
"adapter": "asmashayea/deepseek-absa"
}
}
cached_models = {}
def load_araberta():
path = "asmashayea/absa-arabert"
tokenizer = AutoTokenizer.from_pretrained(path)
base_model = AutoModel.from_pretrained(path)
# lora_config = LoraConfig.from_pretrained(path)
# lora_model = get_peft_model(base_model, lora_config)
base_model = AutoModel.from_pretrained(path)
lora_model = PeftModel.from_pretrained(base_model, path)
local_pt = hf_hub_download(repo_id="asmashayea/absa-arabert", filename="bilstm_crf_head.pt")
config = AutoConfig.from_pretrained(path)
model = BERT_BiLSTM_CRF(lora_model, config)
model.load_state_dict(torch.load(local_pt, map_location=torch.device("cpu")))
# model.load_state_dict(torch.load(local_pt))
model.eval()
cached_models["Araberta"] = (tokenizer, model)
return tokenizer, model
def infer_araberta(text):
if "Araberta" not in cached_models:
tokenizer, model = load_araberta()
else:
tokenizer, model = cached_models["Araberta"]
device = next(model.parameters()).device
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predicted_ids = outputs['logits'][0].cpu().tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu())
predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids]
clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens]
clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens]
# ✅ New: map short to full sentiment
sentiment_map = {
"POS": "positive",
"NEG": "negative",
"NEU": "neutral"
}
aspects = []
current_tokens = []
current_sentiment = None
for token, label in zip(clean_tokens, clean_labels):
if label.startswith("B-"):
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
})
current_tokens = [token]
current_sentiment = label.split("-")[1]
elif label.startswith("I-") and current_sentiment == label.split("-")[1]:
current_tokens.append(token)
else:
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
})
current_tokens = []
current_sentiment = None
if current_tokens:
aspects.append({
"aspect": " ".join(current_tokens).replace("##", ""),
"sentiment": sentiment_map.get(current_sentiment, current_sentiment)
})
token_predictions = []
merged_token = ""
merged_label = None
for token, label in zip(clean_tokens, clean_labels):
if token.startswith("##"):
merged_token += token[2:]
else:
if merged_token:
token_predictions.append({
"token": merged_token,
"label": merged_label
})
merged_token = token
merged_label = label
# Add last token
if merged_token:
token_predictions.append({
"token": merged_token,
"label": merged_label
})
return {
"aspects": aspects,
"token_predictions": token_predictions
}
def predict_absa(text, model_choice):
if model_choice in ['mT5', 'mBART']:
decoded = infer_t5_bart(text, model_choice)
elif model_choice == 'Araberta':
decoded = infer_araberta(text)
elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
decoded = infer_gpt_absa(text, model_choice)
elif model_choice == "DeepSeek":
return infer_deepseek(text)
elif model_choice == "ALLaM":
return infer_allam(text)
return decoded