import torch import torch.distributed as dist import torch.multiprocessing as mp import os def test_megablocks_moe_mlp_import(): from megablocks.layers import MegaBlocksMoeMLP assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed." def run_distributed_test(rank, world_size): from megablocks.layers import MegaBlocksMoeMLP os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) dist.init_process_group( backend="gloo", rank=rank, world_size=world_size, ) expert_parallel_group = torch.distributed.new_group( range(torch.distributed.get_world_size()) ) model = MegaBlocksMoeMLP() model.expert_parallel_group = expert_parallel_group class Experts: def __init__(self): self.gate_up_proj = None self.gate_up_proj_bias = None self.down_proj = None self.down_proj_bias = None self.hidden_size = None model.experts = Experts() num_experts = 128 hidden_size = 1152 intermediate_size = 3072 ne, hs, isz = num_experts, hidden_size, intermediate_size experts_per_rank = ne // world_size device = "cuda" if torch.cuda.is_available() else "cpu" model.router = torch.nn.Linear(hs, ne).to(device) model.router.weight.data.fill_(1) e = model.experts e.gate_up_proj = torch.nn.Parameter( torch.ones(experts_per_rank, hs, isz, device=device) ) e.gate_up_proj_bias = torch.nn.Parameter( torch.zeros(experts_per_rank, isz, device=device) ) e.down_proj = torch.nn.Parameter( torch.ones(experts_per_rank, 1536, hs, device=device) ) e.down_proj_bias = torch.nn.Parameter( torch.zeros(experts_per_rank, hs, device=device) ) e.hidden_size = hs x = torch.randn(1, 1, 1152).to(device) output, expert_weights_out = model(x) assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}." print(f"Rank {rank}: Test passed! Output shape: {output.shape}") dist.destroy_process_group() def test_megablocks_moe_mlp_functionality(): world_size = 2 mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True) print("Multi-process test completed successfully!") if __name__ == "__main__": test_megablocks_moe_mlp_import() print("Import test passed!") test_megablocks_moe_mlp_functionality()