kernel
drbh commited on
Commit
2b84d84
Β·
1 Parent(s): dabb815

feat: bump builds

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +8 -5
  2. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/__init__.py +0 -0
  3. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/activation_fn.py +0 -0
  4. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/all_to_all.py +0 -0
  5. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/arguments.py +0 -0
  6. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/common.py +0 -0
  7. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/dmlp_registry.py +0 -0
  8. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/dmoe.py +0 -0
  9. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/gelu.py +0 -0
  10. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/glu.py +0 -0
  11. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/memory_test.py +0 -0
  12. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/mlp.py +0 -0
  13. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/moe.py +0 -0
  14. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/mpu.py +0 -0
  15. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/router.py +0 -0
  16. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/sharedexpert_registry.py +0 -0
  17. build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_6756875_dirty.abi3.so β†’ _megablocks_dabb815.abi3.so} +2 -2
  18. build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
  19. build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +566 -0
  20. build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +1 -1
  21. build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py +8 -5
  22. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/__init__.py +0 -0
  23. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/activation_fn.py +0 -0
  24. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/all_to_all.py +0 -0
  25. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/arguments.py +0 -0
  26. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/common.py +0 -0
  27. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/dmlp_registry.py +0 -0
  28. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/dmoe.py +0 -0
  29. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/gelu.py +0 -0
  30. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/glu.py +0 -0
  31. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/memory_test.py +0 -0
  32. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/mlp.py +0 -0
  33. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/moe.py +0 -0
  34. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/mpu.py +0 -0
  35. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/router.py +0 -0
  36. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/sharedexpert_registry.py +0 -0
  37. build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_6756875_dirty.abi3.so β†’ _megablocks_dabb815.abi3.so} +2 -2
  38. build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
  39. build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +566 -0
  40. build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +1 -1
  41. build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py +8 -5
  42. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/__init__.py +0 -0
  43. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/activation_fn.py +0 -0
  44. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/all_to_all.py +0 -0
  45. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/arguments.py +0 -0
  46. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/common.py +0 -0
  47. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/dmlp_registry.py +0 -0
  48. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/dmoe.py +0 -0
  49. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/gelu.py +0 -0
  50. build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/glu.py +0 -0
build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py CHANGED
@@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend
9
  from .grouped_gemm import ops as gg_ops
10
 
11
 
12
- from .layers.arguments import Arguments
13
- from .layers.dmoe import ParallelDroplessMLP, dMoE
14
- from .layers.glu import SparseGLU
15
- from .layers.mlp import MLP, SparseMLP
16
- from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
@@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten
176
 
177
  # Export public API
