replace `model.generate` with custom generation function to optimize kv_cache
Browse filesreplaced the problematic high level generate
```
model_outputs = self.model.generate(
input_ids=current_input_id,
attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id)
eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
past_key_values=current_kv,
generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True
# max_new_tokens should be set in self.generation_config appropriately for a segment
)
```
with custom generation and decoding functions.
This is because you can't use stateful cache in high level `model.generate`.
- src/retool_trainer.py +168 -101
@@ -25,6 +25,7 @@ from transformers import (
|
|
25 |
is_wandb_available,
|
26 |
PreTrainedTokenizer,
|
27 |
)
|
|
|
28 |
|
29 |
|
30 |
|
@@ -257,128 +258,194 @@ class ReToolTrainer(Trainer): # Change this line
|
|
257 |
return advantages
|
258 |
|
259 |
|
260 |
-
def
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
|
|
|
|
271 |
batch_size = prompt_ids_batch.size(0)
|
272 |
batch_completion = []
|
273 |
batch_interpreter_positions = []
|
274 |
-
|
275 |
-
for i in range(batch_size):
|
276 |
-
#
|
277 |
-
current_input_id = prompt_ids_batch[i:i+1]
|
278 |
-
current_attention_mask = attention_mask_batch[i:i+1]
|
279 |
current_kv = None
|
280 |
-
|
281 |
-
#
|
282 |
cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
|
283 |
interpreter_positions = []
|
284 |
-
|
285 |
for turn_idx in range(max_turns):
|
286 |
-
#
|
287 |
-
|
|
|
|
|
|
|
|
|
288 |
input_ids=current_input_id,
|
289 |
-
attention_mask=current_attention_mask,
|
290 |
-
eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
|
291 |
past_key_values=current_kv,
|
292 |
-
|
293 |
-
|
294 |
)
|
295 |
-
|
296 |
-
#
|
297 |
-
|
298 |
-
|
299 |
-
#
|
300 |
-
|
301 |
-
|
302 |
-
#
|
303 |
-
cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1)
|
304 |
-
|
305 |
-
# Update current_input_id for the next generation step
|
306 |
-
# Update current_attention_mask: it was for (history + current_input_id),
|
307 |
-
# now append 1s for completion_id
|
308 |
-
current_attention_mask = torch.cat([
|
309 |
-
current_attention_mask,
|
310 |
-
torch.ones_like(completion_id)
|
311 |
-
], dim=1)
|
312 |
-
|
313 |
-
current_kv = model_outputs.past_key_values # Cache for the new current_full_ids
|
314 |
-
|
315 |
-
last_token_id = current_full_ids[0, -1].item()
|
316 |
-
|
317 |
if last_token_id == eos_id or turn_idx == max_turns - 1:
|
318 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
319 |
-
batch_interpreter_positions.append(interpreter_positions)
|
320 |
break
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
# Extract code from the
|
325 |
-
full_text = self.processing_class.decode(
|
|
|
|
|
326 |
code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
|
|
|
327 |
if code_match:
|
328 |
-
code_block = code_match.group(1)
|
329 |
-
interpreter_text = self._execute_code(code_block)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
else:
|
331 |
-
|
332 |
-
|
333 |
-
formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
|
334 |
-
|
335 |
-
interpreter_feedback_id = self.processing_class(
|
336 |
-
formatted_feedback_text,
|
337 |
-
return_tensors="pt",
|
338 |
-
add_special_tokens=False
|
339 |
-
).input_ids.to(current_full_ids.device)
|
340 |
-
|
341 |
-
|
342 |
-
# Record positions relative to cumulative_completion_ids *before* appending feedback
|
343 |
-
interpreter_start_idx = cumulative_completion_ids.size(1)
|
344 |
-
cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) # Use cumulative, not current
|
345 |
-
interpreter_end_idx = cumulative_completion_ids.size(1) - 1
|
346 |
-
interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
|
347 |
-
|
348 |
-
# Update attention mask for the appended tool feedback
|
349 |
-
current_attention_mask = torch.cat([
|
350 |
-
current_attention_mask,
|
351 |
-
torch.ones_like(interpreter_feedback_id)
|
352 |
-
], dim=1)
|
353 |
-
|
354 |
-
# Prepare for the next LM generation step:
|
355 |
-
# The model needs to "process" the tool_output_tokens to update its KV cache.
|
356 |
-
# The `current_input_id` for the next generate call will be `interpreter_feedback_id`.
|
357 |
-
# `current_kv` already holds the cache for `current_full_ids` *before* the tool feedback was appended.
|
358 |
-
# The `current_attention_mask` now correctly covers `current_full_ids` (which includes tool feedback).
|
359 |
-
current_input_id = interpreter_feedback_id
|
360 |
-
# `current_kv` is correct (it's for the prefix before `interpreter_feedback_id`).
|
361 |
-
# The next `model.generate` call will use this `current_input_id`, `current_attention_mask`, and `current_kv`.
|
362 |
else:
|
363 |
-
#
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
else: # Executed if the loop finished due to max_turns without a break
|
369 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
370 |
batch_interpreter_positions.append(interpreter_positions)
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
return padded_sequences, batch_interpreter_positions
|
381 |
-
|
382 |
|
383 |
|
384 |
def _create_interpreter_mask(
|
|
|
25 |
is_wandb_available,
|
26 |
PreTrainedTokenizer,
|
27 |
)
|
28 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
29 |
|
30 |
|
31 |
|
|
|
258 |
return advantages
|
259 |
|
260 |
|
261 |
+
def _custom_generate(self, input_ids, attention_mask=None, past_key_values=None, max_new_tokens=50, eos_token_ids=None):
|
262 |
+
"""Custom generation function that avoids KV cache issues"""
|
263 |
+
if attention_mask is None:
|
264 |
+
attention_mask = torch.ones_like(input_ids)
|
265 |
+
|
266 |
+
if eos_token_ids is None:
|
267 |
+
eos_token_ids = [self.processing_class.eos_token_id]
|
268 |
+
|
269 |
+
# Initialize
|
270 |
+
current_ids = input_ids.clone()
|
271 |
+
current_mask = attention_mask.clone()
|
272 |
+
current_kv = past_key_values
|
273 |
+
|
274 |
+
# Generate tokens in batches for efficiency
|
275 |
+
all_tokens = []
|
276 |
+
batch_size = 10 # Process this many tokens at once
|
277 |
+
|
278 |
+
for start_idx in range(0, max_new_tokens, batch_size):
|
279 |
+
# How many tokens to generate in this batch
|
280 |
+
batch_tokens = min(batch_size, max_new_tokens - start_idx)
|
281 |
+
|
282 |
+
# Accumulate new tokens
|
283 |
+
new_tokens = []
|
284 |
+
|
285 |
+
for _ in range(batch_tokens):
|
286 |
+
# Forward pass with proper cache handling
|
287 |
+
with torch.no_grad():
|
288 |
+
outputs = self.model(
|
289 |
+
input_ids=current_ids if current_kv is None else current_ids[:, -1:],
|
290 |
+
attention_mask=current_mask if current_kv is None else current_mask[:, -1:],
|
291 |
+
past_key_values=DynamicCache.from_legacy_cache(current_kv) if current_kv is not None else None,
|
292 |
+
use_cache=True
|
293 |
+
)
|
294 |
+
|
295 |
+
# Sample next token
|
296 |
+
next_token_logits = outputs.logits[:, -1, :] / self.temperature
|
297 |
+
filtered_logits = self._filter_logits(next_token_logits)
|
298 |
+
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
|
299 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
300 |
+
|
301 |
+
# Add to accumulated tokens
|
302 |
+
token_id = next_token.item()
|
303 |
+
new_tokens.append(token_id)
|
304 |
+
|
305 |
+
# Update for next iteration
|
306 |
+
current_ids = torch.cat([current_ids, next_token], dim=1)
|
307 |
+
token_mask = torch.ones((1, 1), device=current_mask.device, dtype=current_mask.dtype)
|
308 |
+
current_mask = torch.cat([current_mask, token_mask], dim=1)
|
309 |
+
current_kv = outputs.past_key_values
|
310 |
+
|
311 |
+
# Check for stop tokens - include both EOS and code_end
|
312 |
+
if token_id in eos_token_ids:
|
313 |
+
break
|
314 |
+
|
315 |
+
# Add batch tokens to overall result
|
316 |
+
all_tokens.extend(new_tokens)
|
317 |
+
|
318 |
+
# Check if we hit a stop token
|
319 |
+
if len(new_tokens) < batch_tokens:
|
320 |
+
break
|
321 |
+
|
322 |
+
# Convert to tensor
|
323 |
+
result = torch.tensor([all_tokens], device=input_ids.device)
|
324 |
+
return result, current_kv
|
325 |
+
|
326 |
+
def _filter_logits(self, logits):
|
327 |
+
"""Apply top-k and top-p filtering"""
|
328 |
+
if self.top_k > 0:
|
329 |
+
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
|
330 |
+
logits[0, :] = torch.full_like(logits[0, :], float('-inf'))
|
331 |
+
logits[0, top_k_indices[0]] = top_k_logits[0]
|
332 |
+
|
333 |
+
if self.top_p < 1.0:
|
334 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
335 |
+
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
336 |
+
|
337 |
+
# Remove tokens with cumulative probability above threshold
|
338 |
+
sorted_indices_to_remove = cumulative_probs > self.top_p
|
339 |
+
# Shift the indices to the right to keep the first token above threshold
|
340 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
341 |
+
sorted_indices_to_remove[:, 0] = 0
|
342 |
+
|
343 |
+
# Scatter sorted tensors to original indexing
|
344 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
345 |
+
logits[indices_to_remove] = float('-inf')
|
346 |
+
|
347 |
+
return logits
|
348 |
|
349 |
+
def _retool_generate_with_interpreter(self, prompt_ids_batch, attention_mask_batch, eos_id, interpreter_id, code_id, max_turns=10):
|
350 |
+
"""Implementation with custom generation to avoid KV cache issues"""
|
351 |
batch_size = prompt_ids_batch.size(0)
|
352 |
batch_completion = []
|
353 |
batch_interpreter_positions = []
|
354 |
+
|
355 |
+
for i in range(batch_size):
|
356 |
+
# Initialize
|
357 |
+
current_input_id = prompt_ids_batch[i:i+1]
|
358 |
+
current_attention_mask = attention_mask_batch[i:i+1]
|
359 |
current_kv = None
|
360 |
+
|
361 |
+
# Track completion (excludes prompt)
|
362 |
cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
|
363 |
interpreter_positions = []
|
364 |
+
|
365 |
for turn_idx in range(max_turns):
|
366 |
+
# Check if input is empty
|
367 |
+
if current_input_id.size(1) == 0:
|
368 |
+
break
|
369 |
+
|
370 |
+
# Generate with custom function
|
371 |
+
newly_generated_tokens, current_kv = self._custom_generate(
|
372 |
input_ids=current_input_id,
|
373 |
+
attention_mask=current_attention_mask,
|
|
|
374 |
past_key_values=current_kv,
|
375 |
+
max_new_tokens=self.max_completion_length, # Use class attribute
|
376 |
+
eos_token_ids=[eos_id, code_id[1]]
|
377 |
)
|
378 |
+
|
379 |
+
# Add to completion
|
380 |
+
cumulative_completion_ids = torch.cat([cumulative_completion_ids, newly_generated_tokens], dim=1)
|
381 |
+
|
382 |
+
# Check last token
|
383 |
+
last_token_id = newly_generated_tokens[0, -1].item() if newly_generated_tokens.size(1) > 0 else None
|
384 |
+
|
385 |
+
# Check for end conditions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
if last_token_id == eos_id or turn_idx == max_turns - 1:
|
387 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
388 |
+
batch_interpreter_positions.append(interpreter_positions)
|
389 |
break
|
390 |
+
|
391 |
+
# Check for code end token
|
392 |
+
if last_token_id == code_id[1]:
|
393 |
+
# Extract code from the full text
|
394 |
+
full_text = self.processing_class.decode(
|
395 |
+
torch.cat([prompt_ids_batch[i], cumulative_completion_ids[0]], dim=0)
|
396 |
+
)
|
397 |
code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
|
398 |
+
|
399 |
if code_match:
|
400 |
+
code_block = code_match.group(1).strip()
|
401 |
+
interpreter_text = self._execute_code(code_block)
|
402 |
+
|
403 |
+
# Format and add interpreter output
|
404 |
+
formatted_feedback = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
|
405 |
+
interpreter_ids = self.processing_class(
|
406 |
+
formatted_feedback,
|
407 |
+
return_tensors="pt",
|
408 |
+
add_special_tokens=False
|
409 |
+
).input_ids.to(prompt_ids_batch.device)
|
410 |
+
|
411 |
+
# Record positions
|
412 |
+
interpreter_start_idx = cumulative_completion_ids.size(1)
|
413 |
+
cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_ids], dim=1)
|
414 |
+
interpreter_end_idx = cumulative_completion_ids.size(1) - 1
|
415 |
+
interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
|
416 |
+
|
417 |
+
# Set up for next turn
|
418 |
+
current_input_id = interpreter_ids
|
419 |
+
current_attention_mask = torch.ones_like(current_input_id)
|
420 |
+
# Keep current_kv from previous generation
|
421 |
else:
|
422 |
+
# No code block found despite </code> token
|
423 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
else:
|
425 |
+
# Continue with the newly generated tokens
|
426 |
+
current_input_id = newly_generated_tokens
|
427 |
+
current_attention_mask = torch.ones_like(current_input_id)
|
428 |
+
else:
|
429 |
+
# Loop finished due to max_turns without a break
|
|
|
430 |
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
431 |
batch_interpreter_positions.append(interpreter_positions)
|
432 |
+
|
433 |
+
# Pad sequences
|
434 |
+
if len(batch_completion) > 0:
|
435 |
+
# Ensure padding_value is a valid integer
|
436 |
+
padding_value = self.processing_class.pad_token_id
|
437 |
+
if padding_value is None:
|
438 |
+
padding_value = 0 # Use 0 as a default if pad_token_id is None
|
439 |
+
|
440 |
+
padded_sequences = torch.nn.utils.rnn.pad_sequence(
|
441 |
+
batch_completion,
|
442 |
+
batch_first=True,
|
443 |
+
padding_value=padding_value
|
444 |
+
)
|
445 |
+
else:
|
446 |
+
padded_sequences = torch.empty((0, 0), dtype=torch.long, device=prompt_ids_batch.device)
|
447 |
+
|
448 |
return padded_sequences, batch_interpreter_positions
|
|
|
449 |
|
450 |
|
451 |
def _create_interpreter_mask(
|