kernel
File size: 2,536 Bytes
e47036a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os


def test_megablocks_moe_mlp_import():
    from megablocks.layers import MegaBlocksMoeMLP

    assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."


def run_distributed_test(rank, world_size):
    from megablocks.layers import MegaBlocksMoeMLP

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)

    dist.init_process_group(
        backend="gloo",
        rank=rank,
        world_size=world_size,
    )

    expert_parallel_group = torch.distributed.new_group(
        range(torch.distributed.get_world_size())
    )

    model = MegaBlocksMoeMLP()
    model.expert_parallel_group = expert_parallel_group

    class Experts:
        def __init__(self):
            self.gate_up_proj = None
            self.gate_up_proj_bias = None
            self.down_proj = None
            self.down_proj_bias = None
            self.hidden_size = None

    model.experts = Experts()

    num_experts = 128
    hidden_size = 1152
    intermediate_size = 3072

    ne, hs, isz = num_experts, hidden_size, intermediate_size

    experts_per_rank = ne // world_size

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model.router = torch.nn.Linear(hs, ne).to(device)
    model.router.weight.data.fill_(1)

    e = model.experts
    e.gate_up_proj = torch.nn.Parameter(
        torch.ones(experts_per_rank, hs, isz, device=device)
    )
    e.gate_up_proj_bias = torch.nn.Parameter(
        torch.zeros(experts_per_rank, isz, device=device)
    )
    e.down_proj = torch.nn.Parameter(
        torch.ones(experts_per_rank, 1536, hs, device=device)
    )
    e.down_proj_bias = torch.nn.Parameter(
        torch.zeros(experts_per_rank, hs, device=device)
    )
    e.hidden_size = hs

    x = torch.randn(1, 1, 1152).to(device)
    output, expert_weights_out = model(x)

    assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}."

    print(f"Rank {rank}: Test passed! Output shape: {output.shape}")

    dist.destroy_process_group()


def test_megablocks_moe_mlp_functionality():
    world_size = 2

    mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True)

    print("Multi-process test completed successfully!")


if __name__ == "__main__":
    test_megablocks_moe_mlp_import()
    print("Import test passed!")

    test_megablocks_moe_mlp_functionality()