Commit
·
974d5cf
1
Parent(s):
8c4c5ff
add custom handler
Browse files- handler.py +80 -0
- requirements.txt +0 -0
handler.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from unsloth import FastLanguageModel
|
3 |
+
class EndpointHandler():
|
4 |
+
def __init__(self, path="prashanthbsp/DeepSeek-R1-Distill-Llama-8B-unsloth-bnb-4bit-reasoning-cpg-entity-v1"):
|
5 |
+
# Preload all the elements you are going to need at inference.
|
6 |
+
# pseudo:
|
7 |
+
# self.model= load_model(path)
|
8 |
+
max_seq_length = 2048
|
9 |
+
dtype = None
|
10 |
+
load_in_4bit = True
|
11 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
12 |
+
model_name = path,
|
13 |
+
max_seq_length = max_seq_length,
|
14 |
+
dtype = dtype,
|
15 |
+
load_in_4bit = load_in_4bit,
|
16 |
+
)
|
17 |
+
self.model = model
|
18 |
+
self.tokenizer = tokenizer
|
19 |
+
|
20 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
21 |
+
"""
|
22 |
+
data args:
|
23 |
+
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
24 |
+
kwargs
|
25 |
+
Return:
|
26 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
27 |
+
"""
|
28 |
+
|
29 |
+
# pseudo
|
30 |
+
# self.model(input)
|
31 |
+
inputs = data.pop("inputs", data)
|
32 |
+
context = inputs.pop("context", inputs)
|
33 |
+
prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
|
34 |
+
Write a response that appropriately completes the request.
|
35 |
+
Before answering, think carefully about the task to ensure a logical and accurate response.
|
36 |
+
|
37 |
+
### Instruction
|
38 |
+
You are a helpful assistant analyzing social media posts. Your task is to extract ANY food, beverage, or supplement entities mentioned in the post and determine whether each entity is used as an ingredient or consumed as a product.
|
39 |
+
|
40 |
+
Guidelines:
|
41 |
+
- Extract ONLY food, beverage, or supplement entities mentioned in the post
|
42 |
+
- An entity is considered an ingredient if it's used as part of a recipe or combined with other foods
|
43 |
+
- An entity is considered a product if it's a food, beverage, or supplement consumed as is
|
44 |
+
- Focus on specific items rather than general categories when possible
|
45 |
+
|
46 |
+
Main thing to note - we ONLY want to extract food, beverage, or supplement entities, nothing else
|
47 |
+
|
48 |
+
Output in JSON format only:
|
49 |
+
{{
|
50 |
+
"entities": [
|
51 |
+
{{
|
52 |
+
"entity": "name of first entity",
|
53 |
+
"type": "ingredient or product"
|
54 |
+
}},
|
55 |
+
{{
|
56 |
+
"entity": "name of second entity",
|
57 |
+
"type": "ingredient or product"
|
58 |
+
}}
|
59 |
+
]
|
60 |
+
}}
|
61 |
+
|
62 |
+
If no entities are found, output:
|
63 |
+
{{
|
64 |
+
"entities": []
|
65 |
+
}}
|
66 |
+
|
67 |
+
### Social Media Post:
|
68 |
+
{0}
|
69 |
+
### Response:
|
70 |
+
<think>{1}"""
|
71 |
+
FastLanguageModel.for_inference(model)
|
72 |
+
inputs = tokenizer([prompt_style.format(context, "")], return_tensors="pt").to("cuda")
|
73 |
+
outputs = model.generate(
|
74 |
+
input_ids=inputs.input_ids,
|
75 |
+
attention_mask=inputs.attention_mask,
|
76 |
+
max_new_tokens=1200,
|
77 |
+
use_cache=True,
|
78 |
+
)
|
79 |
+
response = tokenizer.batch_decode(outputs)
|
80 |
+
return response[0].split("### Response:")[1]
|
requirements.txt
ADDED
File without changes
|