Spaces:
Sleeping
Sleeping
Commit
Β·
e597fdc
1
Parent(s):
80dc20b
- generative_inference.py +69 -1
- inference.py +6 -2
generative_inference.py
CHANGED
@@ -35,9 +35,12 @@ MODEL_OPTIONS = {
|
|
35 |
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
|
36 |
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
|
37 |
"GPT4o": {"base": "openai/gpt-4o",
|
38 |
-
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}
|
|
|
|
|
39 |
}
|
40 |
|
|
|
41 |
cached_models = {}
|
42 |
|
43 |
# β
Reusable for both mT5 + mBART
|
@@ -106,3 +109,68 @@ def infer_gpt_absa(text, model_key):
|
|
106 |
except Exception as e:
|
107 |
return {"error": str(e)}
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
|
36 |
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
|
37 |
"GPT4o": {"base": "openai/gpt-4o",
|
38 |
+
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"},
|
39 |
+
"DeepSeek": { "base": "deepseek-ai/deepseek-llm-7b-chat", "adapter": "deepseek7bchat-json-lora-gptparam" }
|
40 |
+
|
41 |
}
|
42 |
|
43 |
+
|
44 |
cached_models = {}
|
45 |
|
46 |
# β
Reusable for both mT5 + mBART
|
|
|
109 |
except Exception as e:
|
110 |
return {"error": str(e)}
|
111 |
|
112 |
+
|
113 |
+
|
114 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
115 |
+
|
116 |
+
def load_deepseek():
|
117 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
118 |
+
MODEL_OPTIONS["DeepSeek"]["base"],
|
119 |
+
device_map="auto",
|
120 |
+
torch_dtype=torch.float16,
|
121 |
+
trust_remote_code=True
|
122 |
+
)
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
124 |
+
MODEL_OPTIONS["DeepSeek"]["adapter"],
|
125 |
+
trust_remote_code=True
|
126 |
+
)
|
127 |
+
model = PeftModel.from_pretrained(base_model, MODEL_OPTIONS["DeepSeek"]["adapter"])
|
128 |
+
|
129 |
+
cached_models["DeepSeek"] = (tokenizer, model)
|
130 |
+
return tokenizer, model
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
def build_deepseek_prompt(review_text, output=""):
|
135 |
+
return f"""<|system|>
|
136 |
+
You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.
|
137 |
+
|
138 |
+
|
139 |
+
Instructions:
|
140 |
+
1. Extract only the aspects (nouns) mentioned in the review.
|
141 |
+
2. Assign a sentiment to each aspect: "positive", "negative", or "neutral".
|
142 |
+
3. Return aspects in the same language as they appear.
|
143 |
+
4. An aspect must be a noun that refers to a specific item or service the user described.
|
144 |
+
5. Ignore adjectives, general ideas, and vague topics.
|
145 |
+
6. Do NOT translate, explain, or add extra text.
|
146 |
+
7. The output must be just a valid JSON list with 'aspect' and 'sentiment'. Start with `[` and stop at `]`.
|
147 |
+
8. Do NOT output the instructions, review, or any text β only one output JSON list.
|
148 |
+
9. Just one output and one review.
|
149 |
+
<|user|>
|
150 |
+
{review_text}
|
151 |
+
<|assistant|>
|
152 |
+
{output}""" # β
include the output here
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def infer_deepseek(text):
|
157 |
+
tokenizer, model = cached_models.get("DeepSeek") or load_deepseek()
|
158 |
+
|
159 |
+
prompt = build_deepseek_prompt(text)
|
160 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(model.device)
|
161 |
+
|
162 |
+
with torch.no_grad():
|
163 |
+
outputs = model.generate(
|
164 |
+
**inputs,
|
165 |
+
max_new_tokens=128,
|
166 |
+
do_sample=False,
|
167 |
+
temperature=0.0,
|
168 |
+
pad_token_id=tokenizer.eos_token_id
|
169 |
+
)
|
170 |
+
|
171 |
+
decoded = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
|
172 |
+
try:
|
173 |
+
parsed = json.loads(decoded)
|
174 |
+
return parsed
|
175 |
+
except Exception as e:
|
176 |
+
return {"error": str(e), "raw": decoded}
|
inference.py
CHANGED
@@ -3,7 +3,7 @@ import json
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
|
4 |
from peft import LoraConfig, get_peft_model, PeftModel
|
5 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
6 |
-
from generative_inference import infer_t5_bart, infer_gpt_absa
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
# Define supported models and their adapter IDs
|
@@ -23,7 +23,8 @@ MODEL_OPTIONS = {
|
|
23 |
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
|
24 |
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
|
25 |
"GPT4o": {"base": "openai/gpt-4o",
|
26 |
-
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"}
|
|
|
27 |
}
|
28 |
|
29 |
|
@@ -139,4 +140,7 @@ def predict_absa(text, model_choice):
|
|
139 |
elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
|
140 |
decoded = infer_gpt_absa(text, model_choice)
|
141 |
|
|
|
|
|
|
|
142 |
return decoded
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, AutoConfig
|
4 |
from peft import LoraConfig, get_peft_model, PeftModel
|
5 |
from modeling_bilstm_crf import BERT_BiLSTM_CRF
|
6 |
+
from generative_inference import infer_t5_bart, infer_gpt_absa, infer_deepseek
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
# Define supported models and their adapter IDs
|
|
|
23 |
"GPT3.5": {"base": "openai/gpt-3.5-turbo",
|
24 |
"model_id": "ft:gpt-3.5-turbo-0125:asma:gpt-3-5-turbo-absa:Bb6gmwkE"},
|
25 |
"GPT4o": {"base": "openai/gpt-4o",
|
26 |
+
"model_id": "ft:gpt-4o-mini-2024-07-18:asma:gpt4-finetune-absa:BazoEjnp"},
|
27 |
+
"DeepSeek": { "base": "deepseek-ai/deepseek-llm-7b-chat", "adapter": "deepseek7bchat-json-lora-gptparam" }
|
28 |
}
|
29 |
|
30 |
|
|
|
140 |
elif model_choice == 'GPT3.5' or model_choice == 'GPT4o':
|
141 |
decoded = infer_gpt_absa(text, model_choice)
|
142 |
|
143 |
+
elif model_choice == "DeepSeek":
|
144 |
+
return infer_deepseek(text)
|
145 |
+
|
146 |
return decoded
|