bird-of-paradise commited on
Commit
a0dec77
·
verified ·
1 Parent(s): cfa2a65

replace `model.generate` with custom generation function to optimize kv_cache

Browse files

replaced the problematic high level generate
```
model_outputs = self.model.generate(
input_ids=current_input_id,
attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id)
eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
past_key_values=current_kv,
generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True
# max_new_tokens should be set in self.generation_config appropriately for a segment
)
```
with custom generation and decoding functions.
This is because you can't use stateful cache in high level `model.generate`.

Files changed (1) hide show
  1. src/retool_trainer.py +168 -101
src/retool_trainer.py CHANGED
@@ -25,6 +25,7 @@ from transformers import (
25
  is_wandb_available,
26
  PreTrainedTokenizer,
27
  )
 
28
 
29
 
30
 
@@ -257,128 +258,194 @@ class ReToolTrainer(Trainer): # Change this line
257
  return advantages
258
 
259
 
260
- def _retool_generate_with_interpreter(
261
- self,
262
- prompt_ids_batch: torch.Tensor, # Full batch of prompts
263
- attention_mask_batch: torch.Tensor, # Full batch of attention masks for prompts
264
- #tokenizer: PreTrainedTokenizer, # use self.processiing_class for Tokenizer
265
- eos_id: int, # True end-of-sequence token ID
266
- interpreter_id: list[int], # [start_id, end_id]
267
- code_id: list[int], # [start_id, end_id]
268
- max_turns: int = 10
269
- ) -> tuple[torch.LongTensor, list[list[tuple[int, int]]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
 
 
271
  batch_size = prompt_ids_batch.size(0)
272
  batch_completion = []
273
  batch_interpreter_positions = []
274
-
275
- for i in range(batch_size): # Process each item in the batch
276
- # --- Initialization for the current sequence ---
277
- current_input_id = prompt_ids_batch[i:i+1] # Initial input is the prompt
278
- current_attention_mask = attention_mask_batch[i:i+1]
279
  current_kv = None
280
-
281
- # NEW: Track only the completion part (no prompt)
282
  cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
283
  interpreter_positions = []
284
-
285
  for turn_idx in range(max_turns):
286
- # --- Stage 1: LM generates text ---
287
- model_outputs = self.model.generate(
 
 
 
 
288
  input_ids=current_input_id,
289
- attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id)
290
- eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
291
  past_key_values=current_kv,
292
- generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True
293
- # max_new_tokens should be set in self.generation_config appropriately for a segment
294
  )
295
-
296
- # Update current_full_ids to the new complete sequence
297
- current_full_ids = model_outputs.sequences
298
-
299
- # Newly generated tokens by the LM in THIS step
300
- completion_id = current_full_ids[:, current_input_id.size(1):]
301
-
302
- # Add to completion tracking (excludes prompt)
303
- cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1)
304
-
305
- # Update current_input_id for the next generation step
306
- # Update current_attention_mask: it was for (history + current_input_id),
307
- # now append 1s for completion_id
308
- current_attention_mask = torch.cat([
309
- current_attention_mask,
310
- torch.ones_like(completion_id)
311
- ], dim=1)
312
-
313
- current_kv = model_outputs.past_key_values # Cache for the new current_full_ids
314
-
315
- last_token_id = current_full_ids[0, -1].item()
316
-
317
  if last_token_id == eos_id or turn_idx == max_turns - 1:
318
  batch_completion.append(cumulative_completion_ids.squeeze(0))
319
- batch_interpreter_positions.append(interpreter_positions) # Note: was batch_interpreter_positions[i] = ...
320
  break
321
-
322
- if last_token_id == code_id[1]: # Assuming code_id[1] is the specific ID for </code> last token
323
- # --- Stage 2: Tool Execution ---
324
- # Extract code from the generated sequence
325
- full_text = self.processing_class.decode(current_full_ids[0])
 
 
326
  code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
 
327
  if code_match:
328
- code_block = code_match.group(1)
329
- interpreter_text = self._execute_code(code_block) # 👈 To do: code sandbox execution 👈
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  else:
331
- interpreter_text = "Error: No code found"
332
-
333
- formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
334
-
335
- interpreter_feedback_id = self.processing_class(
336
- formatted_feedback_text,
337
- return_tensors="pt",
338
- add_special_tokens=False
339
- ).input_ids.to(current_full_ids.device)
340
-
341
-
342
- # Record positions relative to cumulative_completion_ids *before* appending feedback
343
- interpreter_start_idx = cumulative_completion_ids.size(1)
344
- cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) # Use cumulative, not current
345
- interpreter_end_idx = cumulative_completion_ids.size(1) - 1
346
- interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
347
-
348
- # Update attention mask for the appended tool feedback
349
- current_attention_mask = torch.cat([
350
- current_attention_mask,
351
- torch.ones_like(interpreter_feedback_id)
352
- ], dim=1)
353
-
354
- # Prepare for the next LM generation step:
355
- # The model needs to "process" the tool_output_tokens to update its KV cache.
356
- # The `current_input_id` for the next generate call will be `interpreter_feedback_id`.
357
- # `current_kv` already holds the cache for `current_full_ids` *before* the tool feedback was appended.
358
- # The `current_attention_mask` now correctly covers `current_full_ids` (which includes tool feedback).
359
- current_input_id = interpreter_feedback_id
360
- # `current_kv` is correct (it's for the prefix before `interpreter_feedback_id`).
361
- # The next `model.generate` call will use this `current_input_id`, `current_attention_mask`, and `current_kv`.
362
  else:
