kernel
drbh commited on
Commit
e47036a
·
1 Parent(s): 0e97a7c

fix: improve expert parallel implementation and refactors

Browse files
tests/parallel_layer_test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.multiprocessing as mp
4
+ import os
5
+
6
+
7
+ def test_megablocks_moe_mlp_import():
8
+ from megablocks.layers import MegaBlocksMoeMLP
9
+
10
+ assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."
11
+
12
+
13
+ def run_distributed_test(rank, world_size):
14
+ from megablocks.layers import MegaBlocksMoeMLP
15
+
16
+ os.environ["MASTER_ADDR"] = "localhost"
17
+ os.environ["MASTER_PORT"] = "12355"
18
+ os.environ["RANK"] = str(rank)
19
+ os.environ["WORLD_SIZE"] = str(world_size)
20
+
21
+ dist.init_process_group(
22
+ backend="gloo",
23
+ rank=rank,
24
+ world_size=world_size,
25
+ )
26
+
27
+ expert_parallel_group = torch.distributed.new_group(
28
+ range(torch.distributed.get_world_size())
29
+ )
30
+
31
+ model = MegaBlocksMoeMLP()
32
+ model.expert_parallel_group = expert_parallel_group
33
+
34
+ class Experts:
35
+ def __init__(self):
36
+ self.gate_up_proj = None
37
+ self.gate_up_proj_bias = None
38
+ self.down_proj = None
39
+ self.down_proj_bias = None
40
+ self.hidden_size = None
41
+
42
+ model.experts = Experts()
43
+
44
+ num_experts = 128
45
+ hidden_size = 1152
46
+ intermediate_size = 3072
47
+
48
+ ne, hs, isz = num_experts, hidden_size, intermediate_size
49
+
50
+ experts_per_rank = ne // world_size
51
+
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+
54
+ model.router = torch.nn.Linear(hs, ne).to(device)
55
+ model.router.weight.data.fill_(1)
56
+
57
+ e = model.experts
58
+ e.gate_up_proj = torch.nn.Parameter(
59
+ torch.ones(experts_per_rank, hs, isz, device=device)
60
+ )
61
+ e.gate_up_proj_bias = torch.nn.Parameter(
62
+ torch.zeros(experts_per_rank, isz, device=device)
63
+ )
64
+ e.down_proj = torch.nn.Parameter(
65
+ torch.ones(experts_per_rank, 1536, hs, device=device)
66
+ )
67
+ e.down_proj_bias = torch.nn.Parameter(
68
+ torch.zeros(experts_per_rank, hs, device=device)
69
+ )
70
+ e.hidden_size = hs
71
+
72
+ x = torch.randn(1, 1, 1152).to(device)
73
+ output, expert_weights_out = model(x)
74
+
75
+ assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}."
76
+
77
+ print(f"Rank {rank}: Test passed! Output shape: {output.shape}")
78
+
79
+ dist.destroy_process_group()
80
+
81
+
82
+ def test_megablocks_moe_mlp_functionality():
83
+ world_size = 2
84
+
85
+ mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True)
86
+
87
+ print("Multi-process test completed successfully!")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ test_megablocks_moe_mlp_import()
92
+ print("Import test passed!")
93
+
94
+ test_megablocks_moe_mlp_functionality()
torch-ext/megablocks/layers.py CHANGED
@@ -333,7 +333,6 @@ def permute_and_compute(
333
  gradient_scale,
334
  alpha,
335
  ):
336
- """Permute tokens and compute expert outputs."""
337
  # Route tokens to experts
338
  x = x.view(-1, x.shape[-1])
339
 
