asmashayea commited on
Commit
aef0118
·
1 Parent(s): 8115742
Files changed (2) hide show
  1. inference.py +20 -37
  2. seq2seq_inference.py +78 -0
inference.py CHANGED
@@ -1,13 +1,8 @@
1
  import torch
2
- from transformers import (
3
- AutoTokenizer,
4
- AutoModelForSeq2SeqLM,
5
- MBartForConditionalGeneration,
6
- MBart50TokenizerFast
7
- )
8
  from peft import PeftModel
9
-
10
- # Supported models and their adapter IDs
11
  MODEL_OPTIONS = {
12
  "mT5": {
13
  "base": "google/mt5-base",
@@ -17,13 +12,12 @@ MODEL_OPTIONS = {
17
  "base": "facebook/mbart-large-50-many-to-many-mmt",
18
  "adapter": "asmashayea/mbart-absa"
19
  },
20
- # You can customize GPT-like entries later
21
  "GPT3.5": {
22
- "base": "bigscience/bloom-560m", # Placeholder only
23
  "adapter": "asmashayea/gpt-absa"
24
  },
25
  "GPT4o": {
26
- "base": "bigscience/bloom-560m", # Placeholder only
27
  "adapter": "asmashayea/gpt-absa"
28
  }
29
  }
@@ -37,14 +31,8 @@ def load_model(model_key):
37
  base_id = MODEL_OPTIONS[model_key]["base"]
38
  adapter_id = MODEL_OPTIONS[model_key]["adapter"]
39
 
40
- if model_key == "mBART":
41
- tokenizer = MBart50TokenizerFast.from_pretrained(base_id)
42
- tokenizer.src_lang = "ar_AR" # Required for input
43
- base_model = MBartForConditionalGeneration.from_pretrained(base_id)
44
- else:
45
- tokenizer = AutoTokenizer.from_pretrained(adapter_id)
46
- base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
47
-
48
  model = PeftModel.from_pretrained(base_model, adapter_id)
49
  model.eval()
50
 
@@ -52,26 +40,21 @@ def load_model(model_key):
52
  return tokenizer, model
53
 
54
  def predict_absa(text, model_choice):
 
55
  tokenizer, model = load_model(model_choice)
56
- prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text.strip()}"
57
 
58
- if model_choice == "mBART":
59
- tokenizer.tgt_lang = "ar_AR" # Required for output
60
- inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
61
 
62
- with torch.no_grad():
63
- outputs = model.generate(
64
- **inputs,
65
- max_new_tokens=128,
66
- do_sample=False,
67
- temperature=0.0,
68
- forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"],
69
- pad_token_id=tokenizer.pad_token_id,
70
- )
71
- else:
72
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
73
- with torch.no_grad():
74
- outputs = model.generate(**inputs, max_new_tokens=128)
75
 
76
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
77
  return decoded
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
3
  from peft import PeftModel
4
+ from seq2seq_inference import infer_t5_prompt, infer_mBart_prompt
5
+ # Define supported models and their adapter IDs
6
  MODEL_OPTIONS = {
7
  "mT5": {
8
  "base": "google/mt5-base",
 
12
  "base": "facebook/mbart-large-50-many-to-many-mmt",
13
  "adapter": "asmashayea/mbart-absa"
14
  },
 
15
  "GPT3.5": {
16
+ "base": "bigscience/bloom-560m", # example, not ideal for ABSA
17
  "adapter": "asmashayea/gpt-absa"
18
  },
19
  "GPT4o": {
20
+ "base": "bigscience/bloom-560m", # example, not ideal for ABSA
21
  "adapter": "asmashayea/gpt-absa"
22
  }
23
  }
 
31
  base_id = MODEL_OPTIONS[model_key]["base"]
32
  adapter_id = MODEL_OPTIONS[model_key]["adapter"]
33
 
