kernel
File size: 1,608 Bytes
dabb815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch

from collections import namedtuple


def test_megablocks_moe_mlp_import():
    """Test if MegaBlocksMoeMLP can be imported."""
    from megablocks.layers import MegaBlocksMoeMLP

    assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."


def test_megablocks_moe_mlp_functionality():
    """Test the functionality of MegaBlocksMoeMLP."""
    from megablocks.layers import MegaBlocksMoeMLP

    # Create a simple instance of MegaBlocksMoeMLP
    model = MegaBlocksMoeMLP()

    # add experts attribute to the model
    model.experts = namedtuple(
        "Experts",
        [
            "gate_up_proj",
            "gate_down_proj",
            "down_proj",
            "hidden_size",
        ],
    )

    num_experts = 128
    hidden_size = 1152
    intermediate_size = 3072

    # Shorter names for reading convenience
    ne, hs, isz = num_experts, hidden_size, intermediate_size

    model.router = torch.nn.Linear(hs, ne).cuda()
    model.router.weight.data.fill_(1)

    e = model.experts
    e.gate_up_proj = torch.nn.Parameter(torch.ones(ne, hs, isz, device="cuda"))
    e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
    e.down_proj = torch.nn.Parameter(torch.ones(ne, 1536, hs, device="cuda"))
    e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
    e.hidden_size = hs

    # Create dummy input data
    x = torch.randn(1, 1, 1152).to(torch.device("cuda"))
    output, expert_weights_out = model(x)

    # print("Output shape:", output.shape)
    assert output.shape == (1, 1, 1152), "Output shape mismatch."