syberWolf commited on
Commit
1e592e3
·
1 Parent(s): b47e2d8

try to actually use the GPU

Browse files
Files changed (1) hide show
  1. handler.py +20 -10
handler.py CHANGED
@@ -1,27 +1,37 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
-
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
 
 
7
  # load the model
8
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
9
  model = AutoModelForCausalLM.from_pretrained(
10
  "Qwen/Qwen2-1.5B-Instruct",
11
- torch_dtype="auto",
12
  device_map="auto"
13
- )
 
14
  # create inference pipeline
15
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
 
17
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  inputs = data.pop("inputs", data)
19
- parameters = data.pop("parameters", None)
 
 
 
 
 
 
 
 
 
 
20
 
21
  # pass inputs with all kwargs in data
22
- if parameters is not None:
23
- prediction = self.pipeline(inputs, **parameters)
24
- else:
25
- prediction = self.pipeline(inputs)
26
- # postprocess the prediction
27
  return prediction
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
  # load the model
10
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
11
  model = AutoModelForCausalLM.from_pretrained(
12
  "Qwen/Qwen2-1.5B-Instruct",
13
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
  device_map="auto"
15
+ ).to(device)
16
+
17
  # create inference pipeline
18
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
19
 
20
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
  inputs = data.pop("inputs", data)
22
+ parameters = data.pop("parameters", {})
23
+
24
+ # Ensure inputs are on the GPU if available
25
+ if isinstance(inputs, str):
26
+ inputs = [inputs]
27
+
28
+ # Tensor input handling
29
+ try:
30
+ inputs = torch.tensor(inputs).cuda() if torch.cuda.is_available() else torch.tensor(inputs)
31
+ except:
32
+ pass # If inputs are not tensors (e.g., strings), continue without conversion
33
 
34
  # pass inputs with all kwargs in data
35
+ prediction = self.pipeline(inputs, **parameters)
36
+
 
 
 
37
  return prediction