34
+ tokenizer = AutoTokenizer.from_pretrained(adapter_id)
35
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
 
 
 
 
 
 
36
  model = PeftModel.from_pretrained(base_model, adapter_id)
37
  model.eval()
38
 
 
40
  return tokenizer, model
41
 
42
  def predict_absa(text, model_choice):
43
+
44
  tokenizer, model = load_model(model_choice)
 
45
 
 
 
 
46
 
47
+ if model_choice == 'mT5':
48
+ decoded = infer_t5_prompt(text, tokenizer, model)
49
+
50
+ elif model_choice == 'mBART':
51
+ decoded = infer_mBart_prompt(text, tokenizer, model)
52
+
53
+ # prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
54
+ # inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
55
+
56
+ # with torch.no_grad():
57
+ # outputs = model.generate(**inputs, max_new_tokens=128)
 
 
58
 
59
+ # decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
  return decoded
seq2seq_inference.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ from peft import PeftModel
5
+
6
+ # Updated Prediction Function for mBART
7
+ SYSTEM_PROMPT = (
8
+ "You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.\n\n"
9
+ "Instructions:\n"
10
+ "1. Extract only the aspects (nouns) mentioned in the review.\n"
11
+ "2. Assign a sentiment to each aspect: \"positive\", \"negative\", or \"neutral\".\n"
12
+ "3. Return aspects in the same language as they appear.\n"
13
+ "4. An aspect must be a noun that refers to a specific item or service the user described.\n"
14
+ "5. Ignore adjectives, general ideas, and vague topics.\n"
15
+ "6. Do NOT translate, explain, or add extra text.\n"
16
+ "7. The output must be just a valid JSON list with 'aspect' and 'sentiment'. Start with `[` and stop at `]`.\n"
17
+ "8. Do NOT output the instructions, review, or any text — only one output JSON list.\n"
18
+ "9. Just one output and one review."
19
+ )
20
+
21
+ def infer_mBart_prompt(review_text, tokenizer, model):
22
+ # Set target language for mBART
23
+ tokenizer.tgt_lang = "ar_AR" # Change as needed ("en_XX" for English)
24
+
25
+ prompt = f"{SYSTEM_PROMPT}\nReview: {review_text}"
26
+
27
+ inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(model.device)
28
+
29
+ with torch.no_grad():
30
+ outputs = model.generate(
31
+ **inputs,
32
+ max_new_tokens=128,
33
+ do_sample=False,
34
+ temperature=0.0,
35
+ forced_bos_token_id=tokenizer.convert_tokens_to_ids(tokenizer.tgt_lang), # safer
36
+ pad_token_id=tokenizer.pad_token_id
37
+ )
38
+
39
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).replace("ar_AR ", "").replace("en_XX ", "").strip()
40
+ return decoded
41
+
42
+
43
+
44
+
45
+ def infer_t5_prompt(review_text, tokenizer, model):
46
+ prompt = (
47
+ SYSTEM_PROMPT + f"\n\nReview: {review_text}"
48
+ )
49
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(peft_model.device)
50
+
51
+ with torch.no_grad():
52
+ outputs = model.generate(
53
+ **inputs,
54
+ max_new_tokens=256,
55
+ num_beams=4,
56
+ do_sample=False,
57
+ temperature=0.0,
58
+ early_stopping=True,
59
+ pad_token_id=tokenizer.pad_token_id,
60
+ eos_token_id=tokenizer.eos_token_id,
61
+ )
62
+
63
+ decoded = tokenizer.decode(
64
+ outputs[0],
65
+ skip_special_tokens=True,
66
+ clean_up_tokenization_spaces=False
67
+ ).strip()
68
+
69
+ # Optional: remove T5 special tokens like <extra_id_0>
70
+ decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
71
+
72
+ try:
73
+ json_output = json.loads(decoded)
74
+ except json.JSONDecodeError as e:
75
+ print(f"⚠️ JSON decode error: {e}. Returning raw output.")
76
+ json_output = decoded
77
+
78
+ return json_output