Spaces:
Running
Running
import torch | |
import json | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig | |
from peft import LoraConfig, get_peft_model, PeftModel | |
from modeling_bilstm_crf import BERT_BiLSTM_CRF | |
from generative_inference import infer_t5_bart, infer_gpt_absa, infer_deepseek, infer_allam | |
from huggingface_hub import hf_hub_download | |
# Define supported models and their adapter IDs | |
MODEL_OPTIONS = { | |
"Araberta": { | |
"base": "asmashayea/absa-araberta", | |
"adapter": "asmashayea/absa-araberta" | |
}, | |
"mT5": { | |
"base": "google/mt5-base", | |
"adapter": "asmashayea/mt4-absa" | |
}, | |
"mBART": { | |
"base": "facebook/mbart-large-50-many-to-many-mmt", | |
"adapter": "asmashayea/mbart-absa" | |
}, | |
"GPT3.5": {"base": "openai/gpt-3.5-turbo", | |
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"}, | |
"GPT4o": {"base": "openai/gpt-4o", | |
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}, | |
"ALLaM": { | |
"base": "ALLaM-AI/ALLaM-7B-Instruct-preview", | |
"adapter": "asmashayea/allam-absa" | |
}, | |
"DeepSeek": { | |
"base": "deepseek-ai/deepseek-llm-7b-chat", | |
"adapter": "asmashayea/deepseek-absa" | |
} | |
} | |
cached_models = {} | |
def load_araberta(): | |
path = "asmashayea/absa-arabert" | |
tokenizer = AutoTokenizer.from_pretrained(path) | |
base_model = AutoModel.from_pretrained(path) | |
# lora_config = LoraConfig.from_pretrained(path) | |
# lora_model = get_peft_model(base_model, lora_config) | |
base_model = AutoModel.from_pretrained(path) | |
lora_model = PeftModel.from_pretrained(base_model, path) | |
local_pt = hf_hub_download(repo_id="asmashayea/absa-arabert", filename="bilstm_crf_head.pt") | |
config = AutoConfig.from_pretrained(path) | |
model = BERT_BiLSTM_CRF(lora_model, config) | |
model.load_state_dict(torch.load(local_pt, map_location=torch.device("cpu"))) | |
# model.load_state_dict(torch.load(local_pt)) | |
model.eval() | |
cached_models["Araberta"] = (tokenizer, model) | |
return tokenizer, model | |
def infer_araberta(text): | |
if "Araberta" not in cached_models: | |
tokenizer, model = load_araberta() | |
else: | |
tokenizer, model = cached_models["Araberta"] | |
device = next(model.parameters()).device | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128) | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
predicted_ids = outputs['logits'][0].cpu().tolist() | |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu()) | |
predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids] | |
clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens] | |
clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens] | |
# ✅ New: map short to full sentiment | |
sentiment_map = { | |
"POS": "positive", | |
"NEG": "negative", | |
"NEU": "neutral" | |
} | |
aspects = [] | |
current_tokens = [] | |
current_sentiment = None | |
for token, label in zip(clean_tokens, clean_labels): | |
if label.startswith("B-"): | |
if current_tokens: | |
aspects.append({ | |
"aspect": " ".join(current_tokens).replace("##", ""), | |
"sentiment": sentiment_map.get(current_sentiment, current_sentiment) | |
}) | |
current_tokens = [token] | |
current_sentiment = label.split("-")[1] | |
elif label.startswith("I-") and current_sentiment == label.split("-")[1]: | |
current_tokens.append(token) | |
else: | |
if current_tokens: | |
aspects.append({ | |
"aspect": " ".join(current_tokens).replace("##", ""), | |
"sentiment": sentiment_map.get(current_sentiment, current_sentiment) | |
}) | |
current_tokens = [] | |
current_sentiment = None | |
if current_tokens: | |
aspects.append({ | |
"aspect": " ".join(current_tokens).replace("##", ""), | |
"sentiment": sentiment_map.get(current_sentiment, current_sentiment) | |
}) | |
token_predictions = [] | |
merged_token = "" | |
merged_label = None | |
for token, label in zip(clean_tokens, clean_labels): | |
if token.startswith("##"): | |
merged_token += token[2:] | |
else: | |
if merged_token: | |
token_predictions.append({ | |
"token": merged_token, | |
"label": merged_label | |
}) | |
merged_token = token | |
merged_label = label | |
# Add last token | |
if merged_token: | |
token_predictions.append({ | |
"token": merged_token, | |
"label": merged_label | |
}) | |
return { | |
"aspects": aspects, | |
"token_predictions": token_predictions | |
} | |
def predict_absa(text, model_choice): | |
if model_choice in ['mT5', 'mBART']: | |
decoded = infer_t5_bart(text, model_choice) | |
elif model_choice == 'Araberta': | |
decoded = infer_araberta(text) | |
elif model_choice == 'GPT3.5' or model_choice == 'GPT4o': | |
decoded = infer_gpt_absa(text, model_choice) | |
elif model_choice == "DeepSeek": | |
return infer_deepseek(text) | |
elif model_choice == "ALLaM": | |
return infer_allam(text) | |
return decoded | |