syberWolf commited on
Commit
96a0103
·
1 Parent(s): 6d33a5e

simplify handler

Browse files
Files changed (1) hide show
  1. handler.py +3 -22
handler.py CHANGED
@@ -1,18 +1,9 @@
1
  from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
- device = "cuda"
4
 
5
  class EndpointHandler:
6
- def __init__(self, path=""):
7
- # load the model
8
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
9
- model = AutoModelForCausalLM.from_pretrained(
10
- "Qwen/Qwen2-1.5B-Instruct",
11
- torch_dtype="auto",
12
- device_map="auto"
13
- )
14
- # create inference pipeline
15
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
 
17
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  inputs = data.pop("inputs", data)
@@ -26,13 +17,3 @@ class EndpointHandler:
26
 
27
  # postprocess the prediction
28
  return prediction
29
-
30
- # Example usage
31
- if __name__ == "__main__":
32
- handler = EndpointHandler()
33
- data = {
34
- "inputs": "Hello, how can I",
35
- "parameters": {"max_length": 50, "num_return_sequences": 1}
36
- }
37
- result = handler(data)
38
- print(result)
 
1
  from typing import Dict, List, Any
2
+ from transformers import pipeline
 
3
 
4
  class EndpointHandler:
5
+ def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct"):
6
+ self.pipeline = pipeline("text-generation", model=model_name) # Note: Model name provided as argument for flexibility
 
 
 
 
 
 
 
 
7
 
8
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
9
  inputs = data.pop("inputs", data)
 
17
 
18
  # postprocess the prediction
19
  return prediction