syberWolf commited on
Commit
7e91a22
·
1 Parent(s): f96aa72

updates for phi

Browse files
Files changed (2) hide show
  1. handler.py +17 -8
  2. requirements.txt +4 -1
handler.py CHANGED
@@ -6,22 +6,31 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
  # load the model
9
- tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", trust_remote_code=True)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  "microsoft/Phi-3-mini-128k-instruct",
12
- trust_remote_code=True
13
- )
 
 
14
  # create inference pipeline
15
  self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
 
17
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  inputs = data.pop("inputs", data)
19
-
20
- for key in ['stop_sequences', 'watermark', 'stop']:
21
- if key in inputs:
22
- del inputs[key]
23
-
24
  parameters = data.pop("parameters", None)
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # pass inputs with all kwargs in data
27
  if parameters is not None:
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
  # load the model
9
+ tokenizer = tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
10
  model = AutoModelForCausalLM.from_pretrained(
11
  "microsoft/Phi-3-mini-128k-instruct",
12
+ device_map="cuda",
13
+ torch_dtype="auto",
14
+ trust_remote_code=True,
15
+ )
16
  # create inference pipeline
17
  self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
18
 
19
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
20
  inputs = data.pop("inputs", data)
 
 
 
 
 
21
  parameters = data.pop("parameters", None)
22
+
23
+ # Print parameters for debugging
24
+ print("Parameters before cleaning:", parameters)
25
+
26
+ # Remove unwanted keys from parameters
27
+ if parameters is not None:
28
+ for key in ['stop_sequences', 'watermark', 'stop']:
29
+ if key in parameters:
30
+ del parameters[key]
31
+
32
+ # Print parameters after cleaning
33
+ print("Parameters after cleaning:", parameters)
34
 
35
  # pass inputs with all kwargs in data
36
  if parameters is not None:
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- flash-attn==latest
 
 
 
 
1
+ flash_attn==2.5.8
2
+ torch==2.3.1
3
+ accelerate==0.31.0
4
+ transformers==4.41.2