Schmadge commited on
Commit
840ecaa
·
1 Parent(s): 23577b9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -22
handler.py CHANGED
@@ -3,23 +3,44 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import Any, Dict
5
 
6
- class EndpointHandler:
7
- def __init__(self, path='', torch_dtype=torch.bfloat16, trust_remote_code=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
9
  self.model = AutoModelForCausalLM.from_pretrained(
10
  path,
11
  torch_dtype=torch_dtype,
12
  trust_remote_code=trust_remote_code
13
  )
14
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
15
-
 
 
16
  if tokenizer.pad_token_id is None:
17
  warnings.warn(
18
  "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
19
  )
20
  tokenizer.pad_token = tokenizer.eos_token
21
-
22
- tokenizer.padding_side = "right" # "left"
23
  self.tokenizer = tokenizer
24
 
25
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -39,21 +60,7 @@ class EndpointHandler:
39
  }
40
 
41
  def format_instruction(self, instruction):
42
- INSTRUCTION_KEY = "### Instruction:"
43
- RESPONSE_KEY = "### Response:"
44
- END_KEY = "### End"
45
- INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
46
- PROMPT_FOR_GENERATION_FORMAT = """{intro}
47
- {instruction_key}
48
- {instruction}
49
- {response_key}
50
- """.format(
51
- intro=INTRO_BLURB,
52
- instruction_key=INSTRUCTION_KEY,
53
- instruction="{instruction}",
54
- response_key=RESPONSE_KEY,
55
- )
56
- return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
57
 
58
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
59
  # process input
@@ -61,7 +68,7 @@ class EndpointHandler:
61
  parameters = data.pop("parameters", None)
62
 
63
  # preprocess
64
- s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=inputs)
65
  input_ids = self.tokenizer(s, return_tensors="pt").input_ids.to(self.device)
66
  gkw = {**self.generate_kwargs, **parameters}
67
  # pass inputs with all kwargs in data
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import Any, Dict
5
 
6
+ class InstructionTextGenerationPipeline:
7
+ INSTRUCTION_KEY = "### Instruction:"
8
+ RESPONSE_KEY = "### Response:"
9
+ END_KEY = "### End"
10
+ INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
11
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
12
+ {instruction_key}
13
+ {instruction}
14
+ {response_key}
15
+ """.format(
16
+ intro=INTRO_BLURB,
17
+ instruction_key=INSTRUCTION_KEY,
18
+ instruction="{instruction}",
19
+ response_key=RESPONSE_KEY,
20
+ )
21
 
22
+ def __init__(
23
+ self,
24
+ path,
25
+ torch_dtype=torch.bfloat16,
26
+ trust_remote_code=True,
27
+ ) -> None:
28
  self.model = AutoModelForCausalLM.from_pretrained(
29
  path,
30
  torch_dtype=torch_dtype,
31
  trust_remote_code=trust_remote_code
32
  )
33
+ tokenizer = AutoTokenizer.from_pretrained(
34
+ "mosaicml/mpt-7b-instruct",
35
+ trust_remote_code=trust_remote_code
36
+ )
37
  if tokenizer.pad_token_id is None:
38
  warnings.warn(
39
  "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
40
  )
41
  tokenizer.pad_token = tokenizer.eos_token
42
+
43
+ tokenizer.padding_side = "right"
44
  self.tokenizer = tokenizer
45
 
46
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
60
  }
61
 
62
  def format_instruction(self, instruction):
63
+ return self.PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
66
  # process input
 
68
  parameters = data.pop("parameters", None)
69
 
70
  # preprocess
71
+ s = self.format_instruction(instruction=inputs)
72
  input_ids = self.tokenizer(s, return_tensors="pt").input_ids.to(self.device)
73
  gkw = {**self.generate_kwargs, **parameters}
74
  # pass inputs with all kwargs in data