kernel
megablocks / tests /parallel_layer_test.py
drbh
fix: improve expert parallel implementation and refactors
e47036a
raw
history blame
2.54 kB
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
def test_megablocks_moe_mlp_import():
from megablocks.layers import MegaBlocksMoeMLP
assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."
def run_distributed_test(rank, world_size):
from megablocks.layers import MegaBlocksMoeMLP
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size,
)
expert_parallel_group = torch.distributed.new_group(
range(torch.distributed.get_world_size())
)
model = MegaBlocksMoeMLP()
model.expert_parallel_group = expert_parallel_group
class Experts:
def __init__(self):
self.gate_up_proj = None
self.gate_up_proj_bias = None
self.down_proj = None
self.down_proj_bias = None
self.hidden_size = None
model.experts = Experts()
num_experts = 128
hidden_size = 1152
intermediate_size = 3072
ne, hs, isz = num_experts, hidden_size, intermediate_size
experts_per_rank = ne // world_size
device = "cuda" if torch.cuda.is_available() else "cpu"
model.router = torch.nn.Linear(hs, ne).to(device)
model.router.weight.data.fill_(1)
e = model.experts
e.gate_up_proj = torch.nn.Parameter(
torch.ones(experts_per_rank, hs, isz, device=device)
)
e.gate_up_proj_bias = torch.nn.Parameter(
torch.zeros(experts_per_rank, isz, device=device)
)
e.down_proj = torch.nn.Parameter(
torch.ones(experts_per_rank, 1536, hs, device=device)
)
e.down_proj_bias = torch.nn.Parameter(
torch.zeros(experts_per_rank, hs, device=device)
)
e.hidden_size = hs
x = torch.randn(1, 1, 1152).to(device)
output, expert_weights_out = model(x)
assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}."
print(f"Rank {rank}: Test passed! Output shape: {output.shape}")
dist.destroy_process_group()
def test_megablocks_moe_mlp_functionality():
world_size = 2
mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True)
print("Multi-process test completed successfully!")
if __name__ == "__main__":
test_megablocks_moe_mlp_import()
print("Import test passed!")
test_megablocks_moe_mlp_functionality()