asmashayea commited on
Commit
bfc4d6a
·
1 Parent(s): 5825fbe
Files changed (3) hide show
  1. app.py +1 -1
  2. inference.py +5 -10
  3. seq2seq_inference.py +9 -38
app.py CHANGED
@@ -15,7 +15,7 @@ demo = gr.Interface(
15
  ],
16
  outputs=gr.Textbox(label="Extracted Aspect-Sentiment-Opinion Triplets"),
17
  title="Arabic ABSA (Aspect-Based Sentiment Analysis)",
18
- description="Choose a model (Araberta, mT5, mBART, GPT) to extract aspects, opinions, and sentiment using LoRA adapters"
19
  )
20
 
21
  if __name__ == "__main__":
 
15
  ],
16
  outputs=gr.Textbox(label="Extracted Aspect-Sentiment-Opinion Triplets"),
17
  title="Arabic ABSA (Aspect-Based Sentiment Analysis)",
18
+ description="Choose a model (Araberta, mT5, GPT) to extract aspects, opinions, and sentiment using LoRA adapters"
19
  )
20
 
21
  if __name__ == "__main__":
inference.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
4
  from peft import LoraConfig, get_peft_model
5
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
6
- from seq2seq_inference import infer_t5_prompt, infer_mBart_prompt
7
  from peft import LoraConfig, get_peft_model, PeftModel
8
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
9
 
@@ -18,10 +18,6 @@ MODEL_OPTIONS = {
18
  "base": "google/mt5-base",
19
  "adapter": "asmashayea/mt4-absa"
20
  },
21
- # "mBART": {
22
- # "base": "facebook/mbart-large-50-many-to-many-mmt",
23
- # "adapter": "asmashayea/mbart-absa"
24
- # },
25
  "GPT3.5": {
26
  "base": "bigscience/bloom-560m", # example, not ideal for ABSA
27
  "adapter": "asmashayea/gpt-absa"
@@ -131,17 +127,16 @@ def load_model(model_key):
131
  cached_models[model_key] = (tokenizer, model)
132
  return tokenizer, model
133
 
 
 
 
134
  def predict_absa(text, model_choice):
135
 
136
 
137
- if model_choice == 'mT5':
138
  tokenizer, model = load_model(model_choice)
139
  decoded = infer_t5_prompt(text, tokenizer, model)
140
 
141
- elif model_choice == 'mBART':
142
- tokenizer, model = load_model(model_choice)
143
- decoded = infer_mBart_prompt(text, tokenizer, model)
144
-
145
  elif model_choice == 'Araberta':
146
 
147
  decoded = infer_araberta(text)
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
4
  from peft import LoraConfig, get_peft_model
5
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
6
+ from seq2seq_inference import infer_t5_prompt
7
  from peft import LoraConfig, get_peft_model, PeftModel
8
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
9
 
 
18
  "base": "google/mt5-base",
19
  "adapter": "asmashayea/mt4-absa"
20
  },
 
 
 
 
21
  "GPT3.5": {
22
  "base": "bigscience/bloom-560m", # example, not ideal for ABSA
23
  "adapter": "asmashayea/gpt-absa"
 
127
  cached_models[model_key] = (tokenizer, model)
128
  return tokenizer, model
129
 
130
+
131
+
132
+
133
  def predict_absa(text, model_choice):
134
 
135
 
136
+ if model_choice in ['mT5', 'mBART']:
137
  tokenizer, model = load_model(model_choice)
138
  decoded = infer_t5_prompt(text, tokenizer, model)
139
 
 
 
 
 
140
  elif model_choice == 'Araberta':
141
 
142
  decoded = infer_araberta(text)
seq2seq_inference.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from peft import PeftModel
5
 
6
- # Updated Prediction Function for mBART
7
  SYSTEM_PROMPT = (
8
  "You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.\n\n"
9
  "Instructions:\n"
@@ -19,38 +18,14 @@ SYSTEM_PROMPT = (
19
  )
20
 
21
 
22
- def infer_mBart_prompt(review_text, tokenizer, model):
23
- # Set target language for mBART
24
- tokenizer.tgt_lang = "ar_AR" # Change as needed ("en_XX" for English)
25
 
26
- prompt = f"{SYSTEM_PROMPT}\nReview: {review_text}"
 
27
 
28
- inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(model.device)
29
 
30
  with torch.no_grad():
31
- outputs = model.generate(
32
- **inputs,
33
- max_new_tokens=128,
34
- do_sample=False,
35
- temperature=0.0,
36
- forced_bos_token_id=tokenizer.convert_tokens_to_ids(tokenizer.tgt_lang), # safer
37
- pad_token_id=tokenizer.pad_token_id
38
- )
39
-
40
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).replace("ar_AR ", "").replace("en_XX ", "").strip()
41
- return decoded
42
-
43
-
44
-
45
-
46
- def infer_t5_prompt(review_text, tokenizer, model):
47
- prompt = (
48
- SYSTEM_PROMPT + f"\n\nReview: {review_text}"
49
- )
50
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
51
-
52
- with torch.no_grad():
53
- outputs = model.generate(
54
  **inputs,
55
  max_new_tokens=256,
56
  num_beams=4,
@@ -62,18 +37,14 @@ def infer_t5_prompt(review_text, tokenizer, model):
62
  )
63
 
64
  decoded = tokenizer.decode(
65
- outputs[0],
66
- skip_special_tokens=True,
67
  clean_up_tokenization_spaces=False
68
  ).strip()
69
 
70
- # Optional: remove T5 special tokens like <extra_id_0>
71
  decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
72
 
73
  try:
74
- json_output = json.loads(decoded)
75
- except json.JSONDecodeError as e:
76
- print(f"⚠️ JSON decode error: {e}. Returning raw output.")
77
- json_output = decoded
78
-
79
- return json_output
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from peft import PeftModel
5
 
 
6
  SYSTEM_PROMPT = (
7
  "You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.\n\n"
8
  "Instructions:\n"
 
18
  )
19
 
20
 
 
 
 
21
 
22
+ def infer_t5_prompt(review_text, tokenizer, peft_model):
23
+ prompt = SYSTEM_PROMPT + f"\n\nReview: {review_text}"
24
 
25
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(peft_model.device)
26
 
27
  with torch.no_grad():
28
+ outputs = peft_model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  **inputs,
30
  max_new_tokens=256,
31
  num_beams=4,
 
37
  )
38
 
39
  decoded = tokenizer.decode(
40
+ outputs[0],
41
+ skip_special_tokens=True,
42
  clean_up_tokenization_spaces=False
43
  ).strip()
44
 
 
45
  decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
46
 
47
  try:
48
+ return json.loads(decoded)
49
+ except json.JSONDecodeError:
50
+ return decoded