Victarry commited on
Commit
d3e7e66
·
1 Parent(s): 65e77cc

Add support for 1F1B-interleave-overlap.

Browse files
README.md CHANGED
@@ -57,6 +57,12 @@ uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batch
57
  ```
58
  ![1f1b_overlap](assets/1f1b_overlap.png)
59
 
 
 
 
 
 
 
60
  ## Configuration
61
 
62
  The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
 
57
  ```
58
  ![1f1b_overlap](assets/1f1b_overlap.png)
59
 
60
+ Running for 1F1B-interleave-overlap strategy:
61
+ ```bash
62
+ uv run python main.py strategy=1f1b_interleave_overlap num_devices=4 num_stages=4 num_batches=8
63
+ ```
64
+ ![1f1b_interleave_overlap](assets/1f1b_interleave_overlap.png)
65
+
66
  ## Configuration
67
 
68
  The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
assets/1f1b_interleave_overlap.png ADDED

Git LFS Details

  • SHA256: 4447a83144505b4e82b231aca58411c4a7f3e85b06f1e40cd864179e08f94514
  • Pointer size: 130 Bytes
  • Size of remote file: 84.3 kB
main.py CHANGED
@@ -1,5 +1,6 @@
1
  from src.execution_model import ScheduleConfig
2
  from src.strategies import (
 
3
  generate_1f1b_interleave_schedule,
4
  generate_1f1b_overlap_schedule,
5
  generate_1f1b_schedule,
@@ -23,6 +24,8 @@ def main(cfg: DictConfig) -> None:
23
  run_zero_bubble_1p(cfg)
24
  elif cfg.strategy == "1f1b_overlap":
25
  run_1f1b_overlap(cfg)
 
 
26
  else:
27
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
28
 
@@ -107,6 +110,24 @@ def run_1f1b_overlap(cfg: DictConfig) -> None:
107
  schedule.execute()
108
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
  main()
 
1
  from src.execution_model import ScheduleConfig
2
  from src.strategies import (
3
+ generate_1f1b_interleave_overlap_schedule,
4
  generate_1f1b_interleave_schedule,
5
  generate_1f1b_overlap_schedule,
6
  generate_1f1b_schedule,
 
24
  run_zero_bubble_1p(cfg)
25
  elif cfg.strategy == "1f1b_overlap":
26
  run_1f1b_overlap(cfg)
27
+ elif cfg.strategy == "1f1b_interleave_overlap":
28
+ run_1f1b_interleave_overlap(cfg)
29
  else:
30
  raise ValueError(f"Unknown strategy: {cfg.strategy}")
31
 
 
110
  schedule.execute()
111
  visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
112
 
113
+ def run_1f1b_interleave_overlap(cfg: DictConfig) -> None:
114
+ """Run 1F1B interleave overlapped pipeline parallelism simulation."""
115
+ # Convert OmegaConf to dict for op_times if it exists
116
+ op_times = (
117
+ OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
118
+ )
119
+
120
+ schedule_config = ScheduleConfig(
121
+ num_devices=cfg.num_devices,
122
+ num_stages=cfg.num_stages,
123
+ num_batches=cfg.num_batches,
124
+ p2p_latency=cfg.p2p_latency,
125
+ placement_strategy="interleave",
126
+ op_times=op_times,
127
+ )
128
+ schedule = generate_1f1b_interleave_overlap_schedule(schedule_config)
129
+ schedule.execute()
130
+ visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
131
 
132
  if __name__ == "__main__":
133
  main()
src/execution_model.py CHANGED
@@ -158,8 +158,8 @@ class ScheduleConfig:
158
  # Check if we have a specific time for this combination
159
  if (op_type1, op_type2) in self.overlapped_op_times:
160
  return self.overlapped_op_times[(op_type1, op_type2)]
161
- # Otherwise, use the max of individual times plus a small overhead
162
- return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id)) + 0.2
163
 
164
  if op_type not in self.op_times:
165
  raise ValueError(f"Invalid operation type: {op_type}")
 
158
  # Check if we have a specific time for this combination
159
  if (op_type1, op_type2) in self.overlapped_op_times:
160
  return self.overlapped_op_times[(op_type1, op_type2)]
