Spaces:
Running
on
T4
Running
on
T4
File size: 4,790 Bytes
14d91dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
import os
from dataclasses import dataclass, field
import torch
from torch import nn
from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
from xlstm.xlstm_large import xLSTMLargeConfig
from xlstm.xlstm_large.components import RMSNorm
from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType
def skip_cuda():
return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
return sLSTMLayer(
sLSTMLayerConfig(
embedding_dim=config.embedding_dim,
num_heads=config.num_heads,
conv1d_kernel_size=0, # 0 means no convolution included
group_norm_weight=True,
dropout=0,
# CellConfig
backend="vanilla" if skip_cuda() else "cuda",
bias_init="powerlaw_blockdependent",
recurrent_weight_init="zeros",
num_gates=4,
gradient_recurrent_cut=False,
gradient_recurrent_clipval=None,
forward_clipval=None,
batch_size=8, # needed?
_block_idx=block_idx,
_num_blocks=num_blocks,
)
)
sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
sLSTMStateType = dict[int, sLSTMLayerStateType]
class sLSTMBlock(nn.Module):
def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
super().__init__()
self.config = config
self.norm_slstm = RMSNorm(
num_features=config.embedding_dim,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
self.slstm_layer = init_cell(config, block_idx, num_blocks)
self.norm_ffn = RMSNorm(
num_features=config.embedding_dim,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
self.ffn = FeedForward(config)
def forward(
self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
) -> tuple[torch.Tensor, sLSTMLayerStateType]:
x_slstm = self.norm_slstm(x)
if state is None:
conv_state, slstm_state = None, None
else:
conv_state, slstm_state = state
x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
x = x + x_slstm
x_ffn = self.norm_ffn(x)
x_ffn = self.ffn(x_ffn)
x = x + x_ffn
return x, (state["conv_state"], state["slstm_state"])
@dataclass
class xLSTMMixedLargeConfig(xLSTMLargeConfig):
slstm_at: list[int] = field(default_factory=list)
all_slstm: bool = True
@property
def block_types(self):
return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]
class xLSTMMixedLargeBlockStack(nn.Module):
config_class = xLSTMMixedLargeConfig
def __init__(self, config: xLSTMMixedLargeConfig):
super().__init__()
self.config = config
self.blocks = nn.ModuleList(
[
sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
for i, t in enumerate(config.block_types)
]
)
if self.config.add_out_norm:
self.out_norm = RMSNorm(
num_features=config.embedding_dim,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
else:
self.out_norm = nn.Identity()
def forward(
self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
) -> tuple[torch.Tensor, mLSTMStateType]:
if state is None:
state = {i: None for i in range(len(self.blocks))}
for i, block in enumerate(self.blocks):
block_state = state[i]
x, block_state_new = block(x, block_state)
if block_state is None:
state[i] = block_state_new
else:
pass
## layer state is a tuple of three tensors: c, n, m
## we update the state in place in order to avoid creating new tensors
# for state_idx in range(len(block_state)):
# state[i][state_idx].copy_(block_state_new[state_idx])
x = self.out_norm(x)
return x, state
|