asmashayea commited on
Commit
8115742
·
1 Parent(s): bfe203c
Files changed (1) hide show
  1. inference.py +34 -19
inference.py CHANGED
@@ -1,8 +1,13 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
3
  from peft import PeftModel
4
 
5
- # Define supported models and their adapter IDs
6
  MODEL_OPTIONS = {
7
  "mT5": {
8
  "base": "google/mt5-base",
@@ -12,12 +17,13 @@ MODEL_OPTIONS = {
12
  "base": "facebook/mbart-large-50-many-to-many-mmt",
13
  "adapter": "asmashayea/mbart-absa"
14
  },
 
15
  "GPT3.5": {
16
- "base": "bigscience/bloom-560m", # placeholder
17
  "adapter": "asmashayea/gpt-absa"
18
  },
19
  "GPT4o": {
20
- "base": "bigscience/bloom-560m", # placeholder
21
  "adapter": "asmashayea/gpt-absa"
22
  }
23
  }
@@ -31,8 +37,14 @@ def load_model(model_key):
31
  base_id = MODEL_OPTIONS[model_key]["base"]
32
  adapter_id = MODEL_OPTIONS[model_key]["adapter"]
33
 
34
- tokenizer = AutoTokenizer.from_pretrained(adapter_id)
35
- base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
 
 
 
 
 
 
36
  model = PeftModel.from_pretrained(base_model, adapter_id)
37
  model.eval()
38
 
@@ -41,22 +53,25 @@ def load_model(model_key):
41
 
42
  def predict_absa(text, model_choice):
43
  tokenizer, model = load_model(model_choice)
44
- prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text}"
45
 
46
  if model_choice == "mBART":
47
- # Critical fix
48
- tokenizer.src_lang = "ar_AR"
49
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
50
- forced_bos_token_id = tokenizer.lang_code_to_id["ar_AR"]
51
- outputs = model.generate(
52
- **inputs,
53
- max_new_tokens=128,
54
- forced_bos_token_id=forced_bos_token_id # Force Arabic generation
55
- )
 
 
 
56
  else:
57
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
58
- outputs = model.generate(**inputs, max_new_tokens=128)
 
59
 
60
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
  return decoded
62
-
 
1
  import torch
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSeq2SeqLM,
5
+ MBartForConditionalGeneration,
6
+ MBart50TokenizerFast
7
+ )
8
  from peft import PeftModel
9
 
10
+ # Supported models and their adapter IDs
11
  MODEL_OPTIONS = {
12
  "mT5": {
13
  "base": "google/mt5-base",
 
17
  "base": "facebook/mbart-large-50-many-to-many-mmt",
18
  "adapter": "asmashayea/mbart-absa"
19
  },
20
+ # You can customize GPT-like entries later
21
  "GPT3.5": {
22
+ "base": "bigscience/bloom-560m", # Placeholder only
23
  "adapter": "asmashayea/gpt-absa"
24
  },
25
  "GPT4o": {
26
+ "base": "bigscience/bloom-560m", # Placeholder only
27
  "adapter": "asmashayea/gpt-absa"
28
  }
29
  }
 
37
  base_id = MODEL_OPTIONS[model_key]["base"]
38
  adapter_id = MODEL_OPTIONS[model_key]["adapter"]
39
 
40
+ if model_key == "mBART":
41
+ tokenizer = MBart50TokenizerFast.from_pretrained(base_id)
42
+ tokenizer.src_lang = "ar_AR" # Required for input
43
+ base_model = MBartForConditionalGeneration.from_pretrained(base_id)
44
+ else:
45
+ tokenizer = AutoTokenizer.from_pretrained(adapter_id)
46
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
47
+
48
  model = PeftModel.from_pretrained(base_model, adapter_id)
49
  model.eval()
50
 
 
53
 
54
  def predict_absa(text, model_choice):
55
  tokenizer, model = load_model(model_choice)
56
+ prompt = f"استخرج الجوانب والآراء والمشاعر من النص التالي:\n{text.strip()}"
57
 
58
  if model_choice == "mBART":
59
+ tokenizer.tgt_lang = "ar_AR" # Required for output
60
+ inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
61
+
62
+ with torch.no_grad():
63
+ outputs = model.generate(
64
+ **inputs,
65
+ max_new_tokens=128,
66
+ do_sample=False,
67
+ temperature=0.0,
68
+ forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"],
69
+ pad_token_id=tokenizer.pad_token_id,
70
+ )
71
  else:
72
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
73
+ with torch.no_grad():
74
+ outputs = model.generate(**inputs, max_new_tokens=128)
75
 
76
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
77
  return decoded