Add custom sampler, train data loader and GRPO style train loop for ReTool_trainer
Browse files- src/retool_trainer.py +350 -7
src/retool_trainer.py
CHANGED
@@ -1,8 +1,18 @@
|
|
1 |
-
from typing import Any, Callable, Optional, Union
|
2 |
-
from collections import defaultdict
|
3 |
-
import re
|
4 |
import profiling_decorator
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import datasets
|
7 |
import torch
|
8 |
import torch.utils.data
|
@@ -14,6 +24,7 @@ from torch import nn
|
|
14 |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
15 |
from torch.utils.data import DataLoader, Sampler
|
16 |
from transformers import (
|
|
|
17 |
AutoModelForCausalLM,
|
18 |
AutoModelForSequenceClassification,
|
19 |
AutoTokenizer,
|
@@ -28,6 +39,105 @@ from transformers import (
|
|
28 |
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
class ReToolTrainer(Trainer): # Change this line
|
33 |
|
@@ -52,16 +162,19 @@ class ReToolTrainer(Trainer): # Change this line
|
|
52 |
mask_truncated_completions: bool = True,
|
53 |
**kwargs
|
54 |
):
|
55 |
-
# Initialize parent Trainer (simpler call)
|
56 |
super().__init__(
|
57 |
model=model,
|
58 |
-
tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class'
|
59 |
args=args,
|
|
|
|
|
60 |
train_dataset=train_dataset,
|
61 |
eval_dataset=eval_dataset,
|
|
|
|
|
|
|
62 |
**kwargs
|
63 |
)
|
64 |
-
|
65 |
|
66 |
# Store processing_class for compatibility
|
67 |
self.processing_class = processing_class or self.tokenizer
|
@@ -115,6 +228,76 @@ class ReToolTrainer(Trainer): # Change this line
|
|
115 |
use_cache=True,
|
116 |
cache_implementation=args.cache_implementation, #args.cache_implementation = 'Offloaded Cache'
|
117 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
def _get_interpreter_token_ids(self) -> list[int]:
|
119 |
"""Get token IDs for <interpreter> and </interpreter> tags."""
|
120 |
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
|
@@ -725,4 +908,164 @@ def _compute_loss(self, model, inputs):
|
|
725 |
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
|
726 |
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
|
727 |
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
|
728 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import profiling_decorator
|
2 |
|
3 |
+
import copy
|
4 |
+
import inspect
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import textwrap
|
8 |
+
import warnings
|
9 |
+
from collections import defaultdict, deque
|
10 |
+
from collections.abc import Sequence, Sized
|
11 |
+
from contextlib import nullcontext
|
12 |
+
from functools import partial
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import Any, Callable, Optional, Union
|
15 |
+
|
16 |
import datasets
|
17 |
import torch
|
18 |
import torch.utils.data
|
|
|
24 |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
25 |
from torch.utils.data import DataLoader, Sampler
|
26 |
from transformers import (
|
27 |
+
AutoConfig,
|
28 |
AutoModelForCausalLM,
|
29 |
AutoModelForSequenceClassification,
|
30 |
AutoTokenizer,
|
|
|
39 |
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
40 |
|
41 |
|
42 |
+
class HFRepeatSampler(Sampler):
|
43 |
+
"""
|
44 |
+
Sampler that repeats the indices of a dataset in a structured manner.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
data_source (`Sized`):
|
48 |
+
Dataset to sample from.
|
49 |
+
mini_repeat_count (`int`):
|
50 |
+
Number of times to repeat each index per batch.
|
51 |
+
batch_size (`int`, *optional*, defaults to `1`):
|
52 |
+
Number of unique indices per batch.
|
53 |
+
repeat_count (`int`, *optional*, defaults to `1`):
|
54 |
+
Number of times to repeat the full sampling process.
|
55 |
+
shuffle (`bool`, *optional*, defaults to `True`):
|
56 |
+
Whether to shuffle the dataset.
|
57 |
+
seed (`int` or `None`, *optional*, defaults to `None`):
|
58 |
+
Random seed for reproducibility (only affects this sampler).
|
59 |
+
|
60 |
+
Example:
|
61 |
+
```python
|
62 |
+
>>> sampler = RepeatSampler(
|
63 |
+
... ["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4
|
64 |
+
... )
|
65 |
+
>>> list(sampler)
|
66 |
+
[4, 4, 3, 3, 0, 0,
|
67 |
+
4, 4, 3, 3, 0, 0,
|
68 |
+
4, 4, 3, 3, 0, 0,
|
69 |
+
4, 4, 3, 3, 0, 0,
|
70 |
+
1, 1, 2, 2, 6, 6,
|
71 |
+
1, 1, 2, 2, 6, 6,
|
72 |
+
1, 1, 2, 2, 6, 6,
|
73 |
+
1, 1, 2, 2, 6, 6]
|
74 |
+
```
|
75 |
+
|
76 |
+
```txt
|
77 |
+
mini_repeat_count = 3
|
78 |
+
- - -
|
79 |
+
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
80 |
+
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
81 |
+
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
|
82 |
+
repeat_count = 2
|
83 |
+
0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
84 |
+
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
85 |
+
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
|
86 |
+
--------- --------- --------- ---------
|
87 |
+
--------- --------- --------- ---------
|
88 |
+
--------- --------- --------- ---------
|
89 |
+
batch_size = 12
|
90 |
+
```
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
data_source: Sized,
|
96 |
+
mini_repeat_count: int,
|
97 |
+
batch_size: int = 1,
|
98 |
+
repeat_count: int = 1,
|
99 |
+
shuffle: bool = True,
|
100 |
+
seed: Optional[int] = None,
|
101 |
+
):
|
102 |
+
self.data_source = data_source
|
103 |
+
self.mini_repeat_count = mini_repeat_count
|
104 |
+
self.batch_size = batch_size
|
105 |
+
self.repeat_count = repeat_count
|
106 |
+
self.num_samples = len(data_source)
|
107 |
+
self.shuffle = shuffle
|
108 |
+
self.seed = seed
|
109 |
+
|
110 |
+
if shuffle:
|
111 |
+
self.generator = torch.Generator() # Create a local random generator
|
112 |
+
if seed is not None:
|
113 |
+
self.generator.manual_seed(seed)
|
114 |
+
|
115 |
+
def __iter__(self):
|
116 |
+
if self.shuffle:
|
117 |
+
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
|
118 |
+
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
119 |
+
else:
|
120 |
+
indexes = list(range(self.num_samples))
|
121 |
+
|
122 |
+
# [2, 4, 3, 1, 0, 6, 5]
|
123 |
+
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
|
124 |
+
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
125 |
+
|
126 |
+
# [[2, 4, 3], [1, 0, 6], [5]]
|
127 |
+
# -> [[2, 4, 3], [1, 0, 6]]
|
128 |
+
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
129 |
+
|
130 |
+
for chunk in indexes:
|
131 |
+
for _ in range(self.repeat_count):
|
132 |
+
for index in chunk:
|
133 |
+
for _ in range(self.mini_repeat_count):
|
134 |
+
yield index
|
135 |
+
|
136 |
+
def __len__(self) -> int:
|
137 |
+
return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
|
142 |
class ReToolTrainer(Trainer): # Change this line
|
143 |
|
|
|
162 |
mask_truncated_completions: bool = True,
|
163 |
**kwargs
|
164 |
):
|
165 |
+
# Initialize parent Trainer (simpler call)
|
166 |
super().__init__(
|
167 |
model=model,
|
|
|
168 |
args=args,
|
169 |
+
tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class'
|
170 |
+
data_collator=identity, # No data collation is needed in GRPO
|
171 |
train_dataset=train_dataset,
|
172 |
eval_dataset=eval_dataset,
|
173 |
+
processing_class=processing_class,
|
174 |
+
callbacks=callbacks,
|
175 |
+
optimizers=optimizers,
|
176 |
**kwargs
|
177 |
)
|
|
|
178 |
|
179 |
# Store processing_class for compatibility
|
180 |
self.processing_class = processing_class or self.tokenizer
|
|
|
228 |
use_cache=True,
|
229 |
cache_implementation=args.cache_implementation, #args.cache_implementation = 'Offloaded Cache'
|
230 |
)
|
231 |
+
def _set_signature_columns_if_needed(self):
|
232 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
233 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
234 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
235 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
236 |
+
if self._signature_columns is None:
|
237 |
+
self._signature_columns = ["prompt", "image"]
|
238 |
+
|
239 |
+
def _get_train_sampler(self, dataset=None):
|
240 |
+
"""Override to use RepeatSampler for GRPO."""
|
241 |
+
# Returns a sampler that
|
242 |
+
# 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
|
243 |
+
# distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
|
244 |
+
# group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
|
245 |
+
# in group formation.
|
246 |
+
# 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
|
247 |
+
# _prepare_inputs to see how the generations are stored and reused.
|
248 |
+
|
249 |
+
# In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
|
250 |
+
# second row shows the second sampled batch, and so on.
|
251 |
+
#
|
252 |
+
# | GPU 0 | GPU 1 |
|
253 |
+
#
|
254 |
+
# global_step step <-───> num_generations=2
|
255 |
+
# <-───────> per_device_train_batch_size=3
|
256 |
+
# grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
|
257 |
+
# =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss
|
258 |
+
# |
|
259 |
+
# | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss
|
260 |
+
# steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss
|
261 |
+
#
|
262 |
+
# 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
|
263 |
+
# 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss
|
264 |
+
# ...
|
265 |
+
if dataset is None:
|
266 |
+
dataset = self.train_dataset
|
267 |
+
|
268 |
+
return HFRepeatSampler(
|
269 |
+
data_source=dataset,
|
270 |
+
mini_repeat_count=self.num_generations, # e.g., 4 completions per prompt
|
271 |
+
batch_size=self.args.generation_batch_size // self.num_generations, # correction
|
272 |
+
repeat_count=self.num_iterations * self.args.steps_per_generation, # correction
|
273 |
+
shuffle=True,
|
274 |
+
seed=self.args.seed
|
275 |
+
)
|
276 |
+
|
277 |
+
def get_train_dataloader(self):
|
278 |
+
"""Override to ensure our custom sampler is used."""
|
279 |
+
if self.train_dataset is None:
|
280 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
281 |
+
|
282 |
+
train_dataset = self.train_dataset
|
283 |
+
data_collator = self.data_collator
|
284 |
+
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
285 |
+
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
286 |
+
else:
|
287 |
+
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
288 |
+
|
289 |
+
sampler = self._get_train_sampler(train_dataset)
|
290 |
+
dataloader_batch_size = self._train_batch_size * self.args.steps_per_generation
|
291 |
+
|
292 |
+
return DataLoader(
|
293 |
+
train_dataset,
|
294 |
+
batch_size= self.args.generation_batch_size, # < this is the change, HF was useing dataloader_batch_size
|
295 |
+
sampler=sampler,
|
296 |
+
collate_fn=data_collator,
|
297 |
+
drop_last=self.args.dataloader_drop_last,
|
298 |
+
num_workers=self.args.dataloader_num_workers,
|
299 |
+
)
|
300 |
+
|
301 |
def _get_interpreter_token_ids(self) -> list[int]:
|
302 |
"""Get token IDs for <interpreter> and </interpreter> tags."""
|
303 |
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
|
|
|
908 |
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
|
909 |
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
|
910 |
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
|
911 |
+
return loss
|
912 |
+
|
913 |
+
def train(self):
|
914 |
+
"""
|
915 |
+
Comprehensive training loop for ReTool with GRPO.
|
916 |
+
Adapted from train_with_batching to work as a method.
|
917 |
+
"""
|
918 |
+
# Initialize
|
919 |
+
self.model.train()
|
920 |
+
if not hasattr(self, 'ref_model') or self.ref_model is None:
|
921 |
+
self.ref_model = deepcopy(self.model)
|
922 |
+
self.ref_model.eval()
|
923 |
+
|
924 |
+
# Setup tracking
|
925 |
+
writer = SummaryWriter(self.args.logging_dir)
|
926 |
+
training_history = []
|
927 |
+
|
928 |
+
# Get dataloader with our custom sampler
|
929 |
+
train_dataloader = self.get_train_dataloader()
|
930 |
+
|
931 |
+
# Generation storage for reuse
|
932 |
+
stored_generation_outputs = None
|
933 |
+
generation_counter = 0
|
934 |
+
global_step = 0
|
935 |
+
|
936 |
+
for epoch in range(self.args.num_train_epochs):
|
937 |
+
epoch_metrics = []
|
938 |
+
start_mem = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
|
939 |
+
|
940 |
+
for batch_idx, batch in enumerate(train_dataloader):
|
941 |
+
# batch already has repeated prompts from our RepeatSampler
|
942 |
+
# Shape: (generation_batch_size, ...) where generation_batch_size = unique_prompts * num_generations
|
943 |
+
|
944 |
+
# Determine if we need new generations
|
945 |
+
generate_new = (global_step % (self.args.steps_per_generation * self.num_iterations)) == 0
|
946 |
+
|
947 |
+
if generate_new:
|
948 |
+
print(f"Generating new completions at step {global_step}")
|
949 |
+
with torch.no_grad():
|
950 |
+
# This is where ReTool magic happens - generate with code execution!
|
951 |
+
stored_generation_outputs = self._generate_and_score_completions(batch)
|
952 |
+
generation_counter = 0
|
953 |
+
|
954 |
+
# Now train on the stored generations
|
955 |
+
# This replaces the mini/micro batch logic from your original function
|
956 |
+
batch_loss = self._train_on_stored_generations(
|
957 |
+
stored_generation_outputs,
|
958 |
+
epoch_metrics
|
959 |
+
)
|
960 |
+
|
961 |
+
global_step += 1
|
962 |
+
generation_counter += 1
|
963 |
+
|
964 |
+
# Logging
|
965 |
+
if global_step % self.args.logging_steps == 0:
|
966 |
+
self._log_training_metrics(writer, epoch_metrics, global_step)
|
967 |
+
|
968 |
+
# Optional: Check for training instability
|
969 |
+
if self._should_stop_training(epoch_metrics):
|
970 |
+
print("Training instability detected! Stopping early.")
|
971 |
+
return training_history
|
972 |
+
|
973 |
+
# End of epoch
|
974 |
+
end_mem = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
|
975 |
+
epoch_summary = self._compute_epoch_summary(epoch_metrics, start_mem, end_mem)
|
976 |
+
training_history.append(epoch_summary)
|
977 |
+
|
978 |
+
# Log epoch results
|
979 |
+
self._log_epoch_metrics(epoch, epoch_summary, writer)
|
980 |
+
|
981 |
+
# Update scheduler if we have one
|
982 |
+
if hasattr(self, 'scheduler') and self.scheduler is not None:
|
983 |
+
self.scheduler.step(epoch_summary['mean_reward'])
|
984 |
+
print(f"Current learning rate: {self.optimizer.param_groups[0]['lr']}")
|
985 |
+
|
986 |
+
writer.close()
|
987 |
+
return training_history
|
988 |
+
|
989 |
+
def _train_on_stored_generations(self, generation_outputs, epoch_metrics):
|
990 |
+
"""
|
991 |
+
Train on stored generations with mini/micro-batching.
|
992 |
+
This replaces the inner loops of your train_with_batching.
|
993 |
+
"""
|
994 |
+
# Extract components from generation_outputs
|
995 |
+
# These already include code execution results and advantages!
|
996 |
+
prompt_ids = generation_outputs['prompt_ids']
|
997 |
+
completion_ids = generation_outputs['completion_ids']
|
998 |
+
advantages = generation_outputs['advantages']
|
999 |
+
completion_mask = generation_outputs['completion_mask']
|
1000 |
+
interpreter_mask = generation_outputs.get('interpreter_mask', completion_mask)
|
1001 |
+
|
1002 |
+
batch_size = prompt_ids.size(0)
|
1003 |
+
|
1004 |
+
# Mini-batch size: process multiple groups together
|
1005 |
+
# Each group has num_generations completions
|
1006 |
+
mini_batch_size = self.args.per_device_train_batch_size * self.num_generations
|
1007 |
+
|
1008 |
+
# Micro-batch size: for memory efficiency within mini-batch
|
1009 |
+
micro_batch_size = max(self.num_generations, 4) # At least one full group
|
1010 |
+
|
1011 |
+
total_loss = 0
|
1012 |
+
num_updates = 0
|
1013 |
+
|
1014 |
+
# Shuffle indices for this training iteration
|
1015 |
+
indices = torch.randperm(batch_size)
|
1016 |
+
|
1017 |
+
# Process in mini-batches
|
1018 |
+
for mini_start in range(0, batch_size, mini_batch_size):
|
1019 |
+
mini_end = min(mini_start + mini_batch_size, batch_size)
|
1020 |
+
mini_indices = indices[mini_start:mini_end]
|
1021 |
+
|
1022 |
+
self.optimizer.zero_grad()
|
1023 |
+
mini_batch_loss = 0
|
1024 |
+
num_micro_batches = 0
|
1025 |
+
|
1026 |
+
# Process in micro-batches (gradient accumulation)
|
1027 |
+
for micro_start in range(0, len(mini_indices), micro_batch_size):
|
1028 |
+
micro_end = min(micro_start + micro_batch_size, len(mini_indices))
|
1029 |
+
micro_indices = mini_indices[micro_start:micro_end]
|
1030 |
+
|
1031 |
+
# Create micro-batch
|
1032 |
+
micro_batch = {
|
1033 |
+
'prompt_ids': prompt_ids[micro_indices],
|
1034 |
+
'prompt_mask': generation_outputs['prompt_mask'][micro_indices],
|
1035 |
+
'completion_ids': completion_ids[micro_indices],
|
1036 |
+
'completion_mask': completion_mask[micro_indices],
|
1037 |
+
'interpreter_mask': interpreter_mask[micro_indices],
|
1038 |
+
'advantages': advantages[micro_indices]
|
1039 |
+
}
|
1040 |
+
|
1041 |
+
# Compute GRPO loss (this uses your _compute_loss method)
|
1042 |
+
loss = self._compute_loss(self.model, micro_batch)
|
1043 |
+
|
1044 |
+
# Scale for gradient accumulation
|
1045 |
+
scaled_loss = loss * (len(micro_indices) / len(mini_indices))
|
1046 |
+
scaled_loss.backward()
|
1047 |
+
|
1048 |
+
mini_batch_loss += loss.item()
|
1049 |
+
num_micro_batches += 1
|
1050 |
+
|
1051 |
+
# Gradient clipping and optimizer step
|
1052 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
1053 |
+
self.model.parameters(),
|
1054 |
+
max_norm=1.0
|
1055 |
+
)
|
1056 |
+
self.optimizer.step()
|
1057 |
+
|
1058 |
+
# Track metrics
|
1059 |
+
batch_metrics = {
|
1060 |
+
'loss': mini_batch_loss / num_micro_batches,
|
1061 |
+
'gradient_norm': grad_norm.item(),
|
1062 |
+
'batch_size': len(mini_indices),
|
1063 |
+
'advantages_mean': advantages[mini_indices].mean().item(),
|
1064 |
+
'advantages_std': advantages[mini_indices].std().item()
|
1065 |
+
}
|
1066 |
+
epoch_metrics.append(batch_metrics)
|
1067 |
+
|
1068 |
+
total_loss += mini_batch_loss
|
1069 |
+
num_updates += 1
|
1070 |
+
|
1071 |
+
return total_loss / max(num_updates, 1)
|