chansung commited on
Commit
741ff0c
·
1 Parent(s): de0a8a7

Update gen.py

Browse files
Files changed (1) hide show
  1. gen.py +30 -86
gen.py CHANGED
@@ -1,5 +1,6 @@
1
  import gc
2
  import copy
 
3
  from tenacity import RetryError
4
  from tenacity import retry, stop_after_attempt, wait_fixed
5
 
@@ -13,6 +14,7 @@ from transformers import (
13
  MinNewTokensLengthLogitsProcessor,
14
  TemperatureLogitsWarper,
15
  TopPLogitsWarper,
 
16
  )
17
 
18
  def get_output_batch(
@@ -56,6 +58,11 @@ class StreamModel:
56
  self.tokenizer = tokenizer
57
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
58
 
 
 
 
 
 
59
  def __call__(
60
  self,
61
  prompt,
@@ -71,82 +78,40 @@ class StreamModel:
71
  logprobs = max(logprobs, 0)
72
 
73
  # bigger than 1
74
- chunk_size = 3
75
  chunk_count = 0
76
 
77
  # Generate completion tokens.
78
- final_tokens = torch.empty(0).to(self.device)
79
 
80
- try:
81
- for tokens in self.generate(
82
- input_ids[None, :].repeat(n, 1),
83
- logprobs=logprobs,
84
- min_new_tokens=min_tokens,
85
- max_new_tokens=max_tokens,
86
- temperature=temperature,
87
- top_p=top_p,
88
- ):
89
- if chunk_count < chunk_size:
90
- chunk_count = chunk_count + 1
91
-
92
- final_tokens = torch.cat((final_tokens, tokens))
93
-
94
- if chunk_count == chunk_size-1:
95
- chunk_count = 0
96
- yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
97
 
98
- if chunk_count > 0:
 
99
  yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
100
 
101
- except RetryError as e:
102
- print(e)
103
- del input_ids
104
- gc.collect()
105
 
106
- del final_tokens
107
  if self.device == "cuda":
108
  torch.cuda.empty_cache()
109
 
110
- @retry(stop=stop_after_attempt(5), wait=wait_fixed(1))
111
  def _infer(self, model_fn, **kwargs):
112
- """Call a model function in inference mode with auto retrying."""
113
- # This is a temporary workaround for bitsandbytes #162:
114
- # https://github.com/TimDettmers/bitsandbytes/issues/162
115
  with torch.inference_mode():
116
  return model_fn(**kwargs)
117
 
118
- def _logits_processor(self, config, input_length):
119
- """Set up logits processor based on the generation config."""
120
- processor = LogitsProcessorList()
121
-
122
- # Add processor for enforcing a min-length of new tokens.
123
- if (
124
- config.min_new_tokens is not None
125
- and config.min_new_tokens > 0
126
- and config.eos_token_id is not None
127
- ):
128
- processor.append(
129
- MinNewTokensLengthLogitsProcessor(
130
- prompt_length_to_skip=input_length,
131
- min_new_tokens=config.min_new_tokens,
132
- eos_token_id=config.eos_token_id,
133
- )
134
- )
135
-
136
- # Add processor for scaling output probability distribution.
137
- if (
138
- config.temperature is not None
139
- and config.temperature > 0
140
- and config.temperature != 1.0
141
- ):
142
- processor.append(TemperatureLogitsWarper(config.temperature))
143
-
144
- # Add processor for nucleus sampling.
145
- if config.top_p is not None and config.top_p > 0 and config.top_p < 1:
146
- processor.append(TopPLogitsWarper(config.top_p))
147
-
148
- return processor
149
-
150
  def tokenize(self, text):
151
  """Tokenize a string into a tensor of token IDs."""
152
  batch = self.tokenizer.encode(text, return_tensors="pt")
@@ -165,7 +130,7 @@ class StreamModel:
165
  kwargs = config.update(**kwargs)
166
  kwargs["output_attentions"] = False
167
  kwargs["output_hidden_states"] = False
168
- kwargs["use_cache"] = True # config.use_cache
169
 
170
  # Collect special token IDs.
171
  pad_token_id = config.pad_token_id
@@ -183,28 +148,6 @@ class StreamModel:
183
  input_ids = input_ids * eos_token_id[0]
184
  input_length = 1
185
 
186
- # Prepare inputs for encoder-decoder models.
187
- if self.model.config.is_encoder_decoder:
188
- # Get outputs from the encoder.
189
- encoder = self.model.get_encoder()
190
- encoder_kwargs = kwargs.copy()
191
- encoder_kwargs.pop("use_cache", None)
192
- encoder_kwargs["input_ids"] = input_ids
193
- encoder_kwargs["return_dict"] = True
194
- encoder_outputs = self._infer(encoder, **encoder_kwargs)
195
- kwargs["encoder_outputs"] = encoder_outputs
196
-
197
- # Reinitialize inputs for the decoder.
198
- decoder_start_token_id = config.decoder_start_token_id
199
- if decoder_start_token_id is None:
200
- decoder_start_token_id = bos_token_id
201
- input_ids = input_ids.new_ones((batch_size, 1))
202
- input_ids = input_ids * decoder_start_token_id
203
- input_length = 1
204
-
205
- # Set up logits processor.
206
- processor = self._logits_processor(config, input_length)
207
-
208
  # Keep track of which sequences are already finished.
209
  unfinished = input_ids.new_ones(batch_size)
210
 
@@ -213,10 +156,11 @@ class StreamModel:
213
  inputs = self.model.prepare_inputs_for_generation(
214
  input_ids, **kwargs
215
  ) # noqa: E501
 
216
  outputs = self._infer(
217
  self.model,
218
  **inputs,
219
- return_dict=True,
220
  output_attentions=False,
221
  output_hidden_states=False,
222
  )
@@ -224,7 +168,7 @@ class StreamModel:
224
  # Pre-process the probability distribution of the next tokens.
225
  logits = outputs.logits[:, -1, :]
226
  with torch.inference_mode():
227
- logits = processor(input_ids, logits)
228
  probs = torch.nn.functional.softmax(logits, dim=-1)
229
 
230
  # Select deterministic or stochastic decoding strategy.
 
1
  import gc
2
  import copy
3
+ import time
4
  from tenacity import RetryError
5
  from tenacity import retry, stop_after_attempt, wait_fixed
6
 
 
14
  MinNewTokensLengthLogitsProcessor,
15
  TemperatureLogitsWarper,
16
  TopPLogitsWarper,
17
+ MinLengthLogitsProcessor
18
  )
19
 
20
  def get_output_batch(
 
58
  self.tokenizer = tokenizer
59
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
60
 
61
+ self.processor = LogitsProcessorList()
62
+ self.processor.append(TemperatureLogitsWarper(0.9))
63
+ self.processor.append(TopPLogitsWarper(0.75))
64
+
65
+
66
  def __call__(
67
  self,
68
  prompt,
 
78
  logprobs = max(logprobs, 0)
79
 
80
  # bigger than 1
81
+ chunk_size = 2
82
  chunk_count = 0
83
 
84
  # Generate completion tokens.
85
+ final_tokens = torch.empty(0)
86
 
87
+ for tokens in self.generate(
88
+ input_ids[None, :].repeat(n, 1),
89
+ logprobs=logprobs,
90
+ min_new_tokens=min_tokens,
91
+ max_new_tokens=max_tokens,
92
+ temperature=temperature,
93
+ top_p=top_p,
94
+ ):
95
+ if chunk_count < chunk_size:
96
+ chunk_count = chunk_count + 1
97
+
98
+ final_tokens = torch.cat((final_tokens, tokens.to("cpu")))
 
 
 
 
 
99
 
100
+ if chunk_count == chunk_size-1:
101
+ chunk_count = 0
102
  yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
103
 
104
+ if chunk_count > 0:
105
+ yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
 
 
106
 
107
+ del final_tokens, input_ids
108
  if self.device == "cuda":
109
  torch.cuda.empty_cache()
110
 
 
111
  def _infer(self, model_fn, **kwargs):
 
 
 
112
  with torch.inference_mode():
113
  return model_fn(**kwargs)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def tokenize(self, text):
116
  """Tokenize a string into a tensor of token IDs."""
117
  batch = self.tokenizer.encode(text, return_tensors="pt")
 
130
  kwargs = config.update(**kwargs)
131
  kwargs["output_attentions"] = False
132
  kwargs["output_hidden_states"] = False
133
+ kwargs["use_cache"] = True
134
 
135
  # Collect special token IDs.
136
  pad_token_id = config.pad_token_id
 
148
  input_ids = input_ids * eos_token_id[0]
149
  input_length = 1
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # Keep track of which sequences are already finished.
152
  unfinished = input_ids.new_ones(batch_size)
153
 
 
156
  inputs = self.model.prepare_inputs_for_generation(
157
  input_ids, **kwargs
158
  ) # noqa: E501
159
+
160
  outputs = self._infer(
161
  self.model,
162
  **inputs,
163
+ # return_dict=True,
164
  output_attentions=False,
165
  output_hidden_states=False,
166
  )
 
168
  # Pre-process the probability distribution of the next tokens.
169
  logits = outputs.logits[:, -1, :]
170
  with torch.inference_mode():
171
+ logits = self.processor(input_ids, logits)
172
  probs = torch.nn.functional.softmax(logits, dim=-1)
173
 
174
  # Select deterministic or stochastic decoding strategy.