syberWolf commited on
Commit
6d33a5e
·
1 Parent(s): e126c73

changes for testin

Browse files
Files changed (1) hide show
  1. handler.py +13 -13
handler.py CHANGED
@@ -1,30 +1,30 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
- import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # Load the model
8
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
9
  model = AutoModelForCausalLM.from_pretrained(
10
  "Qwen/Qwen2-1.5B-Instruct",
11
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
12
- device_map="cuda" if torch.cuda.is_available() else "auto" # Include device_map for correct device allocation
13
  )
14
-
15
- # Create inference pipeline without specifying the device
16
  self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
17
 
18
- def __call__(self, data: Any) -> List[List[Dict[str, Any]]]:
19
  inputs = data.pop("inputs", data)
20
- parameters = data.pop("parameters", {})
21
-
22
- if isinstance(inputs, str):
23
- inputs = [inputs]
24
 
25
- # Get predictions from the pipeline
26
- prediction = self.pipeline(inputs, **parameters)
 
 
 
27
 
 
28
  return prediction
29
 
30
  # Example usage
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ device = "cuda"
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ # load the model
8
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
9
  model = AutoModelForCausalLM.from_pretrained(
10
  "Qwen/Qwen2-1.5B-Instruct",
11
+ torch_dtype="auto",
12
+ device_map="auto"
13
  )
14
+ # create inference pipeline
 
15
  self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
 
17
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  inputs = data.pop("inputs", data)
19
+ parameters = data.pop("parameters", None)
 
 
 
20
 
21
+ # pass inputs with all kwargs in data
22
+ if parameters is not None:
23
+ prediction = self.pipeline(inputs, **parameters)
24
+ else:
25
+ prediction = self.pipeline(inputs)
26
 
27
+ # postprocess the prediction
28
  return prediction
29
 
30
  # Example usage