Spaces:
Sleeping
Sleeping
| # ============================================================================= | |
| # core/mamba.py | |
| # ============================================================================= | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from core.stateSpace import StateSpaceModel | |
| from utils.conv_layer import Mamba1DConv | |
| class RMSNorm(nn.Module): | |
| def __init__(self, d_model: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(d_model)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| norm = x.norm(dim=-1, keepdim=True) * (x.shape[-1] ** -0.5) | |
| return x / (norm + self.eps) * self.weight | |
| class MambaBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| # Projections | |
| self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) | |
| self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) | |
| # Convolution for local context | |
| self.conv1d = Mamba1DConv(config.d_inner, config.d_conv, config.conv_bias) | |
| # State space model | |
| self.ssm = StateSpaceModel( | |
| d_inner=config.d_inner, | |
| d_state=config.d_state, | |
| dt_rank=config.dt_rank, | |
| bias=config.bias | |
| ) | |
| # Activation | |
| self.act = F.silu | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: [batch, seq_len, d_model] | |
| Returns: | |
| output: [batch, seq_len, d_model] | |
| """ | |
| batch_size, seq_len, d_model = x.shape | |
| # Input projection | |
| xz = self.in_proj(x) # [batch, seq_len, 2*d_inner] | |
| x, z = xz.chunk(2, dim=-1) # Each [batch, seq_len, d_inner] | |
| # Apply convolution | |
| x = self.act(self.conv1d(x)) | |
| # Apply state space model | |
| y = self.ssm(x) | |
| # Apply gating with z | |
| y = y * self.act(z) | |
| # Output projection | |
| output = self.out_proj(y) | |
| return output | |
| class MambaLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.norm = RMSNorm(config.d_model) | |
| self.mamba_block = MambaBlock(config) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Pre-norm architecture | |
| residual = x | |
| x = self.norm(x) | |
| x = self.mamba_block(x) | |
| return x + residual |