kernel
drbh commited on
Commit
5268e56
·
1 Parent(s): e47036a

feat: bump build

Browse files
Files changed (41) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_13afbbe_dirty.abi3.so → _megablocks_e47036a.abi3.so} +2 -2
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +129 -142
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  5. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_13afbbe_dirty.abi3.so → _megablocks_e47036a.abi3.so} +2 -2
  6. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  7. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +129 -142
  8. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  9. build/{torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so → torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so} +2 -2
  10. build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  11. build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py +129 -142
  12. build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  13. build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so +3 -0
  14. build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  15. build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py +129 -142
  16. build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  17. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +0 -3
  18. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so +3 -0
  19. build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  20. build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py +129 -142
  21. build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  22. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +0 -3
  23. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so +3 -0
  24. build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  25. build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py +129 -142
  26. build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  27. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +0 -3
  28. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so +3 -0
  29. build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  30. build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +129 -142
  31. build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  32. build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +0 -3
  33. build/{torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so} +1 -1
  34. build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
  35. build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +129 -142
  36. build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
  37. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so +0 -3
  38. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so +3 -0
  39. build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +3 -3
  40. build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +129 -142
  41. build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +21 -121
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_13afbbe_dirty.abi3.so → _megablocks_e47036a.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5683ac8b3e98fc8b8ab19f964b0dbfb9a980b6135220b0a0c1b50180665ce341
3
- size 10517608
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b5370545c29afcc1d1d0cab8d5fce563647e26aec84eb17a66f66e10ddc92c9
3
+ size 10517576
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_13afbbe_dirty.abi3.so → _megablocks_e47036a.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b55d6ee3d41404603fdb75ad9a2949aa92e0224f7056fdbeb4c66934035ebd4b
3
- size 11869424
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f98b46218f277c881a492ffaf78791cd3af6b2c7707ea5883538008366b18569
3
+ size 11869392
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/{torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so → torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5c8c1b700d297741dd86e8c388e03913a30769ceb51b7c12a01245fbdf30128
3
- size 10510072
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8dc9a6a46b46860607a5e7f10374e80dee9e8a65782119e085c6700e199b85a
3
+ size 11931048
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8b7431f955f32750e9006f0c06c4ec0f99ca4a7ecd51585d8626d119db69084
3
+ size 10510040
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d915db521f8d37fb887ed8334db60165e5923f8dce817d69f6441c5ba2d210d6
3
- size 11857952
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c5245e311ba8b23b577be785d6dbd902a75b467ca70274c256042dd21ed235c
3
+ size 11857920
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:94a9a3bb426adceab66b39fe9d179b73e4524167aeb63bed5a67cd7734d31b24
3
- size 11923704
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4816243ebcbc505f3e9229dd55ef52c7e0d804a4f4ef67f9fe3e70932cf08027
3
+ size 11923672
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:aa9d1964e47ec6ff3c4ec77947f6a2a19868b03cec3618daf0555e011f69924d
3
- size 10517848
 
 
 
 
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c5a31570d353a8c392d72f018198c32c6522af0f5a1345426fd7cff5965c1cd
3
+ size 10517816
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b204da58db0f8be45dda62abd98b74a8e60f1f983bfc6a128c74ff66f67cf502
3
- size 11931112
 
 
 
 
build/{torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so → torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:516c5026180d4a8d013c500ed284a60ecbed4bc6c9dc084b838913f40327d1a6
3
  size 11931080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8122201f56a5575ec964e0450cac06affb44498cbf5b6b32870676a436821c15
3
  size 11931080
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_13afbbe_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f861a8bffedbbf14341d39355f3f43a7c24fee2b99bb9ea7b3a2b9ad21c7ee28
3
- size 17892656
 
 
 
 
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_e47036a.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cb62a4fdccfdac00069689f7a08d2fba56393adb0ca9e8cb7c085f6db919d55
3
+ size 17892624
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_13afbbe_dirty
3
- ops = torch.ops._megablocks_13afbbe_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_13afbbe_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_e47036a
3
+ ops = torch.ops._megablocks_e47036a
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_e47036a::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out
build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -7,126 +7,28 @@ import torch.distributed as dist
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
- # from .. import benchmark_util
11
-
12
- # Copyright 2024 Databricks
13
- # SPDX-License-Identifier: Apache-2.0
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def log_benchmark(name, arguments, time, std):
20
- print("=" * 60)
21
- print(f"{name} Benchmark")
22
- print("Benchmark Parameters:")
23
- for key, value in arguments.items():
24
- print(f"{key} = {value}")
25
- print("Results:")
26
- print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std))
27
- print("=" * 60)
28
-
29
-
30
- def benchmark_function(fn, iterations=100, warmup=10):
31
- print(f"Benchmarking {fn.__name__} with {iterations} iterations and {warmup} warmup iterations")
32
- # Warmup iterations.
33
- for _ in range(warmup):
34
- fn()
35
-
36
- times = []
37
- print(f"Running {iterations} iterations...")
38
- for i in range(iterations):
39
- start = torch.cuda.Event(enable_timing=True)
40
- end = torch.cuda.Event(enable_timing=True)
41
-
42
- start.record()
43
- fn()
44
- end.record()
45
-
46
- torch.cuda.synchronize()
47
- times.append(start.elapsed_time(end))
48
- return np.mean(times), np.std(times)
49
-
50
-
51
- # from .._layers.all_to_all import all_to_all
52
-
53
- # Copyright 2024 Databricks
54
- # SPDX-License-Identifier: Apache-2.0
55
-
56
- import torch
57
- import torch.distributed as dist
58
-
59
-
60
- class AllToAllOp(torch.autograd.Function):
61
-
62
- @staticmethod
63
- def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
64
- out = torch.empty(
65
- (sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype
66
- )
67
-
68
- ctx.input_shape = x.shape
69
- ctx.output_split_sizes = output_split_sizes
70
- ctx.input_split_sizes = input_split_sizes
71
- ctx.group = group
72
- handle = dist.all_to_all_single(
73
- out,
74
- x,
75
- output_split_sizes=output_split_sizes,
76
- input_split_sizes=input_split_sizes,
77
- group=group,
78
- async_op=async_op,
79
- )
80
- return out, handle
81
-
82
- @staticmethod
83
- def backward(ctx, grad, _):
84
- if ctx.needs_input_grad[0]:
85
- out = torch.empty(
86
- ctx.input_shape,
87
- device=grad.device,
88
- dtype=grad.dtype,
89
- )
90
- dist.all_to_all_single(
91
- out,
92
- grad,
93
- output_split_sizes=ctx.input_split_sizes,
94
- input_split_sizes=ctx.output_split_sizes,
95
- group=ctx.group,
96
- )
97
- return out, None, None, None, None
98
- return None, None, None, None, None
99
-
100
-
101
- def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
102
- return AllToAllOp.apply(
103
- x,
104
- output_split_sizes,
105
- input_split_sizes,
106
- group,
107
- async_op,
108
- )
109
-
110
 
111
  _ALL_TO_ALL_BENCHMARK = (
112
  (8, 1024),
113
- # (16, 1024),
114
- # (32, 1024),
115
- # (64, 1024),
116
- # (128, 1024),
117
- # (256, 1024),
118
- # (512, 1024),
119
- # (1024, 1024),
120
- # (2 * 1024, 1024),
121
- # (4 * 1024, 1024),
122
- # (8 * 1024, 1024),
123
- # (16 * 1024, 1024),
124
- # (32 * 1024, 1024),
125
- # (64 * 1024, 1024),
126
- # (128 * 1024, 1024),
127
- # (256 * 1024, 1024),
128
- # (512 * 1024, 1024),
129
- # (1024 * 1024, 1024),
130
  )
131
 
132
 
@@ -145,12 +47,10 @@ def benchmark_all_to_all(group, sl, hs):
145
  def benchmark():
146
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
147
 
148
- # time, std = benchmark_util.benchmark_function(benchmark)
149
- time, std = benchmark_function(benchmark)
150
 
151
  if dist.get_rank(group) == 0:
152
- log_benchmark('All-To-All', details, time, std)
153
- # benchmark_util.log_benchmark('All-To-All', details, time, std)
154
 
155
 
156
  if __name__ == '__main__':
 
7
  # from megablocks import benchmark_util
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
  )
33
 
34
 
 
47
  def benchmark():
48
  return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
 
50
+ time, std = benchmark_util.benchmark_function(benchmark)
 
51
 
52
  if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
 
54
 
55
 
56
  if __name__ == '__main__':