asmashayea commited on
Commit
80dc20b
·
1 Parent(s): 598ac39
Files changed (2) hide show
  1. generative_inference.py +6 -5
  2. inference.py +6 -9
generative_inference.py CHANGED
@@ -32,8 +32,10 @@ MODEL_OPTIONS = {
32
  "base": "facebook/mbart-large-50-many-to-many-mmt",
33
  "adapter": "asmashayea/mbart-absa"
34
  },
35
- "GPT3.5": {"base": "openai/gpt-3.5-turbo"},
36
- "GPT4o": {"base": "openai/gpt-4o"}
 
 
37
  }
38
 
39
  cached_models = {}
@@ -79,12 +81,11 @@ def infer_t5_bart(text, model_choice):
79
 
80
 
81
 
82
-
83
  OPENAI_API_KEY = "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"
84
  openai.api_key = OPENAI_API_KEY
85
- MODEL_ID = "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"
86
 
87
- def infer_gpt_absa(text):
 
88
  try:
89
  response = openai.chat.completions.create(
90
  model=MODEL_ID,
 
32
  "base": "facebook/mbart-large-50-many-to-many-mmt",
33
  "adapter": "asmashayea/mbart-absa"
34
  },
35
+ "GPT3.5": {"base": "openai/gpt-3.5-turbo",
36
+ "model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
37
+ "GPT4o": {"base": "openai/gpt-4o",
38
+ "model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}
39
  }
40
 
41
  cached_models = {}
 
81
 
82
 
83
 
 
84
  OPENAI_API_KEY = "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"
85
  openai.api_key = OPENAI_API_KEY
 
86
 
87
+ def infer_gpt_absa(text, model_key):
88
+ MODEL_ID = MODEL_OPTIONS[model_key]["model_id"]
89
  try:
90
  response = openai.chat.completions.create(
91
  model=MODEL_ID,
inference.py CHANGED
@@ -8,7 +8,6 @@ from huggingface_hub import hf_hub_download
8
 
9
  # Define supported models and their adapter IDs
10
  MODEL_OPTIONS = {
11
-
12
  "Araberta": {
13
  "base": "asmashayea/absa-araberta",
14
  "adapter": "asmashayea/absa-araberta"
@@ -21,10 +20,13 @@ MODEL_OPTIONS = {
21
  "base": "facebook/mbart-large-50-many-to-many-mmt",
22
  "adapter": "asmashayea/mbart-absa"
23
  },
24
- "GPT3.5": {"base": "openai/gpt-3.5-turbo"},
25
- "GPT4o": {"base": "openai/gpt-4o"}
 
 
26
  }
27
 
 
28
  cached_models = {}
29
 
30
  def load_araberta():
@@ -123,11 +125,6 @@ def infer_araberta(text):
123
  }
124
 
125
 
126
-
127
-
128
-
129
-
130
-
131
 
132
  def predict_absa(text, model_choice):
133
 
@@ -140,6 +137,6 @@ def predict_absa(text, model_choice):
140
  decoded = infer_araberta(text)
141
 
142
  elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
143
- decoded = infer_gpt_absa(text)
144
 
145
  return decoded
 
8
 
9
  # Define supported models and their adapter IDs
10
  MODEL_OPTIONS = {
 
11
  "Araberta": {
12
  "base": "asmashayea/absa-araberta",
13
  "adapter": "asmashayea/absa-araberta"
 
20
  "base": "facebook/mbart-large-50-many-to-many-mmt",
21
  "adapter": "asmashayea/mbart-absa"
22
  },
23
+ "GPT3.5": {"base": "openai/gpt-3.5-turbo",
24
+ "model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
25
+ "GPT4o": {"base": "openai/gpt-4o",
26
+ "model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}
27
  }
28
 
29
+
30
  cached_models = {}
31
 
32
  def load_araberta():
 
125
  }
126
 
127
 
 
 
 
 
 
128
 
129
  def predict_absa(text, model_choice):
130
 
 
137
  decoded = infer_araberta(text)
138
 
139
  elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
140
+ decoded = infer_gpt_absa(text, model_choice)
141
 
142
  return decoded