363
- # LM stopped for a reason other than EOS or code_end` (e.g., max_new_tokens for the segment)
364
- batch_completion.append(cumulative_completion_ids.squeeze(0))
365
- batch_interpreter_positions.append(interpreter_positions)
366
- # At the end, return full sequence (prompt + completion)
367
- break
368
- else: # Executed if the loop finished due to max_turns without a break
369
  batch_completion.append(cumulative_completion_ids.squeeze(0))
370
  batch_interpreter_positions.append(interpreter_positions)
371
-
372
-
373
- # Pad sequences in the batch to the same length for returning a single tensor
374
- # This is a common step if you started with a batch loop.
375
- # Alternatively, this function could return a list of tensors if lengths vary.
376
- # For now, assuming you'll handle batch padding outside or return a list.
377
- # The return type `torch.LongTensor` implies a padded batch.
378
- padded_sequences = torch.nn.utils.rnn.pad_sequence(batch_completion, batch_first=True, padding_value=self.processing_class.pad_token_id)
379
-
 
 
 
 
 
 
 
380
  return padded_sequences, batch_interpreter_positions
381
-
382
 
383
 
384
  def _create_interpreter_mask(
 
25
  is_wandb_available,
26
  PreTrainedTokenizer,
27
  )
28
+ from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
29
 
30
 
31
 
 
258
  return advantages
259
 
260
 
261
+ def _custom_generate(self, input_ids, attention_mask=None, past_key_values=None, max_new_tokens=50, eos_token_ids=None):
262
+ """Custom generation function that avoids KV cache issues"""
263
+ if attention_mask is None:
264
+ attention_mask = torch.ones_like(input_ids)
265
+
266
+ if eos_token_ids is None:
267
+ eos_token_ids = [self.processing_class.eos_token_id]
268
+
269
+ # Initialize
270
+ current_ids = input_ids.clone()
271
+ current_mask = attention_mask.clone()
272
+ current_kv = past_key_values
273
+
274
+ # Generate tokens in batches for efficiency
275
+ all_tokens = []
276
+ batch_size = 10 # Process this many tokens at once
277
+
278
+ for start_idx in range(0, max_new_tokens, batch_size):
279
+ # How many tokens to generate in this batch
280
+ batch_tokens = min(batch_size, max_new_tokens - start_idx)
281
+
282
+ # Accumulate new tokens
283
+ new_tokens = []
284
+
285
+ for _ in range(batch_tokens):
286
+ # Forward pass with proper cache handling
287
+ with torch.no_grad():
288
+ outputs = self.model(
289
+ input_ids=current_ids if current_kv is None else current_ids[:, -1:],
290
+ attention_mask=current_mask if current_kv is None else current_mask[:, -1:],
291
+ past_key_values=DynamicCache.from_legacy_cache(current_kv) if current_kv is not None else None,
292
+ use_cache=True
293
+ )
294
+
295
+ # Sample next token
296
+ next_token_logits = outputs.logits[:, -1, :] / self.temperature
297
+ filtered_logits = self._filter_logits(next_token_logits)
298
+ probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
299
+ next_token = torch.multinomial(probs, num_samples=1)
300
+
301
+ # Add to accumulated tokens
302
+ token_id = next_token.item()
303
+ new_tokens.append(token_id)
304
+
305
+ # Update for next iteration
306
+ current_ids = torch.cat([current_ids, next_token], dim=1)
307
+ token_mask = torch.ones((1, 1), device=current_mask.device, dtype=current_mask.dtype)
308
+ current_mask = torch.cat([current_mask, token_mask], dim=1)
309
+ current_kv = outputs.past_key_values
310
+
311
+ # Check for stop tokens - include both EOS and code_end
312
+ if token_id in eos_token_ids:
313
+ break
314
+
315
+ # Add batch tokens to overall result
316
+ all_tokens.extend(new_tokens)
317
+
318
+ # Check if we hit a stop token
319
+ if len(new_tokens) < batch_tokens:
320
+ break
321
+
322
+ # Convert to tensor
323
+ result = torch.tensor([all_tokens], device=input_ids.device)
324
+ return result, current_kv
325
+
326
+ def _filter_logits(self, logits):
327
+ """Apply top-k and top-p filtering"""
328
+ if self.top_k > 0:
329
+ top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
330
+ logits[0, :] = torch.full_like(logits[0, :], float('-inf'))
331
+ logits[0, top_k_indices[0]] = top_k_logits[0]
332
+
333
+ if self.top_p < 1.0:
334
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
335
+ cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
336
+
337
+ # Remove tokens with cumulative probability above threshold
338
+ sorted_indices_to_remove = cumulative_probs > self.top_p
339
+ # Shift the indices to the right to keep the first token above threshold
340
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
341
+ sorted_indices_to_remove[:, 0] = 0
342
+
343
+ # Scatter sorted tensors to original indexing
344
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
345
+ logits[indices_to_remove] = float('-inf')
346
+
347
+ return logits
348
 