161
+ # Otherwise, use the max of individual times
162
+ return max(self.get_op_time(op_type1, stage_id), self.get_op_time(op_type2, stage_id))
163
 
164
  if op_type not in self.op_times:
165
  raise ValueError(f"Invalid operation type: {op_type}")
src/strategies.py CHANGED
@@ -130,116 +130,104 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
130
  return schedule
131
 
132
 
133
- # Some codes are copied from Megatron-LM
134
- def generate_1f1b_interleave_schedule(config: ScheduleConfig):
135
- schedule = Schedule(config)
136
-
137
- def get_pp_rank_microbatches(
138
- num_microbatches,
139
- num_devices,
140
- device_id,
141
- num_stages_per_device,
142
- microbatch_group_size_per_vp_stage,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  ):
144
- """Get the number of total, warmup, and remaining microbatches in PP scheduling."""
145
- total_num_microbatches = num_microbatches * num_stages_per_device
146
- are_all_microbatches_in_warmup = False
147
-
148
- if num_devices > 1:
149
- if num_stages_per_device is None:
150
- # forward_backward_pipelining_without_interleaving
151
- num_warmup_microbatches = num_devices - device_id - 1
152
- else:
153
- # forward_backward_pipelining_with_interleaving
154
- # Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
155
- # all workers, followed by more microbatches after depending on
156
- # stage ID (more forward passes for earlier stages, later stages can
157
- # immediately start with 1F1B).
158
- num_warmup_microbatches = (num_devices - device_id - 1) * 2
159
- num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
160
  else:
161
- # forward_backward_no_pipelining
162
- num_warmup_microbatches = 1
163
-
164
- if num_warmup_microbatches >= total_num_microbatches:
165
- num_warmup_microbatches = total_num_microbatches
166
- are_all_microbatches_in_warmup = True
167
- num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
168
-
169
- return (
170
- total_num_microbatches,
171
- are_all_microbatches_in_warmup,
172
- num_warmup_microbatches,
173
- num_microbatches_remaining,
174
- )
175
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- def get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
178
- """Get the schedule table for PP scheduling.
179
-
180
- Create a tunable schedule lookup table.
181
- The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
182
- For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
183
- virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
184
- microbatch_id | 0 1 2 0 1 2 3 4 3 4
185
- model_chunk_id | 0 0 0 1 1 1 0 0 1 1
186
- """
187
- schedule_table = []
188
- for min_microbatch_id_in_group in range(
189
- 0, num_microbatches, microbatch_group_size_per_vp_stage
190
- ):
191
- if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
192
- # Construct schedule for the last microbatch group
193
- schedule_table.extend(
194
- [
195
- (microbatch_id, model_chunk_id)
196
- for model_chunk_id in range(num_model_chunks)
197
- for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
198
- ]
199
- )
200
- else:
201
- # Construct schedule for other microbatch groups
202
- schedule_table.extend(
203
- [
204
- (microbatch_id, model_chunk_id)
205
- for model_chunk_id in range(num_model_chunks)
206
- for microbatch_id in range(
207
- min_microbatch_id_in_group,
208
- min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
209
- )
210
- ]
211
- )
212
- return schedule_table
213
-
214
 
215
- def convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
216
- """Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
217
- order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
218
- virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
219
- microbatch_id | 0 1 2 0 1 2 3 4 3 4
220
- model_chunk_id | 0 0 0 1 1 1 0 0 1 1
221
-
222
- Then the forward backward separated order is:
223
- forward | 1 1 1 2 2 2 1 1 2 2
224
- backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
225
-
226
- If num_warmup_microbatches is 5, the output order is:
227
- 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
228
- """
229
- _, model_chunk_id_table = zip(*schedule_table)
230
- forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
231
- backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
232
- order = forward_order[:num_warmup_microbatches]
233
- for i in range(num_warmup_microbatches, len(forward_order)):
234
- order.append(forward_order[i])
235
- order.append(backward_order[i - num_warmup_microbatches])
236
- if num_warmup_microbatches > 0:
237
- order.extend(backward_order[-num_warmup_microbatches:])
238
- return order
239
 
240
  for device_id in range(config.num_devices):
241
  microbatch_group_size_per_vp_stage = config.num_devices
