bird-of-paradise commited on
Commit
c710786
·
verified ·
1 Parent(s): 0690c9f

Add custom sampler, train data loader and GRPO style train loop for ReTool_trainer

Browse files
Files changed (1) hide show
  1. src/retool_trainer.py +350 -7
src/retool_trainer.py CHANGED
@@ -1,8 +1,18 @@
1
- from typing import Any, Callable, Optional, Union
2
- from collections import defaultdict
3
- import re
4
  import profiling_decorator
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import datasets
7
  import torch
8
  import torch.utils.data
@@ -14,6 +24,7 @@ from torch import nn
14
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
15
  from torch.utils.data import DataLoader, Sampler
16
  from transformers import (
 
17
  AutoModelForCausalLM,
18
  AutoModelForSequenceClassification,
19
  AutoTokenizer,
@@ -28,6 +39,105 @@ from transformers import (
28
  from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  class ReToolTrainer(Trainer): # Change this line
33
 
@@ -52,16 +162,19 @@ class ReToolTrainer(Trainer): # Change this line
52
  mask_truncated_completions: bool = True,
53
  **kwargs
54
  ):
55
- # Initialize parent Trainer (simpler call)
56
  super().__init__(
57
  model=model,
58
- tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class'
59
  args=args,
 
 
60
  train_dataset=train_dataset,
61
  eval_dataset=eval_dataset,
 
 
 
62
  **kwargs
63
  )
64
-
65
 
66
  # Store processing_class for compatibility
67
  self.processing_class = processing_class or self.tokenizer
@@ -115,6 +228,76 @@ class ReToolTrainer(Trainer): # Change this line
115
  use_cache=True,
116
  cache_implementation=args.cache_implementation, #args.cache_implementation = 'Offloaded Cache'
117
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def _get_interpreter_token_ids(self) -> list[int]:
119
  """Get token IDs for <interpreter> and </interpreter> tags."""
120
  start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
@@ -725,4 +908,164 @@ def _compute_loss(self, model, inputs):
725
  self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
726
  gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
727
  self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
728
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import profiling_decorator
2
 
3
+ import copy
4
+ import inspect
5
+ import os
6
+ import re
7
+ import textwrap
8
+ import warnings
9
+ from collections import defaultdict, deque
10
+ from collections.abc import Sequence, Sized
11
+ from contextlib import nullcontext
12
+ from functools import partial
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Optional, Union
15
+
16
  import datasets
17
  import torch
18
  import torch.utils.data
 
24
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
25
  from torch.utils.data import DataLoader, Sampler
26
  from transformers import (
27
+ AutoConfig,
28
  AutoModelForCausalLM,
29
  AutoModelForSequenceClassification,
30
  AutoTokenizer,
 
39
  from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
40
 
41
 
42
+ class HFRepeatSampler(Sampler):
43
+ """
44
+ Sampler that repeats the indices of a dataset in a structured manner.
45
+
46
+ Args:
47
+ data_source (`Sized`):
48
+ Dataset to sample from.
49
+ mini_repeat_count (`int`):
50
+ Number of times to repeat each index per batch.
51
+ batch_size (`int`, *optional*, defaults to `1`):
52
+ Number of unique indices per batch.
53
+ repeat_count (`int`, *optional*, defaults to `1`):
54
+ Number of times to repeat the full sampling process.
55
+ shuffle (`bool`, *optional*, defaults to `True`):
56
+ Whether to shuffle the dataset.
57
+ seed (`int` or `None`, *optional*, defaults to `None`):
58
+ Random seed for reproducibility (only affects this sampler).
59
+
60
+ Example:
61
+ ```python
62
+ >>> sampler = RepeatSampler(
63
+ ... ["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4
64
+ ... )
65
+ >>> list(sampler)
66
+ [4, 4, 3, 3, 0, 0,
67
+ 4, 4, 3, 3, 0, 0,
68
+ 4, 4, 3, 3, 0, 0,
69
+ 4, 4, 3, 3, 0, 0,
70
+ 1, 1, 2, 2, 6, 6,
71
+ 1, 1, 2, 2, 6, 6,
72
+ 1, 1, 2, 2, 6, 6,
73
+ 1, 1, 2, 2, 6, 6]
74
+ ```
75
+
76
+ ```txt
77
+ mini_repeat_count = 3
78
+ - - -
79
+ [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
80
+ 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
81
+ 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
82
+ repeat_count = 2
83
+ 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
84
+ 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
85
+ 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
86
+ --------- --------- --------- ---------
87
+ --------- --------- --------- ---------
88
+ --------- --------- --------- ---------
89
+ batch_size = 12
90
+ ```
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ data_source: Sized,
96
+ mini_repeat_count: int,
97
+ batch_size: int = 1,
98
+ repeat_count: int = 1,
99
+ shuffle: bool = True,
100
+ seed: Optional[int] = None,
101
+ ):
102
+ self.data_source = data_source
103
+ self.mini_repeat_count = mini_repeat_count
104
+ self.batch_size = batch_size
105
+ self.repeat_count = repeat_count
106
+ self.num_samples = len(data_source)
107
+ self.shuffle = shuffle
108
+ self.seed = seed
109
+
110
+ if shuffle:
111
+ self.generator = torch.Generator() # Create a local random generator
112
+ if seed is not None:
113
+ self.generator.manual_seed(seed)
114
+
115
+ def __iter__(self):
116
+ if self.shuffle:
117
+ # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
118
+ indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
119
+ else:
120
+ indexes = list(range(self.num_samples))
121
+
122
+ # [2, 4, 3, 1, 0, 6, 5]
123
+ # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
124
+ indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
125
+
126
+ # [[2, 4, 3], [1, 0, 6], [5]]
127
+ # -> [[2, 4, 3], [1, 0, 6]]
128
+ indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
129
+
130
+ for chunk in indexes:
131
+ for _ in range(self.repeat_count):
132
+ for index in chunk:
133
+ for _ in range(self.mini_repeat_count):
134
+ yield index
135
+
136
+ def __len__(self) -> int:
137
+ return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count
138
+
139
+
140
+
141
 
142
  class ReToolTrainer(Trainer): # Change this line
143
 
 
162
  mask_truncated_completions: bool = True,
163
  **kwargs
164
  ):
165
+ # Initialize parent Trainer (simpler call)
166
  super().__init__(
167
  model=model,
 
168
  args=args,
169
+ tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class'
170
+ data_collator=identity, # No data collation is needed in GRPO
171
  train_dataset=train_dataset,
172
  eval_dataset=eval_dataset,
173
+ processing_class=processing_class,
174
+ callbacks=callbacks,
175
+ optimizers=optimizers,
176
  **kwargs
177
  )
 
178
 
179
  # Store processing_class for compatibility
180
  self.processing_class = processing_class or self.tokenizer
 
228
  use_cache=True,
229
  cache_implementation=args.cache_implementation, #args.cache_implementation = 'Offloaded Cache'
230
  )
231
+ def _set_signature_columns_if_needed(self):
232
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
233
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
234
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
235
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
236
+ if self._signature_columns is None:
237
+ self._signature_columns = ["prompt", "image"]
238
+
239
+ def _get_train_sampler(self, dataset=None):
240
+ """Override to use RepeatSampler for GRPO."""
241
+ # Returns a sampler that
242
+ # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
243
+ # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
244
+ # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
245
+ # in group formation.
246
+ # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
247
+ # _prepare_inputs to see how the generations are stored and reused.
248
+
249
+ # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
250
+ # second row shows the second sampled batch, and so on.
251
+ #
252
+ # | GPU 0 | GPU 1 |
253
+ #
254
+ # global_step step <-───> num_generations=2
255
+ # <-───────> per_device_train_batch_size=3
256
+ # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
257
+ # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss
258
+ # |
259
+ # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss
260
+ # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss
261
+ #
262
+ # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
263
+ # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss
264
+ # ...
265
+ if dataset is None:
266
+ dataset = self.train_dataset
267
+
268
+ return HFRepeatSampler(
269
+ data_source=dataset,
270
+ mini_repeat_count=self.num_generations, # e.g., 4 completions per prompt
271
+ batch_size=self.args.generation_batch_size // self.num_generations, # correction
272
+ repeat_count=self.num_iterations * self.args.steps_per_generation, # correction
273
+ shuffle=True,
274
+ seed=self.args.seed
275
+ )
276
+
277
+ def get_train_dataloader(self):
278
+ """Override to ensure our custom sampler is used."""
279
+ if self.train_dataset is None:
280
+ raise ValueError("Trainer: training requires a train_dataset.")
281
+
282
+ train_dataset = self.train_dataset
283
+ data_collator = self.data_collator
284
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
285
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
286
+ else:
287
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
288
+
289
+ sampler = self._get_train_sampler(train_dataset)
290
+ dataloader_batch_size = self._train_batch_size * self.args.steps_per_generation
291
+
292
+ return DataLoader(
293
+ train_dataset,
294
+ batch_size= self.args.generation_batch_size, # < this is the change, HF was useing dataloader_batch_size
295
+ sampler=sampler,
296
+ collate_fn=data_collator,
297
+ drop_last=self.args.dataloader_drop_last,
298
+ num_workers=self.args.dataloader_num_workers,
299
+ )
300
+
301
  def _get_interpreter_token_ids(self) -> list[int]:
302
  """Get token IDs for <interpreter> and </interpreter> tags."""
303
  start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
 
908
  self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
909
  gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
910
  self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
911
+ return loss
912
+
913
+ def train(self):
914
+ """
915
+ Comprehensive training loop for ReTool with GRPO.
916
+ Adapted from train_with_batching to work as a method.
917
+ """
918
+ # Initialize
919
+ self.model.train()
920
+ if not hasattr(self, 'ref_model') or self.ref_model is None:
921
+ self.ref_model = deepcopy(self.model)
922
+ self.ref_model.eval()
923
+
924
+ # Setup tracking
925
+ writer = SummaryWriter(self.args.logging_dir)
926
+ training_history = []
927
+
928
+ # Get dataloader with our custom sampler
929
+ train_dataloader = self.get_train_dataloader()
930
+
931
+ # Generation storage for reuse
932
+ stored_generation_outputs = None
933
+ generation_counter = 0
934
+ global_step = 0
935
+
936
+ for epoch in range(self.args.num_train_epochs):
937
+ epoch_metrics = []
938
+ start_mem = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
939
+
940
+ for batch_idx, batch in enumerate(train_dataloader):
941
+ # batch already has repeated prompts from our RepeatSampler
942
+ # Shape: (generation_batch_size, ...) where generation_batch_size = unique_prompts * num_generations
943
+
944
+ # Determine if we need new generations
945
+ generate_new = (global_step % (self.args.steps_per_generation * self.num_iterations)) == 0
946
+
947
+ if generate_new:
948
+ print(f"Generating new completions at step {global_step}")
949
+ with torch.no_grad():
950
+ # This is where ReTool magic happens - generate with code execution!
951
+ stored_generation_outputs = self._generate_and_score_completions(batch)
952
+ generation_counter = 0
953
+
954
+ # Now train on the stored generations
955
+ # This replaces the mini/micro batch logic from your original function
956
+ batch_loss = self._train_on_stored_generations(
957
+ stored_generation_outputs,
958
+ epoch_metrics
959
+ )
960
+
961
+ global_step += 1
962
+ generation_counter += 1
963
+
964
+ # Logging
965
+ if global_step % self.args.logging_steps == 0:
966
+ self._log_training_metrics(writer, epoch_metrics, global_step)
967
+
968
+ # Optional: Check for training instability
969
+ if self._should_stop_training(epoch_metrics):
970
+ print("Training instability detected! Stopping early.")
971
+ return training_history
972
+
973
+ # End of epoch
974
+ end_mem = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
975
+ epoch_summary = self._compute_epoch_summary(epoch_metrics, start_mem, end_mem)
976
+ training_history.append(epoch_summary)
977
+
978
+ # Log epoch results
979
+ self._log_epoch_metrics(epoch, epoch_summary, writer)
980
+
981
+ # Update scheduler if we have one
982
+ if hasattr(self, 'scheduler') and self.scheduler is not None:
983
+ self.scheduler.step(epoch_summary['mean_reward'])
984
+ print(f"Current learning rate: {self.optimizer.param_groups[0]['lr']}")
985
+
986
+ writer.close()
987
+ return training_history
988
+
989
+ def _train_on_stored_generations(self, generation_outputs, epoch_metrics):
990
+ """
991
+ Train on stored generations with mini/micro-batching.
992
+ This replaces the inner loops of your train_with_batching.
993
+ """
994
+ # Extract components from generation_outputs
995
+ # These already include code execution results and advantages!
996
+ prompt_ids = generation_outputs['prompt_ids']
997
+ completion_ids = generation_outputs['completion_ids']
998
+ advantages = generation_outputs['advantages']
999
+ completion_mask = generation_outputs['completion_mask']
1000
+ interpreter_mask = generation_outputs.get('interpreter_mask', completion_mask)
1001
+
1002
+ batch_size = prompt_ids.size(0)
1003
+
1004
+ # Mini-batch size: process multiple groups together
1005
+ # Each group has num_generations completions
1006
+ mini_batch_size = self.args.per_device_train_batch_size * self.num_generations
1007
+
1008
+ # Micro-batch size: for memory efficiency within mini-batch
1009
+ micro_batch_size = max(self.num_generations, 4) # At least one full group
1010
+
1011
+ total_loss = 0
1012
+ num_updates = 0
1013
+
1014
+ # Shuffle indices for this training iteration
1015
+ indices = torch.randperm(batch_size)
1016
+
1017
+ # Process in mini-batches
1018
+ for mini_start in range(0, batch_size, mini_batch_size):
1019
+ mini_end = min(mini_start + mini_batch_size, batch_size)
1020
+ mini_indices = indices[mini_start:mini_end]
1021
+
1022
+ self.optimizer.zero_grad()
1023
+ mini_batch_loss = 0
1024
+ num_micro_batches = 0
1025
+
1026
+ # Process in micro-batches (gradient accumulation)
1027
+ for micro_start in range(0, len(mini_indices), micro_batch_size):
1028
+ micro_end = min(micro_start + micro_batch_size, len(mini_indices))
1029
+ micro_indices = mini_indices[micro_start:micro_end]
1030
+
1031
+ # Create micro-batch
1032
+ micro_batch = {
1033
+ 'prompt_ids': prompt_ids[micro_indices],
1034
+ 'prompt_mask': generation_outputs['prompt_mask'][micro_indices],
1035
+ 'completion_ids': completion_ids[micro_indices],
1036
+ 'completion_mask': completion_mask[micro_indices],
1037
+ 'interpreter_mask': interpreter_mask[micro_indices],
1038
+ 'advantages': advantages[micro_indices]
1039
+ }
1040
+
1041
+ # Compute GRPO loss (this uses your _compute_loss method)
1042
+ loss = self._compute_loss(self.model, micro_batch)
1043
+
1044
+ # Scale for gradient accumulation
1045
+ scaled_loss = loss * (len(micro_indices) / len(mini_indices))
1046
+ scaled_loss.backward()
1047
+
1048
+ mini_batch_loss += loss.item()
1049
+ num_micro_batches += 1
1050
+
1051
+ # Gradient clipping and optimizer step
1052
+ grad_norm = torch.nn.utils.clip_grad_norm_(
1053
+ self.model.parameters(),
1054
+ max_norm=1.0
1055
+ )
1056
+ self.optimizer.step()
1057
+
1058
+ # Track metrics
1059
+ batch_metrics = {
1060
+ 'loss': mini_batch_loss / num_micro_batches,
1061
+ 'gradient_norm': grad_norm.item(),
1062
+ 'batch_size': len(mini_indices),
1063
+ 'advantages_mean': advantages[mini_indices].mean().item(),
1064
+ 'advantages_std': advantages[mini_indices].std().item()
1065
+ }
1066
+ epoch_metrics.append(batch_metrics)
1067
+
1068
+ total_loss += mini_batch_loss
1069
+ num_updates += 1
1070
+
1071
+ return total_loss / max(num_updates, 1)