349
+ def _retool_generate_with_interpreter(self, prompt_ids_batch, attention_mask_batch, eos_id, interpreter_id, code_id, max_turns=10):
350
+ """Implementation with custom generation to avoid KV cache issues"""
351
  batch_size = prompt_ids_batch.size(0)
352
  batch_completion = []
353
  batch_interpreter_positions = []
354
+
355
+ for i in range(batch_size):
356
+ # Initialize
357
+ current_input_id = prompt_ids_batch[i:i+1]
358
+ current_attention_mask = attention_mask_batch[i:i+1]
359
  current_kv = None
360
+
361
+ # Track completion (excludes prompt)
362
  cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
363
  interpreter_positions = []
364
+
365
  for turn_idx in range(max_turns):
366
+ # Check if input is empty
367
+ if current_input_id.size(1) == 0:
368
+ break
369
+
370
+ # Generate with custom function
371
+ newly_generated_tokens, current_kv = self._custom_generate(
372
  input_ids=current_input_id,
373
+ attention_mask=current_attention_mask,
 
374
  past_key_values=current_kv,
375
+ max_new_tokens=self.max_completion_length, # Use class attribute
376
+ eos_token_ids=[eos_id, code_id[1]]
377
  )
378
+
379
+ # Add to completion
380
+ cumulative_completion_ids = torch.cat([cumulative_completion_ids, newly_generated_tokens], dim=1)
381
+
382
+ # Check last token
383
+ last_token_id = newly_generated_tokens[0, -1].item() if newly_generated_tokens.size(1) > 0 else None
384
+
385
+ # Check for end conditions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  if last_token_id == eos_id or turn_idx == max_turns - 1:
387
  batch_completion.append(cumulative_completion_ids.squeeze(0))
388
+ batch_interpreter_positions.append(interpreter_positions)
389
  break
390
+
391
+ # Check for code end token
392
+ if last_token_id == code_id[1]:
393
+ # Extract code from the full text
394
+ full_text = self.processing_class.decode(
395
+ torch.cat([prompt_ids_batch[i], cumulative_completion_ids[0]], dim=0)
396
+ )
397
  code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
398
+
399
  if code_match:
400
+ code_block = code_match.group(1).strip()
401
+ interpreter_text = self._execute_code(code_block)
402
+
403
+ # Format and add interpreter output
404
+ formatted_feedback = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
405
+ interpreter_ids = self.processing_class(
406
+ formatted_feedback,
407
+ return_tensors="pt",
408
+ add_special_tokens=False
409
+ ).input_ids.to(prompt_ids_batch.device)
410
+
411
+ # Record positions
412
+ interpreter_start_idx = cumulative_completion_ids.size(1)
413
+ cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_ids], dim=1)
414
+ interpreter_end_idx = cumulative_completion_ids.size(1) - 1
415
+ interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
416
+
417
+ # Set up for next turn
418
+ current_input_id = interpreter_ids
419
+ current_attention_mask = torch.ones_like(current_input_id)
420
+ # Keep current_kv from previous generation
421
  else:
422
+ # No code block found despite </code> token
423
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
  else:
425
+ # Continue with the newly generated tokens
426
+ current_input_id = newly_generated_tokens
427
+ current_attention_mask = torch.ones_like(current_input_id)
428
+ else:
429
+ # Loop finished due to max_turns without a break
 
430
  batch_completion.append(cumulative_completion_ids.squeeze(0))
431
  batch_interpreter_positions.append(interpreter_positions)
432
+
433
+ # Pad sequences
434
+ if len(batch_completion) > 0:
435
+ # Ensure padding_value is a valid integer
436
+ padding_value = self.processing_class.pad_token_id
437
+ if padding_value is None:
438
+ padding_value = 0 # Use 0 as a default if pad_token_id is None
439
+
440
+ padded_sequences = torch.nn.utils.rnn.pad_sequence(
441
+ batch_completion,
442
+ batch_first=True,
443
+ padding_value=padding_value
444
+ )
445
+ else:
446
+ padded_sequences = torch.empty((0, 0), dtype=torch.long, device=prompt_ids_batch.device)
447
+
448
  return padded_sequences, batch_interpreter_positions
 
449
 
450
 
451
  def _create_interpreter_mask(