|
import torch |
|
|
|
from collections import namedtuple |
|
|
|
|
|
def test_megablocks_moe_mlp_import(): |
|
"""Test if MegaBlocksMoeMLP can be imported.""" |
|
from megablocks.layers import MegaBlocksMoeMLP |
|
|
|
assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed." |
|
|
|
|
|
def test_megablocks_moe_mlp_functionality(): |
|
"""Test the functionality of MegaBlocksMoeMLP.""" |
|
from megablocks.layers import MegaBlocksMoeMLP |
|
|
|
|
|
model = MegaBlocksMoeMLP() |
|
|
|
|
|
model.experts = namedtuple( |
|
"Experts", |
|
[ |
|
"gate_up_proj", |
|
"gate_down_proj", |
|
"down_proj", |
|
"hidden_size", |
|
], |
|
) |
|
|
|
num_experts = 128 |
|
hidden_size = 1152 |
|
intermediate_size = 3072 |
|
|
|
|
|
ne, hs, isz = num_experts, hidden_size, intermediate_size |
|
|
|
model.router = torch.nn.Linear(hs, ne).cuda() |
|
model.router.weight.data.fill_(1) |
|
|
|
e = model.experts |
|
e.gate_up_proj = torch.nn.Parameter(torch.ones(ne, hs, isz, device="cuda")) |
|
e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda")) |
|
e.down_proj = torch.nn.Parameter(torch.ones(ne, 1536, hs, device="cuda")) |
|
e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda")) |
|
e.hidden_size = hs |
|
|
|
|
|
x = torch.randn(1, 1, 1152).to(torch.device("cuda")) |
|
output, expert_weights_out = model(x) |
|
|
|
|
|
assert output.shape == (1, 1, 1152), "Output shape mismatch." |
|
|