Spaces:
Sleeping
Sleeping
Commit
·
80dc20b
1
Parent(s):
598ac39
- generative_inference.py +6 -5
- inference.py +6 -9
generative_inference.py
CHANGED
@@ -32,8 +32,10 @@ MODEL_OPTIONS = {
|
|
32 |
"base": "facebook/mbart-large-50-many-to-many-mmt",
|
33 |
"adapter": "asmashayea/mbart-absa"
|
34 |
},
|
35 |
-
"GPT3.5": {"base": "openai/gpt-3.5-turbo"
|
36 |
-
|
|
|
|
|
37 |
}
|
38 |
|
39 |
cached_models = {}
|
@@ -79,12 +81,11 @@ def infer_t5_bart(text, model_choice):
|
|
79 |
|
80 |
|
81 |
|
82 |
-
|
83 |
OPENAI_API_KEY = "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"
|
84 |
openai.api_key = OPENAI_API_KEY
|
85 |
-
MODEL_ID = "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"
|
86 |
|
87 |
-
def infer_gpt_absa(text):
|
|
|
88 |
try:
|
89 |
response = openai.chat.completions.create(
|
90 |
model=MODEL_ID,
|
|
|
32 |
"base": "facebook/mbart-large-50-many-to-many-mmt",
|
33 |
"adapter": "asmashayea/mbart-absa"
|
34 |
},
|
35 |
+
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
|
36 |
+
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
|
37 |
+
"GPT4o": {"base": "openai/gpt-4o",
|
38 |
+
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}
|
39 |
}
|
40 |
|
41 |
cached_models = {}
|
|
|
81 |
|
82 |
|
83 |
|
|
|
84 |
OPENAI_API_KEY = "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"
|
85 |
openai.api_key = OPENAI_API_KEY
|
|
|
86 |
|
87 |
+
def infer_gpt_absa(text, model_key):
|
88 |
+
MODEL_ID = MODEL_OPTIONS[model_key]["model_id"]
|
89 |
try:
|
90 |
response = openai.chat.completions.create(
|
91 |
model=MODEL_ID,
|
inference.py
CHANGED
@@ -8,7 +8,6 @@ from huggingface_hub import hf_hub_download
|
|
8 |
|
9 |
# Define supported models and their adapter IDs
|
10 |
MODEL_OPTIONS = {
|
11 |
-
|
12 |
"Araberta": {
|
13 |
"base": "asmashayea/absa-araberta",
|
14 |
"adapter": "asmashayea/absa-araberta"
|
@@ -21,10 +20,13 @@ MODEL_OPTIONS = {
|
|
21 |
"base": "facebook/mbart-large-50-many-to-many-mmt",
|
22 |
"adapter": "asmashayea/mbart-absa"
|
23 |
},
|
24 |
-
"GPT3.5": {"base": "openai/gpt-3.5-turbo"
|
25 |
-
|
|
|
|
|
26 |
}
|
27 |
|
|
|
28 |
cached_models = {}
|
29 |
|
30 |
def load_araberta():
|
@@ -123,11 +125,6 @@ def infer_araberta(text):
|
|
123 |
}
|
124 |
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
|
132 |
def predict_absa(text, model_choice):
|
133 |
|
@@ -140,6 +137,6 @@ def predict_absa(text, model_choice):
|
|
140 |
decoded = infer_araberta(text)
|
141 |
|
142 |
elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
|
143 |
-
decoded = infer_gpt_absa(text)
|
144 |
|
145 |
return decoded
|
|
|
8 |
|
9 |
# Define supported models and their adapter IDs
|
10 |
MODEL_OPTIONS = {
|
|
|
11 |
"Araberta": {
|
12 |
"base": "asmashayea/absa-araberta",
|
13 |
"adapter": "asmashayea/absa-araberta"
|
|
|
20 |
"base": "facebook/mbart-large-50-many-to-many-mmt",
|
21 |
"adapter": "asmashayea/mbart-absa"
|
22 |
},
|
23 |
+
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
|
24 |
+
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
|
25 |
+
"GPT4o": {"base": "openai/gpt-4o",
|
26 |
+
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}
|
27 |
}
|
28 |
|
29 |
+
|
30 |
cached_models = {}
|
31 |
|
32 |
def load_araberta():
|
|
|
125 |
}
|
126 |
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
def predict_absa(text, model_choice):
|
130 |
|
|
|
137 |
decoded = infer_araberta(text)
|
138 |
|
139 |
elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
|
140 |
+
decoded = infer_gpt_absa(text, model_choice)
|
141 |
|
142 |
return decoded
|