kernel
drbh commited on
Commit
13afbbe
·
1 Parent(s): 9354548

fix: add parallel forward functional logic

Browse files
Files changed (1) hide show
  1. torch-ext/megablocks/layers.py +195 -20
torch-ext/megablocks/layers.py CHANGED
@@ -121,7 +121,15 @@ def scale_grad(
121
 
122
 
123
  # Forward pass for the MLP layer
124
- def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
 
 
 
 
 
 
 
 
125
  # Scale weights
126
  w1 = scale_grad(w1, gradient_scale)
127
  w2 = scale_grad(w2, gradient_scale)
@@ -144,8 +152,6 @@ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float =
144
  return torch.bmm(x, w2) + w2_bias[..., None, :]
145
 
146
 
147
- ## START: Load Balancing Loss (unused at the moment)
148
-
149
  # Global variable to store load balancing loss
150
  _LOAD_BALANCING_LOSS = []
151
 
@@ -234,9 +240,6 @@ def batched_load_balancing_loss(args):
234
  return scale * torch.dot(tokens_per_expert, expert_scores)
235
 
236
 
237
- ## END Load Balancing Loss
238
-
239
-
240
  # Calculate the expert capacity based on tokens, top_k, number of experts,
241
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
  def expert_capacity(
@@ -410,7 +413,6 @@ def forward_once(
410
  return x, tokens_per_expert
411
 
412
 
413
- # TODO: replace with functional logic once aligned with ref
414
  def parallel_forward_once(
415
  x: torch.Tensor,
416
  expert_weights: torch.Tensor,
@@ -429,15 +431,180 @@ def parallel_forward_once(
429
  moe_expert_model_parallelism: bool = True,
430
  hidden_size: int = 1152,
431
  ):
432
- pass
 
 
 
 
 
 
 
 
433
 
 
 
 
 
 
 
434
 
435
- class MyReplacementLayer(torch.nn.Module):
436
- # def __init__(self):
437
- # super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  def forward(
440
- # self,
441
  x: torch.Tensor,
442
  router_weight: torch.Tensor,
443
  moe_top_k: int,
@@ -446,7 +613,6 @@ class MyReplacementLayer(torch.nn.Module):
446
  moe_normalize_expert_weights: int = None,
447
  uniform_expert_assignment: bool = False,
448
  training: bool = False,
449
- #
450
  w1: torch.Tensor = None,
451
  w2: torch.Tensor = None,
452
  w1_bias: torch.Tensor = None,
@@ -522,7 +688,6 @@ class MyReplacementLayer(torch.nn.Module):
522
  return x, expert_weights, router_scores
523
 
524
 
525
-
526
  class MegaBlocksMoeMLP(torch.nn.Module):
527
 
528
  def forward(
@@ -536,11 +701,21 @@ class MegaBlocksMoeMLP(torch.nn.Module):
536
  w2 = self.experts.down_proj.data
537
  w1_bias = self.experts.gate_up_proj_bias.data
538
  w2_bias = self.experts.down_proj_bias.data
539
- expert_parallel_group = None
540
 
541
- sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
 
 
 
 
 
 
 
 
 
 
 
 
542
  hidden_size = self.experts.hidden_size
543
-
544
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
  x=x,
546
  router_weight=router_weight,
@@ -559,8 +734,8 @@ class MegaBlocksMoeMLP(torch.nn.Module):
559
  sort_end_bit=sort_end_bit,
560
  expert_parallel_group=expert_parallel_group,
561
  moe_capacity_factor=1.0,
562
- moe_expert_model_parallelism=False,
563
- forward_fn=forward_once,
564
  hidden_size=hidden_size,
565
  )
566
- return output, expert_weights_out
 
121
 
122
 
123
  # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
  # Scale weights
134
  w1 = scale_grad(w1, gradient_scale)
135
  w2 = scale_grad(w2, gradient_scale)
 
152
  return torch.bmm(x, w2) + w2_bias[..., None, :]
153
 
154
 
 
 
155
  # Global variable to store load balancing loss
156
  _LOAD_BALANCING_LOSS = []
157
 
 
240
  return scale * torch.dot(tokens_per_expert, expert_scores)
241
 
242
 
 
 
 
243
  # Calculate the expert capacity based on tokens, top_k, number of experts,
244
  # expert parallel group, capacity factor, and whether expert model parallelism is used.
245
  def expert_capacity(
 
413
  return x, tokens_per_expert
414
 
415
 
 
416
  def parallel_forward_once(
417
  x: torch.Tensor,
418
  expert_weights: torch.Tensor,
 
431
  moe_expert_model_parallelism: bool = True,
432
  hidden_size: int = 1152,
433
  ):
434
+ # Flatten inputs
435
+ expert_weights = expert_weights.flatten()
436
+ top_experts = top_experts.flatten()
437
+
438
+ with torch.no_grad():
439
+ # Step 1: Local permutation setup
440
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
441
+ top_experts, sort_end_bit, num_experts
442
+ )
443
 
444
+ # Calculate sharding parameters
445
+ world_size = dist.get_world_size(expert_parallel_group)
446
+ hidden_sharding_deg = hidden_sharding_degree(
447
+ world_size, num_experts, hidden_size
448
+ )
449
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
450
 
451
+ # Replicate token counts for hidden sharding
452
+ repeated_tokens_per_expert = ops.repeat(
453
+ tokens_per_expert, (hidden_sharding_deg,)
454
+ )
455
+
456
+ # Exchange token counts across devices
457
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
458
+ # print("world_size:", world_size)
459
+ # print("experts_per_rank_val:", experts_per_rank_val)
460
+
461
+ # Ensure CUB knows which device to use
462
+ tpe_handle = dist.all_to_all_single(
463
+ parallel_tokens_per_expert,
464
+ repeated_tokens_per_expert,
465
+ group=expert_parallel_group,
466
+ async_op=True,
467
+ )
468
+
469
+ # Step 2: Local permutation - group tokens by target device
470
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
471
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
472
+
473
+ # Step 3: Compute communication counts and exchange tokens
474
+ with torch.no_grad():
475
+ tpe_handle.wait()
476
+
477
+ # Reshape for per-device calculations
478
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
479
+ world_size, experts_per_rank_val
480
+ )
481
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
482
+ world_size, experts_per_rank_val
483
+ )
484
+
485
+ # Calculate send/recv counts
486
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
487
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
488
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
489
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
490
+ tokens_received = sum(recv_counts)
491
+
492
+ # Replicate for hidden sharding
493
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
494
+
495
+ # Cross-device token exchange
496
+ parallel_x, parallel_x_handle = ops.all_to_all(
497
+ x,
498
+ recv_counts,
499
+ send_counts,
500
+ expert_parallel_group,
501
+ async_op=True
502
+ )
503
 
504
+ with torch.no_grad():
505
+ # Step 4: Setup for local expert computation
506
+ replicate_bins = ops.inclusive_cumsum(
507
+ parallel_tokens_per_expert.flatten(),
508
+ 0
509
+ )
510
+ replicate_bins = (
511
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
512
+ )
513
+
514
+ # Create expert indices for received tokens
515
+ parallel_top_expert = torch.remainder(
516
+ torch.arange(
517
+ num_experts * hidden_sharding_deg,
518
+ dtype=torch.int32,
519
+ device=indices.device,
520
+ ),
521
+ experts_per_rank_val,
522
+ )
523
+ parallel_top_expert = ops.replicate(
524
+ parallel_top_expert.unsqueeze(dim=0),
525
+ replicate_bins,
526
+ tokens_received,
527
+ ).flatten()
528
+
529
+ # Sort tokens by expert assignment
530
+ parallel_bin_ids, parallel_indices = ops.sort(
531
+ parallel_top_expert,
532
+ sort_end_bit,
533
+ )
534
+
535
+ # Calculate bins for local experts
536
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
537
+ dim=0, dtype=torch.int
538
+ )
539
+ parallel_bins = ops.inclusive_cumsum(
540
+ parallel_tokens_per_expert,
541
+ 0
542
+ )
543
+ parallel_bins = (
544
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
545
+ )
546
+
547
+ # Calculate expert capacity
548
+ expert_capacity = expert_capacity_fn(
549
+ tokens_received,
550
+ top_k,
551
+ experts_per_rank_val,
552
+ expert_parallel_group,
553
+ moe_capacity_factor,
554
+ moe_expert_model_parallelism,
555
+ )
556
+ if expert_capacity == 0:
557
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
558
+
559
+ # Locally permute the tokens and perform the expert computation.
560
+ # Block to make sure that the cross-device permutation is complete.
561
+ # if self.args.mlp_impl == 'grouped':
562
+
563
+ # TODO: dont always assume grouped MLP
564
+ if True:
565
+ # GroupedMLP requires counts on CPU. We can use the tensor already
566
+ # moved to CPU for the prior all_to_all, which avoids an extra
567
+ # device synchronization.
568
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
569
+ dim=0,
570
+ dtype=torch.int,
571
+ )
572
+
573
+ # Step 5: Expert computation
574
+ parallel_x_handle.wait()
575
+
576
+ parallel_x = permute_and_compute(
577
+ parallel_x,
578
+ parallel_tokens_per_expert,
579
+ parallel_indices,
580
+ parallel_bin_ids,
581
+ None, # expert_weights
582
+ parallel_bins,
583
+ expert_capacity,
584
+ top_k=1,
585
+ w1=w1,
586
+ w2=w2,
587
+ w1_bias=w1_bias,
588
+ w2_bias=w2_bias,
589
+ gradient_scale=gradient_scale,
590
+ alpha=alpha,
591
+ )
592
+
593
+ # Step 6: Reverse communication - send results back
594
+ x, _ = ops.all_to_all(parallel_x, send_counts, recv_counts, expert_parallel_group)
595
+
596
+ # Step 7: Reduce across hidden sharding dimension
597
+ shape = (hidden_sharding_deg, -1, hidden_size)
598
+ x = x.view(shape).sum(dim=0)
599
+
600
+ # Step 8: Final local unpermutation
601
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
602
+
603
+ return x, tokens_per_expert.flatten()
604
+
605
+
606
+ class MyReplacementLayer(torch.nn.Module):
607
  def forward(
 
608
  x: torch.Tensor,
609
  router_weight: torch.Tensor,
610
  moe_top_k: int,
 
613
  moe_normalize_expert_weights: int = None,
614
  uniform_expert_assignment: bool = False,
615
  training: bool = False,
 
616
  w1: torch.Tensor = None,
617
  w2: torch.Tensor = None,
618
  w1_bias: torch.Tensor = None,
 
688
  return x, expert_weights, router_scores
689
 
690
 
 
691
  class MegaBlocksMoeMLP(torch.nn.Module):
692
 
693
  def forward(
 
701
  w2 = self.experts.down_proj.data
702
  w1_bias = self.experts.gate_up_proj_bias.data
703
  w2_bias = self.experts.down_proj_bias.data
 
704
 
705
+ # check if the expert_parallel_group attribute is set
706
+ if hasattr(self, "expert_parallel_group"):
707
+ expert_parallel_group = self.expert_parallel_group
708
+ moe_expert_model_parallelism = True
709
+ forward_fn = parallel_forward_once
710
+ else:
711
+ expert_parallel_group = None
712
+ moe_expert_model_parallelism = False
713
+ forward_fn = forward_once
714
+
715
+ sort_end_bit = max(
716
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
717
+ )
718
  hidden_size = self.experts.hidden_size
 
719
  output, expert_weights_out, router_scores = MyReplacementLayer.forward(
720
  x=x,
721
  router_weight=router_weight,
 
734
  sort_end_bit=sort_end_bit,
735
  expert_parallel_group=expert_parallel_group,
736
  moe_capacity_factor=1.0,
737
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
738
+ forward_fn=forward_fn,
739
  hidden_size=hidden_size,
740
  )
741
+ return output, expert_weights_out