242
- total_num_microbatches, are_all_microbatches_in_warmup, num_warmup_microbatches, num_microbatches_remaining = get_pp_rank_microbatches(
243
  config.num_batches,
244
  config.num_devices,
245
  device_id,
@@ -247,13 +235,13 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
247
  microbatch_group_size_per_vp_stage,
248
  )
249
 
250
- schedule_table = get_schedule_table(
251
  config.num_batches,
252
  config.num_stages_per_device,
253
  microbatch_group_size_per_vp_stage,
254
  )
255
 
256
- order = convert_schedule_table_to_order(
257
  num_warmup_microbatches,
258
  num_model_chunks=config.num_stages_per_device,
259
  schedule_table=schedule_table,
@@ -280,3 +268,89 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
280
  schedule.get_op(micro_batch_id, stage_id, op_type)
281
  )
282
  return schedule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  return schedule
131
 
132
 
133
+ def _get_pp_rank_microbatches(
134
+ num_microbatches,
135
+ num_devices,
136
+ device_id,
137
+ num_stages_per_device,
138
+ microbatch_group_size_per_vp_stage,
139
+ ):
140
+ """Get the number of total, warmup, and remaining microbatches in PP scheduling."""
141
+ total_num_microbatches = num_microbatches * num_stages_per_device
142
+
143
+ if num_devices > 1:
144
+ # Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
145
+ # all workers, followed by more microbatches after depending on
146
+ # stage ID (more forward passes for earlier stages, later stages can
147
+ # immediately start with 1F1B).
148
+ num_warmup_microbatches = (num_devices - device_id - 1) * 2
149
+ num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
150
+ else:
151
+ # forward_backward_no_pipelining
152
+ num_warmup_microbatches = 1
153
+
154
+ if num_warmup_microbatches >= total_num_microbatches:
155
+ num_warmup_microbatches = total_num_microbatches
156
+
157
+ return num_warmup_microbatches
158
+
159
+
160
+ def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
161
+ """Get the schedule table for PP scheduling.
162
+
163
+ Create a tunable schedule lookup table.
164
+ The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
165
+ For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
166
+ virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
167
+ microbatch_id | 0 1 2 0 1 2 3 4 3 4
168
+ model_chunk_id | 0 0 0 1 1 1 0 0 1 1
169
+ """
170
+ schedule_table = []
171
+ for min_microbatch_id_in_group in range(
172
+ 0, num_microbatches, microbatch_group_size_per_vp_stage
173
  ):
174
+ if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
175
+ # Construct schedule for the last microbatch group
176
+ schedule_table.extend(
177
+ [
178
+ (microbatch_id, model_chunk_id)
179
+ for model_chunk_id in range(num_model_chunks)
180
+ for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
181
+ ]
182
+ )
 
 
 
 
 
 
 
183
  else:
