|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import os |
|
|
|
|
|
def test_megablocks_moe_mlp_import(): |
|
from megablocks.layers import MegaBlocksMoeMLP |
|
|
|
assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed." |
|
|
|
|
|
def run_distributed_test(rank, world_size): |
|
from megablocks.layers import MegaBlocksMoeMLP |
|
|
|
os.environ["MASTER_ADDR"] = "localhost" |
|
os.environ["MASTER_PORT"] = "12355" |
|
os.environ["RANK"] = str(rank) |
|
os.environ["WORLD_SIZE"] = str(world_size) |
|
|
|
dist.init_process_group( |
|
backend="gloo", |
|
rank=rank, |
|
world_size=world_size, |
|
) |
|
|
|
expert_parallel_group = torch.distributed.new_group( |
|
range(torch.distributed.get_world_size()) |
|
) |
|
|
|
model = MegaBlocksMoeMLP() |
|
model.expert_parallel_group = expert_parallel_group |
|
|
|
class Experts: |
|
def __init__(self): |
|
self.gate_up_proj = None |
|
self.gate_up_proj_bias = None |
|
self.down_proj = None |
|
self.down_proj_bias = None |
|
self.hidden_size = None |
|
|
|
model.experts = Experts() |
|
|
|
num_experts = 128 |
|
hidden_size = 1152 |
|
intermediate_size = 3072 |
|
|
|
ne, hs, isz = num_experts, hidden_size, intermediate_size |
|
|
|
experts_per_rank = ne // world_size |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model.router = torch.nn.Linear(hs, ne).to(device) |
|
model.router.weight.data.fill_(1) |
|
|
|
e = model.experts |
|
e.gate_up_proj = torch.nn.Parameter( |
|
torch.ones(experts_per_rank, hs, isz, device=device) |
|
) |
|
e.gate_up_proj_bias = torch.nn.Parameter( |
|
torch.zeros(experts_per_rank, isz, device=device) |
|
) |
|
e.down_proj = torch.nn.Parameter( |
|
torch.ones(experts_per_rank, 1536, hs, device=device) |
|
) |
|
e.down_proj_bias = torch.nn.Parameter( |
|
torch.zeros(experts_per_rank, hs, device=device) |
|
) |
|
e.hidden_size = hs |
|
|
|
x = torch.randn(1, 1, 1152).to(device) |
|
output, expert_weights_out = model(x) |
|
|
|
assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}." |
|
|
|
print(f"Rank {rank}: Test passed! Output shape: {output.shape}") |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
def test_megablocks_moe_mlp_functionality(): |
|
world_size = 2 |
|
|
|
mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True) |
|
|
|
print("Multi-process test completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
test_megablocks_moe_mlp_import() |
|
print("Import test passed!") |
|
|
|
test_megablocks_moe_mlp_functionality() |
|
|