178
  __all__ = [
 
179
  # Direct kernel exports
180
  "exclusive_cumsum",
181
  "inclusive_cumsum",
 
9
  from .grouped_gemm import ops as gg_ops
10
 
11
 
12
+ from ._layers.arguments import Arguments
13
+ from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from ._layers.glu import SparseGLU
15
+ from ._layers.mlp import MLP, SparseMLP
16
+ from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
+
18
+ from . import layers
19
 
20
  # This section contains the direct kernel exports (not inlcuded in the original code)
21
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
178
 
179
  # Export public API
180
  __all__ = [
181
+ "MyReplacementLayer",
182
  # Direct kernel exports
183
  "exclusive_cumsum",
184
  "inclusive_cumsum",
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/__init__.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/activation_fn.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/all_to_all.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/arguments.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/common.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/dmlp_registry.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/dmoe.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/gelu.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/glu.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/memory_test.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/mlp.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/moe.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/mpu.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/router.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{layers β†’ _layers}/sharedexpert_registry.py RENAMED
File without changes
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_6756875_dirty.abi3.so β†’ _megablocks_dabb815.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ad46e9f244afa886c8a104d75e37f93afd2a0ecf83bfc7a414680fa16d8b78f9
3
- size 10517608
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a20cd4dc15095b8504db981c651e516e8a7d8394b99d973d632558637c8dba9
3
+ size 10517576
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_6756875_dirty
3
- ops = torch.ops._megablocks_6756875_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_6756875_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_dabb815
3
+ ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_dabb815::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ from typing import Optional, Any
5
+
6
+ from . import _layers
7
+ from . import ops
8
+
9
+
10
+ # Set the expert model parallel attributes on a tensor
11
+ def set_expert_model_parallel_attributes(
12
+ tensor: torch.Tensor,
13
+ is_parallel: bool,
14
+ ):
15
+ assert not hasattr(tensor, "expert_model_parallel")
16
+ setattr(tensor, "expert_model_parallel", is_parallel)
17
+
18
+
19
+ # Get the expert model parallel attributes from a tensor
20
+ def expert_sharding_degree(
21
+ world_size: int,
22
+ moe_num_experts: int,
23
+ ) -> int:
24
+ esd = min(world_size, moe_num_experts)
25
+ if (moe_num_experts % esd) != 0:
26
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
27
+ return esd
28
+
29
+
30
+ # Calculate the hidden sharding degree based on world size and expert sharding degree
31
+ def hidden_sharding_degree(
32
+ world_size: int,
33
+ moe_num_experts: int,
34
+ ffn_hidden_size: int,
35
+ ) -> int:
36
+ esd = expert_sharding_degree(world_size, moe_num_experts)
37
+ hsd = world_size // esd
38
+ if (ffn_hidden_size % hsd) != 0:
39
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
40
+ if (esd * hsd) != world_size:
41
+ raise ValueError(
42
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
43
+ )
44
+ return hsd
45
+
46
+
47
+ # Calculate the number of experts per rank based on world size and expert sharding degree
48
+ def experts_per_rank(
49
+ moe_num_experts: int,
50
+ world_size: int,
51
+ ) -> int:
52
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
53
+
54
+
55
+ # Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
56
+ def features_per_rank(
57
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
58
+ ) -> int:
59
+ return ffn_hidden_size // hidden_sharding_degree(
60
+ world_size, moe_num_experts, ffn_hidden_size
61
+ )
62
+
63
+
64
+ # Apply jitter to the input tensor
65
+ def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
66
+ low = 1.0 - moe_jitter_eps
67
+ high = 1.0 + moe_jitter_eps
68
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
69
+ return x * (low + noise * (high - low))
70
+
71
+
72
+ # Compute the top-k scores from the logits
73
+ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
74
+ if moe_top_k == 1:
75
+ return scores.max(dim=-1, keepdim=True)
76
+ return torch.topk(scores, moe_top_k, dim=-1)
77
+
78
+
79
+ # Route tokens to experts and compute expert weights and indices
80
+ def route_tokens(
81
+ x: torch.Tensor,
82
+ router_weight: torch.Tensor,
83
+ moe_top_k: int,
84
+ moe_num_experts: int,
85
+ moe_jitter_eps: float = None,
86
+ moe_normalize_expert_weights: int = None,
87
+ uniform_expert_assignment: bool = False,
88
+ training: bool = False,
89
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
90
+ if training and moe_jitter_eps is not None:
91
+ x = apply_jitter(x, moe_jitter_eps)
92
+
93
+ x_flat = x.view(-1, x.shape[-1])
94
+ logits = torch.nn.functional.linear(x_flat, router_weight)
95
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
+ expert_weights = expert_weights.softmax(dim=-1)
97
+ if moe_normalize_expert_weights is not None:
98
+ expert_weights = expert_weights / torch.norm(
99
+ expert_weights,
100
+ p=moe_normalize_expert_weights,
101
+ dim=-1,
102
+ keepdim=True,
103
+ )
104
+ if uniform_expert_assignment:
105
+ expert_indices = _layers.router._uniform_expert_assignment(
106
+ expert_indices,
107
+ moe_num_experts,
108
+ )
109
+
110
+ return logits, expert_weights, expert_indices
111
+
112
+
113
+ # Scale the gradient of the weights
114
+ def scale_grad(
115
+ w: torch.Tensor,
116
+ gradient_scale: Optional[float] = None,
117
+ ) -> torch.Tensor:
118
+ if gradient_scale is None:
119
+ return w
120
+ return _layers.mlp.scale_gradient(w, gradient_scale)
121
+
122
+
123
+ # Forward pass for the MLP layer
124
+ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
125
+ # Scale weights
126
+ w1 = scale_grad(w1, gradient_scale)
127
+ w2 = scale_grad(w2, gradient_scale)
128
+ w1_bias = scale_grad(w1_bias, gradient_scale)
129
+ w2_bias = scale_grad(w2_bias, gradient_scale)
130
+
131
+ # Resolve dtensors
132
+ w1 = _layers.mlp.resolve_dtensor(w1)
133
+ w2 = _layers.mlp.resolve_dtensor(w2)
134
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
135
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
136
+
137
+ # Forward pass
138
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
139
+ gate, up = gate_up.chunk(2, dim=-1)
140
+
141
+ glu = gate * torch.sigmoid(gate * alpha)
142
+ x = (up + 1) * glu
143
+
144
+ return torch.bmm(x, w2) + w2_bias[..., None, :]
145
+
146
+
147
+ ## START: Load Balancing Loss (unused at the moment)
148
+
149
+ # Global variable to store load balancing loss
150
+ _LOAD_BALANCING_LOSS = []
151
+
152
+
153
+ def save_load_balancing_loss(loss):
154
+ global _LOAD_BALANCING_LOSS
155
+ _LOAD_BALANCING_LOSS.append(loss)
156
+
157
+
158
+ def get_load_balancing_loss():
159
+ global _LOAD_BALANCING_LOSS
160
+ return _LOAD_BALANCING_LOSS
161
+
162
+
163
+ def clear_load_balancing_loss():
164
+ global _LOAD_BALANCING_LOSS
165
+ _LOAD_BALANCING_LOSS.clear()
166
+
167
+
168
+ def batched_load_balancing_loss(args):
169
+ if args.moe_loss_weight == 0:
170
+ return 0.0
171
+
172
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
173
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
174
+ if args.num_layers_per_virtual_pipeline_stage is not None:
175
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
176
+
177
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
178
+ raise ValueError(
179
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
180
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
181
+ f"{args.num_layers}\npipeline_model_parallel_size = "
182
+ f"{args.pipeline_model_parallel_size}\n"
183
+ "num_layers_per_virtual_pipeline_stage"
184
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
185
+ )
186
+ if len(expert_scores) != num_layers_per_pipeline_stage:
187
+ raise ValueError(
188
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
189
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
190
+ f"{args.num_layers}\npipeline_model_parallel_size = "
191
+ f"{args.pipeline_model_parallel_size}\n"
192
+ "num_layers_per_virtual_pipeline_stage"
193
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
194
+ )
195
+
196
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
197
+ assert all(
198
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
199
+ )
200
+
201
+ tokens = expert_scores[0].shape[0]
202
+ assert all(
203
+ (
204
+ (
205
+ x.ndim == 2
206
+ and x.shape[1] == args.moe_num_experts
207
+ and x.shape[0] == tokens
208
+ )
209
+ for x in expert_scores
210
+ )
211
+ )
212
+
213
+ # Concatenate the contributions of each layer and convert to
214
+ # the correct types and formats for the dot product.
215
+ expert_scores = torch.cat(expert_scores, dim=1)
216
+ if args.moe_lbl_in_fp32:
217
+ expert_scores = expert_scores.float()
218
+ if tokens != 0:
219
+ expert_scores = expert_scores.mean(dim=0)
220
+ else:
221
+ expert_scores = expert_scores.sum(dim=0)
222
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
223
+
224
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
225
+ assert tokens_per_expert.numel() == expected_values
226
+ assert expert_scores.numel() == expected_values
227
+
228
+ # Calculate the total scale across all factors.
229
+ #
230
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
231
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
232
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
233
+ scale = scale_numerator / scale_denominator
234
+ return scale * torch.dot(tokens_per_expert, expert_scores)
235
+
236
+
237
+ ## END Load Balancing Loss
238
+
239
+
240
+ # Calculate the expert capacity based on tokens, top_k, number of experts,
241
+ # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
+ def expert_capacity(
243
+ tokens: int,
244
+ top_k: int,
245
+ num_experts: int,
246
+ expert_parallel_group: int,
247
+ moe_capacity_factor: float,
248
+ moe_expert_model_parallelism: bool,
249
+ ) -> int:
250
+ world_size = (
251
+ dist.get_world_size(expert_parallel_group)
252
+ if moe_expert_model_parallelism
253
+ else 1
254
+ )
255
+
256
+ tokens_per_expert = top_k * tokens * world_size / num_experts
257
+ return int(moe_capacity_factor * tokens_per_expert)
258
+
259
+
260
+ def load_balancing_loss(
261
+ tokens_per_expert: torch.Tensor,
262
+ expert_scores: torch.Tensor,
263
+ top_k: int,
264
+ num_experts: int,
265
+ ):
266
+ assert len(expert_scores.size()) == 2
267
+ tokens, num_experts = expert_scores.size()
268
+ assert num_experts == num_experts
269
+ assert len(tokens_per_expert.size()) == 1
270
+ (num_experts,) = tokens_per_expert.size()
271
+ assert num_experts == num_experts
272
+ scale = num_experts / (tokens * top_k)
273
+ return scale * torch.dot(
274
+ tokens_per_expert.to(expert_scores.dtype),
275
+ expert_scores.mean(dim=0),
276
+ )
277
+
278
+
279
+ def indices_and_bins(
280
+ top_expert: torch.Tensor,
281
+ sort_end_bit: int,
282
+ num_experts: int,
283
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
284
+ top_expert = top_expert.int()
285
+
286
+ # Ensure contiguous memory layout
287
+ top_expert = top_expert.contiguous()
288
+
289
+ # Ensure CUB knows which device to use
290
+ with torch.cuda.device(top_expert.device):
291
+ output = ops.sort(top_expert, sort_end_bit)
292
+ bin_ids, indices = output
293
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
294
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
295
+
296
+ bins = bins.view(1) if not len(bins.size()) else bins
297
+ return indices, bin_ids, bins, tokens_per_expert
298
+
299
+
300
+ def expert_capacity_fn(
301
+ tokens: int,
302
+ top_k: int,
303
+ num_experts: int,
304
+ expert_parallel_group: torch.distributed.ProcessGroup,
305
+ moe_capacity_factor: float = 1.0,
306
+ moe_expert_model_parallelism: bool = False,
307
+ ) -> int:
308
+ world_size = (
309
+ dist.get_world_size(expert_parallel_group)
310
+ if moe_expert_model_parallelism
311
+ else 1
312
+ )
313
+ tokens_per_expert = top_k * tokens * world_size / num_experts
314
+ return int(moe_capacity_factor * tokens_per_expert)
315
+
316
+
317
+ def permute_and_compute(
318
+ x,
319
+ tokens_per_expert,
320
+ indices,
321
+ bin_ids,
322
+ expert_weights,
323
+ bins,
324
+ expert_capacity,
325
+ top_k,
326
+ w1,
327
+ w2,
328
+ w1_bias,
329
+ w2_bias,
330
+ gradient_scale,
331
+ alpha,
332
+ ):
333
+ """Permute tokens and compute expert outputs."""
334
+ # Route tokens to experts
335
+ x = x.view(-1, x.shape[-1])
336
+
337
+ # Ensure CUB knows which device to use
338
+ with torch.cuda.device(x.device):
339
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
340
+
341
+ # Expert computation
342
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
343
+
344
+ # Ensure CUB knows which device to use
345
+ with torch.cuda.device(x.device):
346
+ # Route tokens back
347
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
348
+ return out
349
+
350
+
351
+ def forward_once(
352
+ x: torch.Tensor,
353
+ expert_weights: torch.Tensor,
354
+ top_experts: torch.Tensor,
355
+ w1: torch.Tensor,
356
+ w2: torch.Tensor,
357
+ w1_bias: torch.Tensor,
358
+ w2_bias: torch.Tensor,
359
+ gradient_scale: Optional[float] = None,
360
+ alpha: float = 1.702,
361
+ sort_end_bit: int = 0,
362
+ top_k: int = 4,
363
+ num_experts: int = 128,
364
+ expert_parallel_group: int = None,
365
+ moe_capacity_factor: float = 1.0,
366
+ moe_expert_model_parallelism: bool = False,
367
+ ):
368
+ # x: [sl, bs, hs]
369
+ # expert_weights: [sl * bs, top-k]
370
+ # top_experts: [sl * bs, top-k]
371
+ expert_weights = expert_weights.flatten()
372
+ top_experts = top_experts.flatten()
373
+
374
+ with torch.no_grad():
375
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
376
+ top_experts, sort_end_bit, num_experts
377
+ )
378
+
379
+ # Calculate expert capacity
380
+ sl, bs, _ = x.size()
381
+
382
+ expert_capacity = expert_capacity_fn(
383
+ sl * bs,
384
+ top_k,
385
+ num_experts,
386
+ expert_parallel_group,
387
+ moe_capacity_factor,
388
+ moe_expert_model_parallelism,
389
+ )
390
+
391
+ if expert_capacity == 0:
392
+ expert_capacity = torch.max(tokens_per_expert).item()
393
+
394
+ x = permute_and_compute(
395
+ x,
396
+ tokens_per_expert,
397
+ indices,
398
+ bin_ids,
399
+ expert_weights,
400
+ bins,
401
+ expert_capacity,
402
+ top_k,
403
+ w1,
404
+ w2,
405
+ w1_bias,
406
+ w2_bias,
407
+ gradient_scale,
408
+ alpha,
409
+ )
410
+ return x, tokens_per_expert
411
+
412
+
413
+ # TODO: replace with functional logic once aligned with ref
414
+ def parallel_forward_once(
415
+ x: torch.Tensor,
416
+ expert_weights: torch.Tensor,
417
+ top_experts: torch.Tensor,
418
+ w1: torch.Tensor,
419
+ w2: torch.Tensor,
420
+ w1_bias: torch.Tensor,
421
+ w2_bias: torch.Tensor,
422
+ gradient_scale: Optional[float] = None,
423
+ alpha: float = 1.702,
424
+ sort_end_bit: int = 0,
425
+ top_k: int = 4,
426
+ num_experts: int = 128,
427
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
428
+ moe_capacity_factor: float = 1.0,
429
+ moe_expert_model_parallelism: bool = True,
430
+ hidden_size: int = 1152,
431
+ ):
432
+ pass
433
+
434
+
435
+ class MyReplacementLayer(torch.nn.Module):
436
+ # def __init__(self):
437
+ # super().__init__()
438
+
439
+ def forward(
440
+ # self,
441
+ x: torch.Tensor,
442
+ router_weight: torch.Tensor,
443
+ moe_top_k: int,
444
+ moe_num_experts: int,
445
+ moe_jitter_eps: float = None,
446
+ moe_normalize_expert_weights: int = None,
447
+ uniform_expert_assignment: bool = False,
448
+ training: bool = False,
449
+ #
450
+ w1: torch.Tensor = None,
451
+ w2: torch.Tensor = None,
452
+ w1_bias: torch.Tensor = None,
453
+ w2_bias: torch.Tensor = None,
454
+ gradient_scale: Optional[float] = None,
455
+ alpha: float = 1.702,
456
+ sort_end_bit: int = 0,
457
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
458
+ moe_capacity_factor: float = 1.0,
459
+ moe_expert_model_parallelism: bool = False,
460
+ forward_fn: Any = None,
461
+ hidden_size: int = None, # Required for parallel forward
462
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
463
+
464
+ # Route tokens to experts
465
+ logits, expert_weights, expert_indices = route_tokens(
466
+ x,
467
+ router_weight,
468
+ moe_top_k,
469
+ moe_num_experts,
470
+ moe_jitter_eps,
471
+ moe_normalize_expert_weights,
472
+ uniform_expert_assignment,
473
+ training,
474
+ )
475
+
476
+ # Create router scores for output
477
+ router_scores = (
478
+ torch.zeros_like(logits)
479
+ .scatter_(1, expert_indices, expert_weights)
480
+ .transpose(0, 1)
481
+ )
482
+
483
+ in_shape = x.size()
484
+
485
+ # Prepare forward function arguments
486
+ forward_args = {
487
+ "x": x,
488
+ "expert_weights": expert_weights,
489
+ "top_experts": expert_indices,
490
+ "w1": w1,
491
+ "w2": w2,
492
+ "w1_bias": w1_bias,
493
+ "w2_bias": w2_bias,
494
+ "gradient_scale": gradient_scale,
495
+ "alpha": alpha,
496
+ "sort_end_bit": sort_end_bit,
497
+ "top_k": moe_top_k,
498
+ "num_experts": moe_num_experts,
499
+ "expert_parallel_group": expert_parallel_group,
500
+ "moe_capacity_factor": moe_capacity_factor,
501
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
502
+ }
503
+
504
+ # Add hidden_size for parallel forward
505
+ if moe_expert_model_parallelism and hidden_size is not None:
506
+ forward_args["hidden_size"] = hidden_size
507
+ elif moe_expert_model_parallelism and hidden_size is None:
508
+ # Infer hidden_size from input shape
509
+ forward_args["hidden_size"] = x.shape[-1]
510
+
511
+ # Compute expert outputs
512
+ x, tokens_per_expert = forward_fn(**forward_args)
513
+
514
+ # Save load balancing loss if needed
515
+ moe_loss_weight = 0.0 # Can be made configurable
516
+ if training and moe_loss_weight > 0:
517
+ save_load_balancing_loss((tokens_per_expert, logits))
518
+
519
+ # Restore original shape
520
+ x = x.view(in_shape)
521
+
522
+ return x, expert_weights, router_scores
523
+
524
+
525
+
526
+ class MegaBlocksMoeMLP(torch.nn.Module):
527
+
528
+ def forward(
529
+ self,
530
+ x: torch.Tensor,
531
+ ) -> torch.Tensor:
532
+ router_weight = self.router.weight
533
+ moe_top_k = 4
534
+ moe_num_experts = 128
535
+ w1 = self.experts.gate_up_proj.data
536
+ w2 = self.experts.down_proj.data
537
+ w1_bias = self.experts.gate_up_proj_bias.data
538
+ w2_bias = self.experts.down_proj_bias.data
539
+ expert_parallel_group = None
540
+
541
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
542
+ hidden_size = self.experts.hidden_size
543
+
544
+ output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
+ x=x,
546
+ router_weight=router_weight,
547
+ moe_top_k=moe_top_k,
548
+ moe_num_experts=moe_num_experts,
549
+ moe_jitter_eps=None,
550
+ moe_normalize_expert_weights=None,
551
+ uniform_expert_assignment=False,
552
+ training=False,
553
+ w1=w1,
554
+ w2=w2,
555
+ w1_bias=w1_bias,
556
+ w2_bias=w2_bias,
557
+ gradient_scale=None,
558
+ alpha=1.702,
559
+ sort_end_bit=sort_end_bit,
560
+ expert_parallel_group=expert_parallel_group,
561
+ moe_capacity_factor=1.0,
562
+ moe_expert_model_parallelism=False,
563
+ forward_fn=forward_once,
564
+ hidden_size=hidden_size,
565
+ )
566
+ return output, expert_weights_out
build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
  from .. import benchmark_util
11
- from ..layers.all_to_all import all_to_all
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
 
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
  from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py CHANGED
@@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend
9
  from .grouped_gemm import ops as gg_ops
10
 
11
 
12
- from .layers.arguments import Arguments
13
- from .layers.dmoe import ParallelDroplessMLP, dMoE
14
- from .layers.glu import SparseGLU
15
- from .layers.mlp import MLP, SparseMLP
16
- from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
@@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten
176
 
177
  # Export public API
178
  __all__ = [
 
179
  # Direct kernel exports
180
  "exclusive_cumsum",
181
  "inclusive_cumsum",
 
9
  from .grouped_gemm import ops as gg_ops
10
 
11
 
12
+ from ._layers.arguments import Arguments
13
+ from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from ._layers.glu import SparseGLU
15
+ from ._layers.mlp import MLP, SparseMLP
16
+ from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
+
18
+ from . import layers
19
 
20
  # This section contains the direct kernel exports (not inlcuded in the original code)
21
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
178
 
179
  # Export public API
180
  __all__ = [
181
+ "MyReplacementLayer",
182
  # Direct kernel exports
183
  "exclusive_cumsum",
184
  "inclusive_cumsum",
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/__init__.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/activation_fn.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/all_to_all.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/arguments.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/common.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/dmlp_registry.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/dmoe.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/gelu.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/glu.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/memory_test.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/mlp.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/moe.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/mpu.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/router.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{layers β†’ _layers}/sharedexpert_registry.py RENAMED
File without changes
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_6756875_dirty.abi3.so β†’ _megablocks_dabb815.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1419672a07ed370d7107ca54a6b694f234efa8e696644ee4e96c1bf396aff6af
3
- size 11869424
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3f69e5978b727f08b43112c2321a222719aa824612d452029225a48976dfbb6
3
+ size 11869392
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _megablocks_6756875_dirty
3
- ops = torch.ops._megablocks_6756875_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_megablocks_6756875_dirty::{op_name}"
 
1
  import torch
2
+ from . import _megablocks_dabb815
3
+ ops = torch.ops._megablocks_dabb815
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_megablocks_dabb815::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ from typing import Optional, Any
5
+
6
+ from . import _layers
7
+ from . import ops
8
+
9
+
10
+ # Set the expert model parallel attributes on a tensor
11
+ def set_expert_model_parallel_attributes(
12
+ tensor: torch.Tensor,
13
+ is_parallel: bool,
14
+ ):
15
+ assert not hasattr(tensor, "expert_model_parallel")
16
+ setattr(tensor, "expert_model_parallel", is_parallel)
17
+
18
+
19
+ # Get the expert model parallel attributes from a tensor
20
+ def expert_sharding_degree(
21
+ world_size: int,
22
+ moe_num_experts: int,
23
+ ) -> int:
24
+ esd = min(world_size, moe_num_experts)
25
+ if (moe_num_experts % esd) != 0:
26
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
27
+ return esd
28
+
29
+
30
+ # Calculate the hidden sharding degree based on world size and expert sharding degree
31
+ def hidden_sharding_degree(
32
+ world_size: int,
33
+ moe_num_experts: int,
34
+ ffn_hidden_size: int,
35
+ ) -> int:
36
+ esd = expert_sharding_degree(world_size, moe_num_experts)
37
+ hsd = world_size // esd
38
+ if (ffn_hidden_size % hsd) != 0:
39
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
40
+ if (esd * hsd) != world_size:
41
+ raise ValueError(
42
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
43
+ )
44
+ return hsd
45
+
46
+
47
+ # Calculate the number of experts per rank based on world size and expert sharding degree
48
+ def experts_per_rank(
49
+ moe_num_experts: int,
50
+ world_size: int,
51
+ ) -> int:
52
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
53
+
54
+
55
+ # Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
56
+ def features_per_rank(
57
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
58
+ ) -> int:
59
+ return ffn_hidden_size // hidden_sharding_degree(
60
+ world_size, moe_num_experts, ffn_hidden_size
61
+ )
62
+
63
+
64
+ # Apply jitter to the input tensor
65
+ def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
66
+ low = 1.0 - moe_jitter_eps
67
+ high = 1.0 + moe_jitter_eps
68
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
69
+ return x * (low + noise * (high - low))
70
+
71
+
72
+ # Compute the top-k scores from the logits
73
+ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
74
+ if moe_top_k == 1:
75
+ return scores.max(dim=-1, keepdim=True)
76
+ return torch.topk(scores, moe_top_k, dim=-1)
77
+
78
+
79
+ # Route tokens to experts and compute expert weights and indices
80
+ def route_tokens(
81
+ x: torch.Tensor,
82
+ router_weight: torch.Tensor,
83
+ moe_top_k: int,
84
+ moe_num_experts: int,
85
+ moe_jitter_eps: float = None,
86
+ moe_normalize_expert_weights: int = None,
87
+ uniform_expert_assignment: bool = False,
88
+ training: bool = False,
89
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
90
+ if training and moe_jitter_eps is not None:
91
+ x = apply_jitter(x, moe_jitter_eps)
92
+
93
+ x_flat = x.view(-1, x.shape[-1])
94
+ logits = torch.nn.functional.linear(x_flat, router_weight)
95
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
+ expert_weights = expert_weights.softmax(dim=-1)
97
+ if moe_normalize_expert_weights is not None:
98
+ expert_weights = expert_weights / torch.norm(
99
+ expert_weights,
100
+ p=moe_normalize_expert_weights,
101
+ dim=-1,
102
+ keepdim=True,
103
+ )
104
+ if uniform_expert_assignment:
105
+ expert_indices = _layers.router._uniform_expert_assignment(
106
+ expert_indices,
107
+ moe_num_experts,
108
+ )
109
+
110
+ return logits, expert_weights, expert_indices
111
+
112
+
113
+ # Scale the gradient of the weights
114
+ def scale_grad(
115
+ w: torch.Tensor,
116
+ gradient_scale: Optional[float] = None,
117
+ ) -> torch.Tensor:
118
+ if gradient_scale is None:
119
+ return w
120
+ return _layers.mlp.scale_gradient(w, gradient_scale)
121
+
122
+
123
+ # Forward pass for the MLP layer
124
+ def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702):
125
+ # Scale weights
126
+ w1 = scale_grad(w1, gradient_scale)
127
+ w2 = scale_grad(w2, gradient_scale)
128
+ w1_bias = scale_grad(w1_bias, gradient_scale)
129
+ w2_bias = scale_grad(w2_bias, gradient_scale)
130
+
131
+ # Resolve dtensors
132
+ w1 = _layers.mlp.resolve_dtensor(w1)
133
+ w2 = _layers.mlp.resolve_dtensor(w2)
134
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
135
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
136
+
137
+ # Forward pass
138
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
139
+ gate, up = gate_up.chunk(2, dim=-1)
140
+
141
+ glu = gate * torch.sigmoid(gate * alpha)
142
+ x = (up + 1) * glu
143
+
144
+ return torch.bmm(x, w2) + w2_bias[..., None, :]
145
+
146
+
147
+ ## START: Load Balancing Loss (unused at the moment)
148
+
149
+ # Global variable to store load balancing loss
150
+ _LOAD_BALANCING_LOSS = []
151
+
152
+
153
+ def save_load_balancing_loss(loss):
154
+ global _LOAD_BALANCING_LOSS
155
+ _LOAD_BALANCING_LOSS.append(loss)
156
+
157
+
158
+ def get_load_balancing_loss():
159
+ global _LOAD_BALANCING_LOSS
160
+ return _LOAD_BALANCING_LOSS
161
+
162
+
163
+ def clear_load_balancing_loss():
164
+ global _LOAD_BALANCING_LOSS
165
+ _LOAD_BALANCING_LOSS.clear()
166
+
167
+
168
+ def batched_load_balancing_loss(args):
169
+ if args.moe_loss_weight == 0:
170
+ return 0.0
171
+
172
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
173
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
174
+ if args.num_layers_per_virtual_pipeline_stage is not None:
175
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
176
+
177
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
178
+ raise ValueError(
179
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
180
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
181
+ f"{args.num_layers}\npipeline_model_parallel_size = "
182
+ f"{args.pipeline_model_parallel_size}\n"
183
+ "num_layers_per_virtual_pipeline_stage"
184
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
185
+ )
186
+ if len(expert_scores) != num_layers_per_pipeline_stage:
187
+ raise ValueError(
188
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
189
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
190
+ f"{args.num_layers}\npipeline_model_parallel_size = "
191
+ f"{args.pipeline_model_parallel_size}\n"
192
+ "num_layers_per_virtual_pipeline_stage"
193
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
194
+ )
195
+
196
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
197
+ assert all(
198
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
199
+ )
200
+
201
+ tokens = expert_scores[0].shape[0]
202
+ assert all(
203
+ (
204
+ (
205
+ x.ndim == 2
206
+ and x.shape[1] == args.moe_num_experts
207
+ and x.shape[0] == tokens
208
+ )
209
+ for x in expert_scores
210
+ )
211
+ )
212
+
213
+ # Concatenate the contributions of each layer and convert to
214
+ # the correct types and formats for the dot product.
215
+ expert_scores = torch.cat(expert_scores, dim=1)
216
+ if args.moe_lbl_in_fp32:
217
+ expert_scores = expert_scores.float()
218
+ if tokens != 0:
219
+ expert_scores = expert_scores.mean(dim=0)
220
+ else:
221
+ expert_scores = expert_scores.sum(dim=0)
222
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
223
+
224
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
225
+ assert tokens_per_expert.numel() == expected_values
226
+ assert expert_scores.numel() == expected_values
227
+
228
+ # Calculate the total scale across all factors.
229
+ #
230
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
231
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
232
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
233
+ scale = scale_numerator / scale_denominator
234
+ return scale * torch.dot(tokens_per_expert, expert_scores)
235
+
236
+
237
+ ## END Load Balancing Loss
238
+
239
+
240
+ # Calculate the expert capacity based on tokens, top_k, number of experts,
241
+ # expert parallel group, capacity factor, and whether expert model parallelism is used.
242
+ def expert_capacity(
243
+ tokens: int,
244
+ top_k: int,
245
+ num_experts: int,
246
+ expert_parallel_group: int,
247
+ moe_capacity_factor: float,
248
+ moe_expert_model_parallelism: bool,
249
+ ) -> int:
250
+ world_size = (
251
+ dist.get_world_size(expert_parallel_group)
252
+ if moe_expert_model_parallelism
253
+ else 1
254
+ )
255
+
256
+ tokens_per_expert = top_k * tokens * world_size / num_experts
257
+ return int(moe_capacity_factor * tokens_per_expert)
258
+
259
+
260
+ def load_balancing_loss(
261
+ tokens_per_expert: torch.Tensor,
262
+ expert_scores: torch.Tensor,
263
+ top_k: int,
264
+ num_experts: int,
265
+ ):
266
+ assert len(expert_scores.size()) == 2
267
+ tokens, num_experts = expert_scores.size()
268
+ assert num_experts == num_experts
269
+ assert len(tokens_per_expert.size()) == 1
270
+ (num_experts,) = tokens_per_expert.size()
271
+ assert num_experts == num_experts
272
+ scale = num_experts / (tokens * top_k)
273
+ return scale * torch.dot(
274
+ tokens_per_expert.to(expert_scores.dtype),
275
+ expert_scores.mean(dim=0),
276
+ )
277
+
278
+
279
+ def indices_and_bins(
280
+ top_expert: torch.Tensor,
281
+ sort_end_bit: int,
282
+ num_experts: int,
283
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
284
+ top_expert = top_expert.int()
285
+
286
+ # Ensure contiguous memory layout
287
+ top_expert = top_expert.contiguous()
288
+
289
+ # Ensure CUB knows which device to use
290
+ with torch.cuda.device(top_expert.device):
291
+ output = ops.sort(top_expert, sort_end_bit)
292
+ bin_ids, indices = output
293
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
294
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
295
+
296
+ bins = bins.view(1) if not len(bins.size()) else bins
297
+ return indices, bin_ids, bins, tokens_per_expert
298
+
299
+
300
+ def expert_capacity_fn(
301
+ tokens: int,
302
+ top_k: int,
303
+ num_experts: int,
304
+ expert_parallel_group: torch.distributed.ProcessGroup,
305
+ moe_capacity_factor: float = 1.0,
306
+ moe_expert_model_parallelism: bool = False,
307
+ ) -> int:
308
+ world_size = (
309
+ dist.get_world_size(expert_parallel_group)
310
+ if moe_expert_model_parallelism
311
+ else 1
312
+ )
313
+ tokens_per_expert = top_k * tokens * world_size / num_experts
314
+ return int(moe_capacity_factor * tokens_per_expert)
315
+
316
+
317
+ def permute_and_compute(
318
+ x,
319
+ tokens_per_expert,
320
+ indices,
321
+ bin_ids,
322
+ expert_weights,
323
+ bins,
324
+ expert_capacity,
325
+ top_k,
326
+ w1,
327
+ w2,
328
+ w1_bias,
329
+ w2_bias,
330
+ gradient_scale,
331
+ alpha,
332
+ ):
333
+ """Permute tokens and compute expert outputs."""
334
+ # Route tokens to experts
335
+ x = x.view(-1, x.shape[-1])
336
+
337
+ # Ensure CUB knows which device to use
338
+ with torch.cuda.device(x.device):
339
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
340
+
341
+ # Expert computation
342
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
343
+
344
+ # Ensure CUB knows which device to use
345
+ with torch.cuda.device(x.device):
346
+ # Route tokens back
347
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
348
+ return out
349
+
350
+
351
+ def forward_once(
352
+ x: torch.Tensor,
353
+ expert_weights: torch.Tensor,
354
+ top_experts: torch.Tensor,
355
+ w1: torch.Tensor,
356
+ w2: torch.Tensor,
357
+ w1_bias: torch.Tensor,
358
+ w2_bias: torch.Tensor,
359
+ gradient_scale: Optional[float] = None,
360
+ alpha: float = 1.702,
361
+ sort_end_bit: int = 0,
362
+ top_k: int = 4,
363
+ num_experts: int = 128,
364
+ expert_parallel_group: int = None,
365
+ moe_capacity_factor: float = 1.0,
366
+ moe_expert_model_parallelism: bool = False,
367
+ ):
368
+ # x: [sl, bs, hs]
369
+ # expert_weights: [sl * bs, top-k]
370
+ # top_experts: [sl * bs, top-k]
371
+ expert_weights = expert_weights.flatten()
372
+ top_experts = top_experts.flatten()
373
+
374
+ with torch.no_grad():
375
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
376
+ top_experts, sort_end_bit, num_experts
377
+ )
378
+
379
+ # Calculate expert capacity
380
+ sl, bs, _ = x.size()
381
+
382
+ expert_capacity = expert_capacity_fn(
383
+ sl * bs,
384
+ top_k,
385
+ num_experts,
386
+ expert_parallel_group,
387
+ moe_capacity_factor,
388
+ moe_expert_model_parallelism,
389
+ )
390
+
391
+ if expert_capacity == 0:
392
+ expert_capacity = torch.max(tokens_per_expert).item()
393
+
394
+ x = permute_and_compute(
395
+ x,
396
+ tokens_per_expert,
397
+ indices,
398
+ bin_ids,
399
+ expert_weights,
400
+ bins,
401
+ expert_capacity,
402
+ top_k,
403
+ w1,
404
+ w2,
405
+ w1_bias,
406
+ w2_bias,
407
+ gradient_scale,
408
+ alpha,
409
+ )
410
+ return x, tokens_per_expert
411
+
412
+
413
+ # TODO: replace with functional logic once aligned with ref
414
+ def parallel_forward_once(
415
+ x: torch.Tensor,
416
+ expert_weights: torch.Tensor,
417
+ top_experts: torch.Tensor,
418
+ w1: torch.Tensor,
419
+ w2: torch.Tensor,
420
+ w1_bias: torch.Tensor,
421
+ w2_bias: torch.Tensor,
422
+ gradient_scale: Optional[float] = None,
423
+ alpha: float = 1.702,
424
+ sort_end_bit: int = 0,
425
+ top_k: int = 4,
426
+ num_experts: int = 128,
427
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
428
+ moe_capacity_factor: float = 1.0,
429
+ moe_expert_model_parallelism: bool = True,
430
+ hidden_size: int = 1152,
431
+ ):
432
+ pass
433
+
434
+
435
+ class MyReplacementLayer(torch.nn.Module):
436
+ # def __init__(self):
437
+ # super().__init__()
438
+
439
+ def forward(
440
+ # self,
441
+ x: torch.Tensor,
442
+ router_weight: torch.Tensor,
443
+ moe_top_k: int,
444
+ moe_num_experts: int,
445
+ moe_jitter_eps: float = None,
446
+ moe_normalize_expert_weights: int = None,
447
+ uniform_expert_assignment: bool = False,
448
+ training: bool = False,
449
+ #
450
+ w1: torch.Tensor = None,
451
+ w2: torch.Tensor = None,
452
+ w1_bias: torch.Tensor = None,
453
+ w2_bias: torch.Tensor = None,
454
+ gradient_scale: Optional[float] = None,
455
+ alpha: float = 1.702,
456
+ sort_end_bit: int = 0,
457
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
458
+ moe_capacity_factor: float = 1.0,
459
+ moe_expert_model_parallelism: bool = False,
460
+ forward_fn: Any = None,
461
+ hidden_size: int = None, # Required for parallel forward
462
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
463
+
464
+ # Route tokens to experts
465
+ logits, expert_weights, expert_indices = route_tokens(
466
+ x,
467
+ router_weight,
468
+ moe_top_k,
469
+ moe_num_experts,
470
+ moe_jitter_eps,
471
+ moe_normalize_expert_weights,
472
+ uniform_expert_assignment,
473
+ training,
474
+ )
475
+
476
+ # Create router scores for output
477
+ router_scores = (
478
+ torch.zeros_like(logits)
479
+ .scatter_(1, expert_indices, expert_weights)
480
+ .transpose(0, 1)
481
+ )
482
+
483
+ in_shape = x.size()
484
+
485
+ # Prepare forward function arguments
486
+ forward_args = {
487
+ "x": x,
488
+ "expert_weights": expert_weights,
489
+ "top_experts": expert_indices,
490
+ "w1": w1,
491
+ "w2": w2,
492
+ "w1_bias": w1_bias,
493
+ "w2_bias": w2_bias,
494
+ "gradient_scale": gradient_scale,
495
+ "alpha": alpha,
496
+ "sort_end_bit": sort_end_bit,
497
+ "top_k": moe_top_k,
498
+ "num_experts": moe_num_experts,
499
+ "expert_parallel_group": expert_parallel_group,
500
+ "moe_capacity_factor": moe_capacity_factor,
501
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
502
+ }
503
+
504
+ # Add hidden_size for parallel forward
505
+ if moe_expert_model_parallelism and hidden_size is not None:
506
+ forward_args["hidden_size"] = hidden_size
507
+ elif moe_expert_model_parallelism and hidden_size is None:
508
+ # Infer hidden_size from input shape
509
+ forward_args["hidden_size"] = x.shape[-1]
510
+
511
+ # Compute expert outputs
512
+ x, tokens_per_expert = forward_fn(**forward_args)
513
+
514
+ # Save load balancing loss if needed
515
+ moe_loss_weight = 0.0 # Can be made configurable
516
+ if training and moe_loss_weight > 0:
517
+ save_load_balancing_loss((tokens_per_expert, logits))
518
+
519
+ # Restore original shape
520
+ x = x.view(in_shape)
521
+
522
+ return x, expert_weights, router_scores
523
+
524
+
525
+
526
+ class MegaBlocksMoeMLP(torch.nn.Module):
527
+
528
+ def forward(
529
+ self,
530
+ x: torch.Tensor,
531
+ ) -> torch.Tensor:
532
+ router_weight = self.router.weight
533
+ moe_top_k = 4
534
+ moe_num_experts = 128
535
+ w1 = self.experts.gate_up_proj.data
536
+ w2 = self.experts.down_proj.data
537
+ w1_bias = self.experts.gate_up_proj_bias.data
538
+ w2_bias = self.experts.down_proj_bias.data
539
+ expert_parallel_group = None
540
+
541
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
542
+ hidden_size = self.experts.hidden_size
543
+
544
+ output, expert_weights_out, router_scores = MyReplacementLayer.forward(
545
+ x=x,
546
+ router_weight=router_weight,
547
+ moe_top_k=moe_top_k,
548
+ moe_num_experts=moe_num_experts,
549
+ moe_jitter_eps=None,
550
+ moe_normalize_expert_weights=None,
551
+ uniform_expert_assignment=False,
552
+ training=False,
553
+ w1=w1,
554
+ w2=w2,
555
+ w1_bias=w1_bias,
556
+ w2_bias=w2_bias,
557
+ gradient_scale=None,
558
+ alpha=1.702,
559
+ sort_end_bit=sort_end_bit,
560
+ expert_parallel_group=expert_parallel_group,
561
+ moe_capacity_factor=1.0,
562
+ moe_expert_model_parallelism=False,
563
+ forward_fn=forward_once,
564
+ hidden_size=hidden_size,
565
+ )
566
+ return output, expert_weights_out
build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py CHANGED
@@ -8,7 +8,7 @@ import torch.distributed as dist
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
  from .. import benchmark_util
11
- from ..layers.all_to_all import all_to_all
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
 
8
  # from megablocks.layers.all_to_all import all_to_all
9
 
10
  from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
12
 
13
  _ALL_TO_ALL_BENCHMARK = (
14
  (8, 1024),
build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py CHANGED
@@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend
9
  from .grouped_gemm import ops as gg_ops
10
 
11
 
12
- from .layers.arguments import Arguments
13
- from .layers.dmoe import ParallelDroplessMLP, dMoE
14
- from .layers.glu import SparseGLU
15
- from .layers.mlp import MLP, SparseMLP
16
- from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
 
 
17
 
18
  # This section contains the direct kernel exports (not inlcuded in the original code)
19
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
@@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten
176
 
177
  # Export public API
178
  __all__ = [
 
179
  # Direct kernel exports
180
  "exclusive_cumsum",
181
  "inclusive_cumsum",
 
9
  from .grouped_gemm import ops as gg_ops
10
 
11
 
12
+ from ._layers.arguments import Arguments
13
+ from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from ._layers.glu import SparseGLU
15
+ from ._layers.mlp import MLP, SparseMLP
16
+ from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
+
18
+ from . import layers
19
 
20
  # This section contains the direct kernel exports (not inlcuded in the original code)
21
  def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
 
178
 
179
  # Export public API
180
  __all__ = [
181
+ "MyReplacementLayer",
182
  # Direct kernel exports
183
  "exclusive_cumsum",
184
  "inclusive_cumsum",
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/__init__.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/activation_fn.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/all_to_all.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/arguments.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/common.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/dmlp_registry.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/dmoe.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/gelu.py RENAMED
File without changes
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{layers β†’ _layers}/glu.py RENAMED
File without changes