import torch import json from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig from peft import LoraConfig, get_peft_model, PeftModel from modeling_bilstm_crf import BERT_BiLSTM_CRF from generative_inference import infer_t5_bart, infer_gpt_absa, infer_deepseek, infer_allam from huggingface_hub import hf_hub_download # Define supported models and their adapter IDs MODEL_OPTIONS = { "Araberta": { "base": "asmashayea/absa-araberta", "adapter": "asmashayea/absa-araberta" }, "mT5": { "base": "google/mt5-base", "adapter": "asmashayea/mt4-absa" }, "mBART": { "base": "facebook/mbart-large-50-many-to-many-mmt", "adapter": "asmashayea/mbart-absa" }, "GPT3.5": {"base": "openai/gpt-3.5-turbo", "model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"}, "GPT4o": {"base": "openai/gpt-4o", "model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}, "ALLaM": { "base": "ALLaM-AI/ALLaM-7B-Instruct-preview", "adapter": "asmashayea/allam-absa" }, "DeepSeek": { "base": "deepseek-ai/deepseek-llm-7b-chat", "adapter": "asmashayea/deepseek-absa" } } cached_models = {} def load_araberta(): path = "asmashayea/absa-arabert" tokenizer = AutoTokenizer.from_pretrained(path) base_model = AutoModel.from_pretrained(path) # lora_config = LoraConfig.from_pretrained(path) # lora_model = get_peft_model(base_model, lora_config) base_model = AutoModel.from_pretrained(path) lora_model = PeftModel.from_pretrained(base_model, path) local_pt = hf_hub_download(repo_id="asmashayea/absa-arabert", filename="bilstm_crf_head.pt") config = AutoConfig.from_pretrained(path) model = BERT_BiLSTM_CRF(lora_model, config) model.load_state_dict(torch.load(local_pt, map_location=torch.device("cpu"))) # model.load_state_dict(torch.load(local_pt)) model.eval() cached_models["Araberta"] = (tokenizer, model) return tokenizer, model def infer_araberta(text): if "Araberta" not in cached_models: tokenizer, model = load_araberta() else: tokenizer, model = cached_models["Araberta"] device = next(model.parameters()).device inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=128) input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) predicted_ids = outputs['logits'][0].cpu().tolist() tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu()) predicted_labels = [model.id2label.get(p, 'O') for p in predicted_ids] clean_tokens = [t for t in tokens if t not in tokenizer.all_special_tokens] clean_labels = [l for t, l in zip(tokens, predicted_labels) if t not in tokenizer.all_special_tokens] # ✅ New: map short to full sentiment sentiment_map = { "POS": "positive", "NEG": "negative", "NEU": "neutral" } aspects = [] current_tokens = [] current_sentiment = None for token, label in zip(clean_tokens, clean_labels): if label.startswith("B-"): if current_tokens: aspects.append({ "aspect": " ".join(current_tokens).replace("##", ""), "sentiment": sentiment_map.get(current_sentiment, current_sentiment) }) current_tokens = [token] current_sentiment = label.split("-")[1] elif label.startswith("I-") and current_sentiment == label.split("-")[1]: current_tokens.append(token) else: if current_tokens: aspects.append({ "aspect": " ".join(current_tokens).replace("##", ""), "sentiment": sentiment_map.get(current_sentiment, current_sentiment) }) current_tokens = [] current_sentiment = None if current_tokens: aspects.append({ "aspect": " ".join(current_tokens).replace("##", ""), "sentiment": sentiment_map.get(current_sentiment, current_sentiment) }) token_predictions = [] merged_token = "" merged_label = None for token, label in zip(clean_tokens, clean_labels): if token.startswith("##"): merged_token += token[2:] else: if merged_token: token_predictions.append({ "token": merged_token, "label": merged_label }) merged_token = token merged_label = label # Add last token if merged_token: token_predictions.append({ "token": merged_token, "label": merged_label }) return { "aspects": aspects, "token_predictions": token_predictions } def predict_absa(text, model_choice): if model_choice in ['mT5', 'mBART']: decoded = infer_t5_bart(text, model_choice) elif model_choice == 'Araberta': decoded = infer_araberta(text) elif model_choice == 'GPT3.5' or model_choice == 'GPT4o': decoded = infer_gpt_absa(text, model_choice) elif model_choice == "DeepSeek": return infer_deepseek(text) elif model_choice == "ALLaM": return infer_allam(text) return decoded