Spaces:
Runtime error
Runtime error
Update gen.py
Browse files
gen.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gc
|
2 |
import copy
|
|
|
3 |
from tenacity import RetryError
|
4 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
5 |
|
@@ -13,6 +14,7 @@ from transformers import (
|
|
13 |
MinNewTokensLengthLogitsProcessor,
|
14 |
TemperatureLogitsWarper,
|
15 |
TopPLogitsWarper,
|
|
|
16 |
)
|
17 |
|
18 |
def get_output_batch(
|
@@ -56,6 +58,11 @@ class StreamModel:
|
|
56 |
self.tokenizer = tokenizer
|
57 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
58 |
|
|
|
|
|
|
|
|
|
|
|
59 |
def __call__(
|
60 |
self,
|
61 |
prompt,
|
@@ -71,82 +78,40 @@ class StreamModel:
|
|
71 |
logprobs = max(logprobs, 0)
|
72 |
|
73 |
# bigger than 1
|
74 |
-
chunk_size =
|
75 |
chunk_count = 0
|
76 |
|
77 |
# Generate completion tokens.
|
78 |
-
final_tokens = torch.empty(0)
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
final_tokens = torch.cat((final_tokens, tokens))
|
93 |
-
|
94 |
-
if chunk_count == chunk_size-1:
|
95 |
-
chunk_count = 0
|
96 |
-
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
|
97 |
|
98 |
-
if chunk_count
|
|
|
99 |
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
del input_ids
|
104 |
-
gc.collect()
|
105 |
|
106 |
-
del final_tokens
|
107 |
if self.device == "cuda":
|
108 |
torch.cuda.empty_cache()
|
109 |
|
110 |
-
@retry(stop=stop_after_attempt(5), wait=wait_fixed(1))
|
111 |
def _infer(self, model_fn, **kwargs):
|
112 |
-
"""Call a model function in inference mode with auto retrying."""
|
113 |
-
# This is a temporary workaround for bitsandbytes #162:
|
114 |
-
# https://github.com/TimDettmers/bitsandbytes/issues/162
|
115 |
with torch.inference_mode():
|
116 |
return model_fn(**kwargs)
|
117 |
|
118 |
-
def _logits_processor(self, config, input_length):
|
119 |
-
"""Set up logits processor based on the generation config."""
|
120 |
-
processor = LogitsProcessorList()
|
121 |
-
|
122 |
-
# Add processor for enforcing a min-length of new tokens.
|
123 |
-
if (
|
124 |
-
config.min_new_tokens is not None
|
125 |
-
and config.min_new_tokens > 0
|
126 |
-
and config.eos_token_id is not None
|
127 |
-
):
|
128 |
-
processor.append(
|
129 |
-
MinNewTokensLengthLogitsProcessor(
|
130 |
-
prompt_length_to_skip=input_length,
|
131 |
-
min_new_tokens=config.min_new_tokens,
|
132 |
-
eos_token_id=config.eos_token_id,
|
133 |
-
)
|
134 |
-
)
|
135 |
-
|
136 |
-
# Add processor for scaling output probability distribution.
|
137 |
-
if (
|
138 |
-
config.temperature is not None
|
139 |
-
and config.temperature > 0
|
140 |
-
and config.temperature != 1.0
|
141 |
-
):
|
142 |
-
processor.append(TemperatureLogitsWarper(config.temperature))
|
143 |
-
|
144 |
-
# Add processor for nucleus sampling.
|
145 |
-
if config.top_p is not None and config.top_p > 0 and config.top_p < 1:
|
146 |
-
processor.append(TopPLogitsWarper(config.top_p))
|
147 |
-
|
148 |
-
return processor
|
149 |
-
|
150 |
def tokenize(self, text):
|
151 |
"""Tokenize a string into a tensor of token IDs."""
|
152 |
batch = self.tokenizer.encode(text, return_tensors="pt")
|
@@ -165,7 +130,7 @@ class StreamModel:
|
|
165 |
kwargs = config.update(**kwargs)
|
166 |
kwargs["output_attentions"] = False
|
167 |
kwargs["output_hidden_states"] = False
|
168 |
-
kwargs["use_cache"] = True
|
169 |
|
170 |
# Collect special token IDs.
|
171 |
pad_token_id = config.pad_token_id
|
@@ -183,28 +148,6 @@ class StreamModel:
|
|
183 |
input_ids = input_ids * eos_token_id[0]
|
184 |
input_length = 1
|
185 |
|
186 |
-
# Prepare inputs for encoder-decoder models.
|
187 |
-
if self.model.config.is_encoder_decoder:
|
188 |
-
# Get outputs from the encoder.
|
189 |
-
encoder = self.model.get_encoder()
|
190 |
-
encoder_kwargs = kwargs.copy()
|
191 |
-
encoder_kwargs.pop("use_cache", None)
|
192 |
-
encoder_kwargs["input_ids"] = input_ids
|
193 |
-
encoder_kwargs["return_dict"] = True
|
194 |
-
encoder_outputs = self._infer(encoder, **encoder_kwargs)
|
195 |
-
kwargs["encoder_outputs"] = encoder_outputs
|
196 |
-
|
197 |
-
# Reinitialize inputs for the decoder.
|
198 |
-
decoder_start_token_id = config.decoder_start_token_id
|
199 |
-
if decoder_start_token_id is None:
|
200 |
-
decoder_start_token_id = bos_token_id
|
201 |
-
input_ids = input_ids.new_ones((batch_size, 1))
|
202 |
-
input_ids = input_ids * decoder_start_token_id
|
203 |
-
input_length = 1
|
204 |
-
|
205 |
-
# Set up logits processor.
|
206 |
-
processor = self._logits_processor(config, input_length)
|
207 |
-
|
208 |
# Keep track of which sequences are already finished.
|
209 |
unfinished = input_ids.new_ones(batch_size)
|
210 |
|
@@ -213,10 +156,11 @@ class StreamModel:
|
|
213 |
inputs = self.model.prepare_inputs_for_generation(
|
214 |
input_ids, **kwargs
|
215 |
) # noqa: E501
|
|
|
216 |
outputs = self._infer(
|
217 |
self.model,
|
218 |
**inputs,
|
219 |
-
return_dict=True,
|
220 |
output_attentions=False,
|
221 |
output_hidden_states=False,
|
222 |
)
|
@@ -224,7 +168,7 @@ class StreamModel:
|
|
224 |
# Pre-process the probability distribution of the next tokens.
|
225 |
logits = outputs.logits[:, -1, :]
|
226 |
with torch.inference_mode():
|
227 |
-
logits = processor(input_ids, logits)
|
228 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
229 |
|
230 |
# Select deterministic or stochastic decoding strategy.
|
|
|
1 |
import gc
|
2 |
import copy
|
3 |
+
import time
|
4 |
from tenacity import RetryError
|
5 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
6 |
|
|
|
14 |
MinNewTokensLengthLogitsProcessor,
|
15 |
TemperatureLogitsWarper,
|
16 |
TopPLogitsWarper,
|
17 |
+
MinLengthLogitsProcessor
|
18 |
)
|
19 |
|
20 |
def get_output_batch(
|
|
|
58 |
self.tokenizer = tokenizer
|
59 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
60 |
|
61 |
+
self.processor = LogitsProcessorList()
|
62 |
+
self.processor.append(TemperatureLogitsWarper(0.9))
|
63 |
+
self.processor.append(TopPLogitsWarper(0.75))
|
64 |
+
|
65 |
+
|
66 |
def __call__(
|
67 |
self,
|
68 |
prompt,
|
|
|
78 |
logprobs = max(logprobs, 0)
|
79 |
|
80 |
# bigger than 1
|
81 |
+
chunk_size = 2
|
82 |
chunk_count = 0
|
83 |
|
84 |
# Generate completion tokens.
|
85 |
+
final_tokens = torch.empty(0)
|
86 |
|
87 |
+
for tokens in self.generate(
|
88 |
+
input_ids[None, :].repeat(n, 1),
|
89 |
+
logprobs=logprobs,
|
90 |
+
min_new_tokens=min_tokens,
|
91 |
+
max_new_tokens=max_tokens,
|
92 |
+
temperature=temperature,
|
93 |
+
top_p=top_p,
|
94 |
+
):
|
95 |
+
if chunk_count < chunk_size:
|
96 |
+
chunk_count = chunk_count + 1
|
97 |
+
|
98 |
+
final_tokens = torch.cat((final_tokens, tokens.to("cpu")))
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
if chunk_count == chunk_size-1:
|
101 |
+
chunk_count = 0
|
102 |
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
|
103 |
|
104 |
+
if chunk_count > 0:
|
105 |
+
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
|
|
|
|
|
106 |
|
107 |
+
del final_tokens, input_ids
|
108 |
if self.device == "cuda":
|
109 |
torch.cuda.empty_cache()
|
110 |
|
|
|
111 |
def _infer(self, model_fn, **kwargs):
|
|
|
|
|
|
|
112 |
with torch.inference_mode():
|
113 |
return model_fn(**kwargs)
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def tokenize(self, text):
|
116 |
"""Tokenize a string into a tensor of token IDs."""
|
117 |
batch = self.tokenizer.encode(text, return_tensors="pt")
|
|
|
130 |
kwargs = config.update(**kwargs)
|
131 |
kwargs["output_attentions"] = False
|
132 |
kwargs["output_hidden_states"] = False
|
133 |
+
kwargs["use_cache"] = True
|
134 |
|
135 |
# Collect special token IDs.
|
136 |
pad_token_id = config.pad_token_id
|
|
|
148 |
input_ids = input_ids * eos_token_id[0]
|
149 |
input_length = 1
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
# Keep track of which sequences are already finished.
|
152 |
unfinished = input_ids.new_ones(batch_size)
|
153 |
|
|
|
156 |
inputs = self.model.prepare_inputs_for_generation(
|
157 |
input_ids, **kwargs
|
158 |
) # noqa: E501
|
159 |
+
|
160 |
outputs = self._infer(
|
161 |
self.model,
|
162 |
**inputs,
|
163 |
+
# return_dict=True,
|
164 |
output_attentions=False,
|
165 |
output_hidden_states=False,
|
166 |
)
|
|
|
168 |
# Pre-process the probability distribution of the next tokens.
|
169 |
logits = outputs.logits[:, -1, :]
|
170 |
with torch.inference_mode():
|
171 |
+
logits = self.processor(input_ids, logits)
|
172 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
173 |
|
174 |
# Select deterministic or stochastic decoding strategy.
|