184
+ # Construct schedule for other microbatch groups
185
+ schedule_table.extend(
186
+ [
187
+ (microbatch_id, model_chunk_id)
188
+ for model_chunk_id in range(num_model_chunks)
189
+ for microbatch_id in range(
190
+ min_microbatch_id_in_group,
191
+ min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
192
+ )
193
+ ]
194
+ )
195
+ return schedule_table
196
+
197
+
198
+ def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
199
+ """Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
200
+ order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
201
+ virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
202
+ microbatch_id | 0 1 2 0 1 2 3 4 3 4
203
+ model_chunk_id | 0 0 0 1 1 1 0 0 1 1
204
+
205
+ Then the forward backward separated order is:
206
+ forward | 1 1 1 2 2 2 1 1 2 2
207
+ backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
208
+
209
+ If num_warmup_microbatches is 5, the output order is:
210
+ 1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
211
+ """
212
+ _, model_chunk_id_table = zip(*schedule_table)
213
+ forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
214
+ backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
215
+ order = forward_order[:num_warmup_microbatches]
216
+ for i in range(num_warmup_microbatches, len(forward_order)):
217
+ order.append(forward_order[i])
218
+ order.append(backward_order[i - num_warmup_microbatches])
219
+ if num_warmup_microbatches > 0:
220
+ order.extend(backward_order[-num_warmup_microbatches:])
221
+ return order
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ # Some codes are copied from Megatron-LM
225
+ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
226
+ schedule = Schedule(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  for device_id in range(config.num_devices):
229
  microbatch_group_size_per_vp_stage = config.num_devices
230
+ num_warmup_microbatches = _get_pp_rank_microbatches(
231
  config.num_batches,
232
  config.num_devices,
233
  device_id,
 
235
  microbatch_group_size_per_vp_stage,
236
  )
237
 
238
+ schedule_table = _get_schedule_table(
239
  config.num_batches,
240
  config.num_stages_per_device,
241
  microbatch_group_size_per_vp_stage,
242
  )
243
 
244
+ order = _convert_schedule_table_to_order(
245
  num_warmup_microbatches,
246
  num_model_chunks=config.num_stages_per_device,
247
  schedule_table=schedule_table,
 
268
  schedule.get_op(micro_batch_id, stage_id, op_type)
269
  )
270
  return schedule
271
+
272
+ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
273
+ schedule = Schedule(config)
274
+
275
+ for device_id in range(config.num_devices):
276
+ microbatch_group_size_per_vp_stage = config.num_devices
277
+ num_warmup_microbatches = _get_pp_rank_microbatches(
278
+ config.num_batches,
279
+ config.num_devices,
280
+ device_id,
281
+ config.num_stages_per_device,
282
+ microbatch_group_size_per_vp_stage,
283
+ )
284
+
285
+ schedule_table = _get_schedule_table(
286
+ config.num_batches,
287
+ config.num_stages_per_device,
288
+ microbatch_group_size_per_vp_stage,
289
+ )
290
+
291
+ # NOTE: Add one more warmup microbatch for overlapped operations!
292
+ num_warmup_microbatches += 1
293
+ order = _convert_schedule_table_to_order(
294
+ num_warmup_microbatches,
295
+ num_model_chunks=config.num_stages_per_device,
296
+ schedule_table=schedule_table,
297
+ )
298
+
299
+ cur_stage_microbatch_id = {}
300
+ for i in range(1, config.num_stages_per_device+1):
301
+ cur_stage_microbatch_id[i] = 0
302
+ cur_stage_microbatch_id[-i] = 0
303
+ i = 0
304
+
305
+ num_overlapped_batches = len(order) - num_warmup_microbatches * 2
306
+ while i < len(order):
307
+ if i < num_warmup_microbatches:
308
+ order_item = order[i]
309
+ assert order_item > 0
310
+ op_type = "forward"
311
+ micro_batch_id = cur_stage_microbatch_id[order_item]
312
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
313
+
314
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
315
+ schedule.device_queues[device_id].add_operation(
316
+ schedule.get_op(micro_batch_id, stage_id, op_type)
317
+ )
318
+ i += 1
319
+ elif i >= num_warmup_microbatches and i < num_warmup_microbatches + num_overlapped_batches - 1:
320
+ order_item_a = order[i]
321
+ order_item_b = order[i+1]
322
+
323
+ op_type_a = "forward" if order_item_a > 0 else "backward"
324
+ micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
325
+ cur_stage_microbatch_id[order_item_a] = cur_stage_microbatch_id[order_item_a] + 1
326
+
327
+ op_type_b = "forward" if order_item_b > 0 else "backward"
328
+ micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
329
+ cur_stage_microbatch_id[order_item_b] = cur_stage_microbatch_id[order_item_b] + 1
330
+
331
+ stage_id_a = schedule.device_queues[device_id].stages[abs(order_item_a)-1]
332
+ stage_id_b = schedule.device_queues[device_id].stages[abs(order_item_b)-1]
333
+
334
+ op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
335
+ op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
336
+ overlapped_op = OverlappedOperation([op_a, op_b])
337
+ schedule.register_overlapped_operation(overlapped_op)
338
+ schedule.device_queues[device_id].add_operation(overlapped_op)
339
+
340
+ i += 2
341
+ else:
342
+ assert i >= num_warmup_microbatches + num_overlapped_batches
343
+ order_item = order[i]
344
+ assert order_item < 0
345
+ op_type = "backward"
346
+ micro_batch_id = cur_stage_microbatch_id[order_item]
347
+ cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
348
+
349
+ stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
350
+ schedule.device_queues[device_id].add_operation(
351
+ schedule.get_op(micro_batch_id, stage_id, op_type)
352
+ )
353
+ i += 1
354
+
355
+
356
+ return schedule