@@ -367,6 +366,7 @@ def forward_once(
367
  expert_parallel_group: int = None,
368
  moe_capacity_factor: float = 1.0,
369
  moe_expert_model_parallelism: bool = False,
 
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
@@ -430,11 +430,15 @@ def parallel_forward_once(
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
 
433
  ):
434
  # Flatten inputs
435
  expert_weights = expert_weights.flatten()
436
  top_experts = top_experts.flatten()
437
 
 
 
 
438
  with torch.no_grad():
439
  # Step 1: Local permutation setup
440
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
@@ -455,9 +459,7 @@ def parallel_forward_once(
455
 
456
  # Exchange token counts across devices
457
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
- # print("world_size:", world_size)
459
- # print("experts_per_rank_val:", experts_per_rank_val)
460
-
461
  # Ensure CUB knows which device to use
462
  tpe_handle = dist.all_to_all_single(
463
  parallel_tokens_per_expert,
@@ -493,20 +495,13 @@ def parallel_forward_once(
493
  x = ops.repeat(x, (hidden_sharding_deg, 1))
494
 
495
  # Cross-device token exchange
496
- parallel_x, parallel_x_handle = ops.all_to_all(
497
- x,
498
- recv_counts,
499
- send_counts,
500
- expert_parallel_group,
501
- async_op=True
502
  )
503
 
504
  with torch.no_grad():
505
  # Step 4: Setup for local expert computation
506
- replicate_bins = ops.inclusive_cumsum(
507
- parallel_tokens_per_expert.flatten(),
508
- 0
509
- )
510
  replicate_bins = (
511
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
  )
@@ -528,7 +523,7 @@ def parallel_forward_once(
528
 
529
  # Sort tokens by expert assignment
530
  parallel_bin_ids, parallel_indices = ops.sort(
531
- parallel_top_expert,
532
  sort_end_bit,
533
  )
534
 
@@ -536,10 +531,7 @@ def parallel_forward_once(
536
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
  dim=0, dtype=torch.int
538
  )
539
- parallel_bins = ops.inclusive_cumsum(
540
- parallel_tokens_per_expert,
541
- 0
542
- )
543
  parallel_bins = (
544
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
  )
@@ -558,10 +550,7 @@ def parallel_forward_once(
558
 
559
  # Locally permute the tokens and perform the expert computation.
560
  # Block to make sure that the cross-device permutation is complete.
561
- # if self.args.mlp_impl == 'grouped':
562
-
563
- # TODO: dont always assume grouped MLP
564
- if True:
565
  # GroupedMLP requires counts on CPU. We can use the tensor already
566
  # moved to CPU for the prior all_to_all, which avoids an extra
567
  # device synchronization.
@@ -591,7 +580,9 @@ def parallel_forward_once(
591
  )
592
 
593
  # Step 6: Reverse communication - send results back
594
- x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
 
 
595
 
596
  # Step 7: Reduce across hidden sharding dimension
597
  shape = (hidden_sharding_deg, -1, hidden_size)
@@ -603,139 +594,135 @@ def parallel_forward_once(
603
  return x, tokens_per_expert.flatten()
604
 
605
 
606
- class MyReplacementLayer(torch.nn.Module):
607
- def forward(
608
- x: torch.Tensor,
609
- router_weight: torch.Tensor,
610
- moe_top_k: int,
611
- moe_num_experts: int,
612
- moe_jitter_eps: float = None,
613
- moe_normalize_expert_weights: int = None,
614
- uniform_expert_assignment: bool = False,
615
- training: bool = False,
616
- w1: torch.Tensor = None,
617
- w2: torch.Tensor = None,
618
- w1_bias: torch.Tensor = None,
619
- w2_bias: torch.Tensor = None,
620
- gradient_scale: Optional[float] = None,
621
- alpha: float = 1.702,
622
- sort_end_bit: int = 0,
623
- expert_parallel_group: torch.distributed.ProcessGroup = None,
624
- moe_capacity_factor: float = 1.0,
625
- moe_expert_model_parallelism: bool = False,
626
- forward_fn: Any = None,
627
- hidden_size: int = None, # Required for parallel forward
628
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
-
630
- # Route tokens to experts
631
- logits, expert_weights, expert_indices = route_tokens(
632
- x,
633
- router_weight,
634
- moe_top_k,
635
- moe_num_experts,
636
- moe_jitter_eps,
637
- moe_normalize_expert_weights,
638
- uniform_expert_assignment,
639
- training,
640
- )
641
 
642
- # Create router scores for output
643
- router_scores = (
644
- torch.zeros_like(logits)
645
- .scatter_(1, expert_indices, expert_weights)
646
- .transpose(0, 1)
647
- )
 
 
 
 
 
648
 
649
- in_shape = x.size()
650
-
651
- # Prepare forward function arguments
652
- forward_args = {
653
- "x": x,
654
- "expert_weights": expert_weights,
655
- "top_experts": expert_indices,
656
- "w1": w1,
657
- "w2": w2,
658
- "w1_bias": w1_bias,
659
- "w2_bias": w2_bias,
660
- "gradient_scale": gradient_scale,
661
- "alpha": alpha,
662
- "sort_end_bit": sort_end_bit,
663
- "top_k": moe_top_k,
664
- "num_experts": moe_num_experts,
665
- "expert_parallel_group": expert_parallel_group,
666
- "moe_capacity_factor": moe_capacity_factor,
667
- "moe_expert_model_parallelism": moe_expert_model_parallelism,
668
- }
669
-
670
- # Add hidden_size for parallel forward
671
- if moe_expert_model_parallelism and hidden_size is not None:
672
- forward_args["hidden_size"] = hidden_size
673
- elif moe_expert_model_parallelism and hidden_size is None:
674
- # Infer hidden_size from input shape
675
- forward_args["hidden_size"] = x.shape[-1]
676
-
677
- # Compute expert outputs
678
- x, tokens_per_expert = forward_fn(**forward_args)
679
-
680
- # Save load balancing loss if needed
681
- moe_loss_weight = 0.0 # Can be made configurable
682
- if training and moe_loss_weight > 0:
683
- save_load_balancing_loss((tokens_per_expert, logits))
684
-
685
- # Restore original shape
686
- x = x.view(in_shape)
687
-
688
- return x, expert_weights, router_scores
 
 
 
 
 
 
 
 
689
 
690
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
- def forward(
694
- self,
695
- x: torch.Tensor,
696
- ) -> torch.Tensor:
697
- router_weight = self.router.weight
698
- moe_top_k = 4
699
- moe_num_experts = 128
700
- w1 = self.experts.gate_up_proj.data
701
- w2 = self.experts.down_proj.data
702
- w1_bias = self.experts.gate_up_proj_bias.data
703
- w2_bias = self.experts.down_proj_bias.data
704
-
705
- # check if the expert_parallel_group attribute is set
706
- if hasattr(self, "expert_parallel_group"):
707
- expert_parallel_group = self.expert_parallel_group
708
- moe_expert_model_parallelism = True
709
- forward_fn = parallel_forward_once
710
- else:
711
- expert_parallel_group = None
712
- moe_expert_model_parallelism = False
713
- forward_fn = forward_once
714
 
 
 
 
715
  sort_end_bit = max(
716
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
  )
718
- hidden_size = self.experts.hidden_size
719
- output, expert_weights_out, router_scores = MyReplacementLayer.forward(
 
720
  x=x,
721
- router_weight=router_weight,
722
  moe_top_k=moe_top_k,
723
  moe_num_experts=moe_num_experts,
724
- moe_jitter_eps=None,
725
- moe_normalize_expert_weights=None,
726
- uniform_expert_assignment=False,
727
- training=False,
728
- w1=w1,
729
- w2=w2,
730
- w1_bias=w1_bias,
731
- w2_bias=w2_bias,
732
- gradient_scale=None,
733
- alpha=1.702,
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
- moe_capacity_factor=1.0,
737
- moe_expert_model_parallelism=moe_expert_model_parallelism,
738
  forward_fn=forward_fn,
739
- hidden_size=hidden_size,
 
740
  )
741
  return output, expert_weights_out
 
333
  gradient_scale,
334
  alpha,
335
  ):
 
336
  # Route tokens to experts
337
  x = x.view(-1, x.shape[-1])
338
 
 
366
  expert_parallel_group: int = None,
367
  moe_capacity_factor: float = 1.0,
368
  moe_expert_model_parallelism: bool = False,
369
+ mlp_impl: Optional[str] = None,
370
  ):
371
  # x: [sl, bs, hs]
372
  # expert_weights: [sl * bs, top-k]
 
430
  moe_capacity_factor: float = 1.0,
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
+ mlp_impl: Optional[str] = "grouped",
434
  ):
435
  # Flatten inputs
436
  expert_weights = expert_weights.flatten()
437
  top_experts = top_experts.flatten()
438
 
439
+ # TODO: remove debugging var
440
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
441
+
442
  with torch.no_grad():
443
  # Step 1: Local permutation setup
444
  indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
 
459
 
460
  # Exchange token counts across devices
461
  parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
462
+
 
 
463
  # Ensure CUB knows which device to use
464
  tpe_handle = dist.all_to_all_single(
465
  parallel_tokens_per_expert,
 
495
  x = ops.repeat(x, (hidden_sharding_deg, 1))
496
 
497
  # Cross-device token exchange
498
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
499
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
 
 
 
 
500
  )
501
 
502
  with torch.no_grad():
503
  # Step 4: Setup for local expert computation
504
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
 
 
 
505
  replicate_bins = (
506
  replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
507
  )
 
523
 
524
  # Sort tokens by expert assignment
525
  parallel_bin_ids, parallel_indices = ops.sort(
526
+ parallel_top_expert,
527
  sort_end_bit,
528
  )
529
 
 
531
  parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
532
  dim=0, dtype=torch.int
533
  )
534
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
 
 
 
535
  parallel_bins = (
536
  parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
537
  )
 
550
 
551
  # Locally permute the tokens and perform the expert computation.
552
  # Block to make sure that the cross-device permutation is complete.
553
+ if mlp_impl == "grouped":
 
 
 
554
  # GroupedMLP requires counts on CPU. We can use the tensor already
555
  # moved to CPU for the prior all_to_all, which avoids an extra
556
  # device synchronization.
 
580
  )
581
 
582
  # Step 6: Reverse communication - send results back
583
+ x, _ = _layers.all_to_all.all_to_all(
584
+ parallel_x, send_counts, recv_counts, expert_parallel_group
585
+ )
586
 
587
  # Step 7: Reduce across hidden sharding dimension
588
  shape = (hidden_sharding_deg, -1, hidden_size)
 
594
  return x, tokens_per_expert.flatten()
595
 
596
 
597
+ def moe_forward(
598
+ x: torch.Tensor,
599
+ router_weight: torch.Tensor,
600
+ moe_top_k: int,
601
+ moe_num_experts: int,
602
+ moe_jitter_eps: float = None,
603
+ moe_normalize_expert_weights: int = None,
604
+ uniform_expert_assignment: bool = False,
605
+ training: bool = False,
606
+ w1: torch.Tensor = None,
607
+ w2: torch.Tensor = None,
608
+ w1_bias: torch.Tensor = None,
609
+ w2_bias: torch.Tensor = None,
610
+ gradient_scale: Optional[float] = None,
611
+ alpha: float = 1.702,
612
+ sort_end_bit: int = 0,
613
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
614
+ moe_capacity_factor: float = 1.0,
615
+ moe_expert_model_parallelism: bool = False,
616
+ forward_fn: Any = None,
617
+ hidden_size: int = None,
618
+ mlp_impl: str = "grouped",
619
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Route tokens to experts
622
+ logits, expert_weights, expert_indices = route_tokens(
623
+ x,
624
+ router_weight,
625
+ moe_top_k,
626
+ moe_num_experts,
627
+ moe_jitter_eps,
628
+ moe_normalize_expert_weights,
629
+ uniform_expert_assignment,
630
+ training,
631
+ )
632
 
633
+ # Create router scores for output
634
+ router_scores = (
635
+ torch.zeros_like(logits)
636
+ .scatter_(1, expert_indices, expert_weights)
637
+ .transpose(0, 1)
638
+ )
639
+
640
+ in_shape = x.size()
641
+
642
+ # Prepare forward function arguments
643
+ forward_args = {
644
+ "x": x,
645
+ "expert_weights": expert_weights,
646
+ "top_experts": expert_indices,
647
+ "w1": w1,
648
+ "w2": w2,
649
+ "w1_bias": w1_bias,
650
+ "w2_bias": w2_bias,
651
+ "gradient_scale": gradient_scale,
652
+ "alpha": alpha,
653
+ "sort_end_bit": sort_end_bit,
654
+ "top_k": moe_top_k,
655
+ "num_experts": moe_num_experts,
656
+ "expert_parallel_group": expert_parallel_group,
657
+ "moe_capacity_factor": moe_capacity_factor,
658
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
659
+ "mlp_impl": mlp_impl,
660
+ }
661
+
662
+ # Add hidden_size for parallel forward
663
+ if moe_expert_model_parallelism and hidden_size is not None:
664
+ forward_args["hidden_size"] = hidden_size
665
+ elif moe_expert_model_parallelism and hidden_size is None:
666
+ # Infer hidden_size from input shape
667
+ forward_args["hidden_size"] = x.shape[-1]
668
+
669
+ # Compute expert outputs
670
+ x, tokens_per_expert = forward_fn(**forward_args)
671
+
672
+ # Save load balancing loss if needed
673
+ moe_loss_weight = 0.0 # Can be made configurable
674
+ if training and moe_loss_weight > 0:
675
+ save_load_balancing_loss((tokens_per_expert, logits))
676
+
677
+ # Restore original shape
678
+ x = x.view(in_shape)
679
+
680
+ return x, expert_weights, router_scores
681
 
682
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self, "moe_top_k", 4)
687
+ moe_num_experts = getattr(self, "moe_num_experts", 128)
688
+ gradient_scale = getattr(self, "gradient_scale", None)
689
+ alpha = getattr(self, "alpha", 1.702)
690
+ moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(
693
+ self, "moe_normalize_expert_weights", None
694
+ )
695
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
 
 
 
 
 
 
 
 
 
 
696
 
697
+ has_parallel = hasattr(self, "expert_parallel_group")
698
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
+ forward_fn = parallel_forward_once if has_parallel else forward_once
700
  sort_end_bit = max(
701
  int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
  )
703
+ mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
+
705
+ output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
+ router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
709
  moe_num_experts=moe_num_experts,
710
+ moe_jitter_eps=moe_jitter_eps,
711
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
712
+ uniform_expert_assignment=uniform_expert_assignment,
713
+ training=self.training,
714
+ w1=self.experts.gate_up_proj,
715
+ w2=self.experts.down_proj,
716
+ w1_bias=self.experts.gate_up_proj_bias,
717
+ w2_bias=self.experts.down_proj_bias,
718
+ gradient_scale=gradient_scale,
719
+ alpha=alpha,
720
  sort_end_bit=sort_end_bit,
721
  expert_parallel_group=expert_parallel_group,
722
+ moe_capacity_factor=moe_capacity_factor,
723
+ moe_expert_model_parallelism=has_parallel,
724
  forward_fn=forward_fn,
725
+ hidden_size=self.experts.hidden_size,
726
+ mlp_impl=mlp_impl,
727
  )
728
  return output, expert_weights_out