TiRex-demo / tirex /models /mixed_stack.py
Nikita
added tirex as model
14d91dc
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
import os
from dataclasses import dataclass, field
import torch
from torch import nn
from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
from xlstm.xlstm_large import xLSTMLargeConfig
from xlstm.xlstm_large.components import RMSNorm
from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType
def skip_cuda():
return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
return sLSTMLayer(
sLSTMLayerConfig(
embedding_dim=config.embedding_dim,
num_heads=config.num_heads,
conv1d_kernel_size=0, # 0 means no convolution included
group_norm_weight=True,
dropout=0,
# CellConfig
backend="vanilla" if skip_cuda() else "cuda",
bias_init="powerlaw_blockdependent",
recurrent_weight_init="zeros",
num_gates=4,
gradient_recurrent_cut=False,
gradient_recurrent_clipval=None,
forward_clipval=None,
batch_size=8, # needed?
_block_idx=block_idx,
_num_blocks=num_blocks,
)
)
sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
sLSTMStateType = dict[int, sLSTMLayerStateType]
class sLSTMBlock(nn.Module):
def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
super().__init__()
self.config = config
self.norm_slstm = RMSNorm(
num_features=config.embedding_dim,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
self.slstm_layer = init_cell(config, block_idx, num_blocks)
self.norm_ffn = RMSNorm(
num_features=config.embedding_dim,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
self.ffn = FeedForward(config)
def forward(
self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
) -> tuple[torch.Tensor, sLSTMLayerStateType]:
x_slstm = self.norm_slstm(x)
if state is None:
conv_state, slstm_state = None, None
else:
conv_state, slstm_state = state
x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
x = x + x_slstm
x_ffn = self.norm_ffn(x)
x_ffn = self.ffn(x_ffn)
x = x + x_ffn
return x, (state["conv_state"], state["slstm_state"])
@dataclass
class xLSTMMixedLargeConfig(xLSTMLargeConfig):
slstm_at: list[int] = field(default_factory=list)
all_slstm: bool = True
@property
def block_types(self):
return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]
class xLSTMMixedLargeBlockStack(nn.Module):
config_class = xLSTMMixedLargeConfig
def __init__(self, config: xLSTMMixedLargeConfig):
super().__init__()
self.config = config
self.blocks = nn.ModuleList(
[
sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
for i, t in enumerate(config.block_types)
]
)
if self.config.add_out_norm:
self.out_norm = RMSNorm(
num_features=config.embedding_dim,
eps=config.norm_eps,
use_weight=True,
use_bias=config.use_bias,
force_float32_reductions=config.norm_reduction_force_float32,
)
else:
self.out_norm = nn.Identity()
def forward(
self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
) -> tuple[torch.Tensor, mLSTMStateType]:
if state is None:
state = {i: None for i in range(len(self.blocks))}
for i, block in enumerate(self.blocks):
block_state = state[i]
x, block_state_new = block(x, block_state)
if block_state is None:
state[i] = block_state_new
else:
pass
## layer state is a tuple of three tensors: c, n, m
## we update the state in place in order to avoid creating new tensors
# for state_idx in range(len(block_state)):
# state[i][state_idx].copy_(block_state_new[state_idx])
x = self.out_norm(x)
return x, state