kernel
drbh commited on
Commit
0e97a7c
·
1 Parent(s): 13afbbe

feat: bump build

Browse files
Files changed (40) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_63599de.abi3.so → _megablocks_13afbbe_dirty.abi3.so} +2 -2
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +195 -20
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  5. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_63599de.abi3.so → _megablocks_13afbbe_dirty.abi3.so} +2 -2
  6. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  7. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +195 -20
  8. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  9. build/{torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so → torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so} +1 -1
  10. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  11. build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py +195 -20
  12. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  13. build/{torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so → torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so} +2 -2
  14. build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  15. build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py +195 -20
  16. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  17. build/{torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so → torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so} +2 -2
  18. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so +0 -3
  19. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  20. build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py +195 -20
  21. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  22. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +3 -0
  23. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so +0 -3
  24. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  25. build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py +195 -20
  26. build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  27. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +3 -0
  28. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so +0 -3
  29. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  30. build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +195 -20
  31. build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  32. build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +3 -0
  33. build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  34. build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +195 -20
  35. build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
  36. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +3 -0
  37. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_63599de.abi3.so +0 -3
  38. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +3 -3
  39. build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +195 -20
  40. build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +121 -21
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_63599de.abi3.so → _megablocks_13afbbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9b35f3f60e0cbf0ce9e84e1224754d353f9de646cf30df5828168222889d312f
3
- size 10517576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5683ac8b3e98fc8b8ab19f964b0dbfb9a980b6135220b0a0c1b50180665ce341
3
+ size 10517608
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_63599de.abi3.so → _megablocks_13afbbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:05d38f81524501b75940bfad8686f4f502b5c6af1de85fb1fe5b20da765d4c3c
3
- size 11869392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b55d6ee3d41404603fdb75ad9a2949aa92e0224f7056fdbeb4c66934035ebd4b
3
+ size 11869424
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/{torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so → torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3a243e51490184fb48e02dbc1115545ea69313a3d63058f8423c0c493e90bc5a
3
  size 11931080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:516c5026180d4a8d013c500ed284a60ecbed4bc6c9dc084b838913f40327d1a6
3
  size 11931080
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/{torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so → torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e9e392427d3157216b82014570075137082c5ec5c5bd6b63c1458d509ed4ff3
3
- size 11931048
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5c8c1b700d297741dd86e8c388e03913a30769ceb51b7c12a01245fbdf30128
3
+ size 10510072
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/{torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so → torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b2451173cb1d000c6d270b59b2aaab1aa0e54025422ba81b1ee990621c90a823
3
- size 10510040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d915db521f8d37fb887ed8334db60165e5923f8dce817d69f6441c5ba2d210d6
3
+ size 11857952
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_63599de.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f8bfaaeb2a5e226a80403463d15f2c762ac8cb70ca7a44d2156aadfac63ab0d1
3
- size 11857920
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94a9a3bb426adceab66b39fe9d179b73e4524167aeb63bed5a67cd7734d31b24
3
+ size 11923704
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_63599de.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:637a8c7ef51b1d35911546ef7456854f1ee7cc3278565d2e144e16f733487148
3
- size 11923672
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa9d1964e47ec6ff3c4ec77947f6a2a19868b03cec3618daf0555e011f69924d
3
+ size 10517848
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_63599de.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:002a58b415ed9e0f6418b103368c4f57f17fa86a851a02f594a33b097b33da09
3
- size 10517816
 
 
 
 
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b204da58db0f8be45dda62abd98b74a8e60f1f983bfc6a128c74ff66f67cf502
3
+ size 11931112
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f861a8bffedbbf14341d39355f3f43a7c24fee2b99bb9ea7b3a2b9ad21c7ee28
3
+ size 17892656
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_63599de.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dadccc59929c2fdbdf3b153f564d223013924c7b617d1eb2b3ecdc04470a4a60
3
- size 17892624
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_63599de
3
- ops = torch.ops._megablocks_63599de
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_63599de::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_13afbbe_dirty
3
+ ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_13afbbe_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out
build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,28 +7,126 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- from .. import benchmark_util
11
- from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
- (16, 1024),
16
- (32, 1024),
17
- (64, 1024),
18
- (128, 1024),
19
- (256, 1024),
20
- (512, 1024),
21
- (1024, 1024),
22
- (2 * 1024, 1024),
23
- (4 * 1024, 1024),
24
- (8 * 1024, 1024),
25
- (16 * 1024, 1024),
26
- (32 * 1024, 1024),
27
- (64 * 1024, 1024),
28
- (128 * 1024, 1024),
29
- (256 * 1024, 1024),
30
- (512 * 1024, 1024),
31
- (1024 * 1024, 1024),
32
  )
33
 
34
 
@@ -47,10 +145,12 @@ def benchmark_all_to_all(group, sl, hs):
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
- time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
- benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ # from .. import benchmark_util
11
+
12
+ # Copyright 2024 Databricks
13
+ # SPDX-License-Identifier: Apache-2.0
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+
19
+ def log_benchmark(name, arguments, time, std):
20
+ print("=" * 60)
21
+ print(f"{name} Benchmark")
22
+ print("Benchmark Parameters:")
23
+ for key, value in arguments.items():
24
+ print(f"{key} = {value}")
25
+ print("Results:")
26
+ print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
+ print("=" * 60)
28
+
29
+
30
+ def benchmark_function(fn, iterations=100, warmup=10):
31
+ print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
+ # Warmup iterations.
33
+ for _ in range(warmup):
34
+ fn()
35
+
36
+ times = []
37
+ print(f"Running {iterations} iterations...")
38
+ for i in range(iterations):
39
+ start = torch.cuda.Event(enable_timing=True)
40
+ end = torch.cuda.Event(enable_timing=True)
41
+
42
+ start.record()
43
+ fn()
44
+ end.record()
45
+
46
+ torch.cuda.synchronize()
47
+ times.append(start.elapsed_time(end))
48
+ return np.mean(times), np.std(times)
49
+
50
+
51
+ # from .._layers.all_to_all import all_to_all
52
+
53
+ # Copyright 2024 Databricks
54
+ # SPDX-License-Identifier: Apache-2.0
55
+
56
+ import torch
57
+ import torch.distributed as dist
58
+
59
+
60
+ class AllToAllOp(torch.autograd.Function):
61
+
62
+ @staticmethod
63
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
+ out = torch.empty(
65
+ (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
+ )
67
+
68
+ ctx.input_shape = x.shape
69
+ ctx.output_split_sizes = output_split_sizes
70
+ ctx.input_split_sizes = input_split_sizes
71
+ ctx.group = group
72
+ handle = dist.all_to_all_single(
73
+ out,
74
+ x,
75
+ output_split_sizes=output_split_sizes,
76
+ input_split_sizes=input_split_sizes,
77
+ group=group,
78
+ async_op=async_op,
79
+ )
80
+ return out, handle
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad, _):
84
+ if ctx.needs_input_grad[0]:
85
+ out = torch.empty(
86
+ ctx.input_shape,
87
+ device=grad.device,
88
+ dtype=grad.dtype,
89
+ )
90
+ dist.all_to_all_single(
91
+ out,
92
+ grad,
93
+ output_split_sizes=ctx.input_split_sizes,
94
+ input_split_sizes=ctx.output_split_sizes,
95
+ group=ctx.group,
96
+ )
97
+ return out, None, None, None, None
98
+ return None, None, None, None, None
99
+
100
+
101
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
+ return AllToAllOp.apply(
103
+ x,
104
+ output_split_sizes,
105
+ input_split_sizes,
106
+ group,
107
+ async_op,
108
+ )
109
+
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
+ # (16, 1024),
114
+ # (32, 1024),
115
+ # (64, 1024),
116
+ # (128, 1024),
117
+ # (256, 1024),
118
+ # (512, 1024),
119
+ # (1024, 1024),
120
+ # (2 * 1024, 1024),
121
+ # (4 * 1024, 1024),
122
+ # (8 * 1024, 1024),
123
+ # (16 * 1024, 1024),
124
+ # (32 * 1024, 1024),
125
+ # (64 * 1024, 1024),
126
+ # (128 * 1024, 1024),
127
+ # (256 * 1024, 1024),
128
+ # (512 * 1024, 1024),
129
+ # (1024 * 1024, 1024),
130
  )
131
 
132
 
 
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
+ # time, std = benchmark_util.benchmark_function(benchmark)
149
+ time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
+ log_benchmark('All-To-All', details, time, std)
153
+ # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':