diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/common.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/router.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch26-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index 7e041eea505b2cea718255f91475aaf5ed5262b5..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ad46e9f244afa886c8a104d75e37f93afd2a0ecf83bfc7a414680fa16d8b78f9 -size 10517608 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..779e13d3baac0ee00e89944b954a8bc75fcb432f --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a20cd4dc15095b8504db981c651e516e8a7d8394b99d973d632558637c8dba9 +size 10517576 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/__init__.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/arguments.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/common.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/gelu.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/glu.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mlp.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/moe.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/mpu.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/router.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch26-cxx11-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index 8c2fc7cc25015ef81d73aa2c0231d1af2d6b9c86..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1419672a07ed370d7107ca54a6b694f234efa8e696644ee4e96c1bf396aff6af -size 11869424 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e349ce35ea4dd2273fcca6389af58dd2a08bff48 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3f69e5978b727f08b43112c2321a222719aa824612d452029225a48976dfbb6 +size 11869392 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/common.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/router.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch26-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index ac1dda0ba2b7480d7680927a42a5de7be43762d2..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:746c9b13bc0a8cbf0524a573252beb4a7490dda851e0246aa4e38ac3828c62e7 -size 11931080 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..f6c25f44514edb4a00c0afaed50e5d80b8d07261 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57540f7b6eae09c2c62826d13dfa2be53eaa37c86206df5914611a3fad9878ba +size 11931048 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/__init__.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/arguments.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/common.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/gelu.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/glu.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mlp.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/moe.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/mpu.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/router.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch26-cxx98-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index a8f46771a14d3e54f5cbe5d1ae726dd1888a70c1..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2fbfef8842d7110f0e496806edee088172cf4fc2e63e1782e8b6a3008bd3d304 -size 10510072 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b9f0146d0a5676550971b4bf724c153bfd6f200f --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:defeb9a48abe98940478c79c5eac52f9fc7c22088abf9d119191559787bb95a9 +size 10510040 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/__init__.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/arguments.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/common.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/gelu.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/glu.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mlp.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/moe.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/mpu.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/router.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch26-cxx98-cu124-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index 6c37e4fc559f163d315b5e00282d5a436b9defe6..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:633cb26576e8af1c8cc4a0b9e86287fb1f5c7651452d400e5ea6e9e95c49f1bd -size 11857952 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..9f9e0e7706d6ae183a1e78ed2156b4e74e2a3ffd --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6cc8f982e35bfa07a9121807bacd3c3572d6ecb1495bcb2b6286b967fb20d58 +size 11857920 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/__init__.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/arguments.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/common.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/gelu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/gelu.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/glu.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mlp.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/moe.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/mpu.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/router.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch26-cxx98-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index a6ddf9c7634f8b8889d37a6f6e774c1ae01f3291..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:69198313a732d74196039b0c350fe1d4e14448abdf9b4c2a6215eabd9b0e171c -size 11923704 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..60b4dab63ebafbe4c9318da846c2d6814b91376b --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be9d5e5df42fd6d0db62397eae9c462b9775e952ce7f71fb687c3ea75dfe6a74 +size 11923672 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/__init__.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/arguments.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/common.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/gelu.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/glu.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mlp.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/moe.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/mpu.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/router.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index 2edf877dce4fdf86bffbaa772e48aeb7f456199b..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f535bc4e3bf6c970117d2dbf7f7b4dcdabf1fea035590851c2a558ca0e95497f -size 10517848 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..69f059780cd3ae0839b2d079bc031ccaad4a4da6 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c297e2817e6b2dd9af4f1166e448844f785ef45ed769d0289766bc9169767df +size 10517816 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/__init__.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/arguments.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/common.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/gelu.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/glu.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mlp.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/moe.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/mpu.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/router.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch27-cxx11-cu126-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index 229aa059ff8f5d788ce34d0d0707467f9051e603..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4708535fe6c5c8d59883ad81b5ce6b0083c3ed0ae71f14749ac2415899797f36 -size 11931112 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..96ff9ba92b272f12146d8eea4f96fe28724fc82d --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fbf97bdf349597f84616a22e3bc8c25cba2d77aef15dfa84bde284ffa51fe38 +size 11931080 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py index af8d40aa52c706d2aae368cc11ade68cf13f6a47..38075732c6d8fa0e1e6ef493145e1aca3851ae6b 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/__init__.py @@ -9,11 +9,13 @@ from .grouped_gemm import backend as gg_backend from .grouped_gemm import ops as gg_ops -from .layers.arguments import Arguments -from .layers.dmoe import ParallelDroplessMLP, dMoE -from .layers.glu import SparseGLU -from .layers.mlp import MLP, SparseMLP -from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers # This section contains the direct kernel exports (not inlcuded in the original code) def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: @@ -176,6 +178,7 @@ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Ten # Export public API __all__ = [ + "MyReplacementLayer", # Direct kernel exports "exclusive_cumsum", "inclusive_cumsum", diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/__init__.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/__init__.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/activation_fn.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/activation_fn.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/activation_fn.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/all_to_all.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/all_to_all.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/all_to_all.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/arguments.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/arguments.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/common.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/common.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmlp_registry.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmlp_registry.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/dmoe.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/dmoe.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/gelu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/gelu.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/gelu.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/glu.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/glu.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/memory_test.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/memory_test.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mlp.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mlp.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/moe.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/moe.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/mpu.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/mpu.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/router.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/router.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py similarity index 100% rename from build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers/sharedexpert_registry.py rename to build/torch27-cxx11-cu128-x86_64-linux/megablocks/_layers/sharedexpert_registry.py diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so deleted file mode 100755 index 50000f9b7cb45cf014efe881e2c53c60f3497663..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_6756875_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ae05d16e382a38c3a908f350e2347206319b5a0a9f9cd2447464b6409dc10fce -size 17892656 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a723817e44e7aec8375f871ad1bf14ee404c7d1d --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_dabb815.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bac66fb2798cffcb10563bb1ad157a5e1d231dcf1fc615825b6c8e6f6b297d20 +size 17892624 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py index c9a9faa3cee07c8b7016a71704ea26e47c460d57..3cc5ae109015332d831f2357627d00ef67faa25d 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _megablocks_6756875_dirty -ops = torch.ops._megablocks_6756875_dirty +from . import _megablocks_dabb815 +ops = torch.ops._megablocks_dabb815 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_megablocks_6756875_dirty::{op_name}" \ No newline at end of file + return f"_megablocks_dabb815::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d3737eddff33082db4c8b31b2c68549b13a911b4 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py @@ -0,0 +1,566 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +## START: Load Balancing Loss (unused at the moment) + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +## END Load Balancing Loss + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + """Permute tokens and compute expert outputs.""" + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +# TODO: replace with functional logic once aligned with ref +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, +): + pass + + +class MyReplacementLayer(torch.nn.Module): + # def __init__(self): + # super().__init__() + + def forward( + # self, + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + # + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, # Required for parallel forward + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + router_weight = self.router.weight + moe_top_k = 4 + moe_num_experts = 128 + w1 = self.experts.gate_up_proj.data + w2 = self.experts.down_proj.data + w1_bias = self.experts.gate_up_proj_bias.data + w2_bias = self.experts.down_proj_bias.data + expert_parallel_group = None + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + hidden_size = self.experts.hidden_size + + output, expert_weights_out, router_scores = MyReplacementLayer.forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=None, + moe_normalize_expert_weights=None, + uniform_expert_assignment=False, + training=False, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=None, + alpha=1.702, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=1.0, + moe_expert_model_parallelism=False, + forward_fn=forward_once, + hidden_size=hidden_size, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py index 43d267dbe2570e4a1f59fd561398668cc2bc0920..4c939818edca3345f6344bbc7cef07ffe3cd0181 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -8,7 +8,7 @@ import torch.distributed as dist # from megablocks.layers.all_to_all import all_to_all from .. import benchmark_util -from ..layers.all_to_all import all_to_all +from .._layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024),