syberWolf commited on
Commit
e8628b3
·
1 Parent(s): 1e592e3
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -12,10 +12,10 @@ class EndpointHandler:
12
  "Qwen/Qwen2-1.5B-Instruct",
13
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
  device_map="auto"
15
- ).to(device)
16
 
17
- # create inference pipeline
18
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
19
 
20
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
  inputs = data.pop("inputs", data)
 
12
  "Qwen/Qwen2-1.5B-Instruct",
13
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
  device_map="auto"
15
+ )
16
 
17
+ # create inference pipeline without specifying the device
18
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
19
 
20
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
  inputs = data.pop("inputs", data)