File size: 4,790 Bytes
14d91dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.


import os
from dataclasses import dataclass, field

import torch
from torch import nn
from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
from xlstm.xlstm_large import xLSTMLargeConfig
from xlstm.xlstm_large.components import RMSNorm
from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType


def skip_cuda():
    return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")


def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
    return sLSTMLayer(
        sLSTMLayerConfig(
            embedding_dim=config.embedding_dim,
            num_heads=config.num_heads,
            conv1d_kernel_size=0,  # 0 means no convolution included
            group_norm_weight=True,
            dropout=0,
            # CellConfig
            backend="vanilla" if skip_cuda() else "cuda",
            bias_init="powerlaw_blockdependent",
            recurrent_weight_init="zeros",
            num_gates=4,
            gradient_recurrent_cut=False,
            gradient_recurrent_clipval=None,
            forward_clipval=None,
            batch_size=8,  # needed?
            _block_idx=block_idx,
            _num_blocks=num_blocks,
        )
    )


sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
sLSTMStateType = dict[int, sLSTMLayerStateType]


class sLSTMBlock(nn.Module):
    def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
        super().__init__()
        self.config = config
        self.norm_slstm = RMSNorm(
            num_features=config.embedding_dim,
            eps=config.norm_eps,
            use_weight=True,
            use_bias=config.use_bias,
            force_float32_reductions=config.norm_reduction_force_float32,
        )
        self.slstm_layer = init_cell(config, block_idx, num_blocks)

        self.norm_ffn = RMSNorm(
            num_features=config.embedding_dim,
            eps=config.norm_eps,
            use_weight=True,
            use_bias=config.use_bias,
            force_float32_reductions=config.norm_reduction_force_float32,
        )
        self.ffn = FeedForward(config)

    def forward(
        self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
    ) -> tuple[torch.Tensor, sLSTMLayerStateType]:
        x_slstm = self.norm_slstm(x)
        if state is None:
            conv_state, slstm_state = None, None
        else:
            conv_state, slstm_state = state
        x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
        x = x + x_slstm

        x_ffn = self.norm_ffn(x)
        x_ffn = self.ffn(x_ffn)
        x = x + x_ffn

        return x, (state["conv_state"], state["slstm_state"])


@dataclass
class xLSTMMixedLargeConfig(xLSTMLargeConfig):
    slstm_at: list[int] = field(default_factory=list)
    all_slstm: bool = True

    @property
    def block_types(self):
        return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]


class xLSTMMixedLargeBlockStack(nn.Module):
    config_class = xLSTMMixedLargeConfig

    def __init__(self, config: xLSTMMixedLargeConfig):
        super().__init__()
        self.config = config

        self.blocks = nn.ModuleList(
            [
                sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
                for i, t in enumerate(config.block_types)
            ]
        )

        if self.config.add_out_norm:
            self.out_norm = RMSNorm(
                num_features=config.embedding_dim,
                eps=config.norm_eps,
                use_weight=True,
                use_bias=config.use_bias,
                force_float32_reductions=config.norm_reduction_force_float32,
            )
        else:
            self.out_norm = nn.Identity()

    def forward(
        self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
    ) -> tuple[torch.Tensor, mLSTMStateType]:
        if state is None:
            state = {i: None for i in range(len(self.blocks))}

        for i, block in enumerate(self.blocks):
            block_state = state[i]
            x, block_state_new = block(x, block_state)

            if block_state is None:
                state[i] = block_state_new
            else:
                pass
                ## layer state is a tuple of three tensors: c, n, m
                ## we update the state in place in order to avoid creating new tensors
                # for state_idx in range(len(block_state)):
                #    state[i][state_idx].copy_(block_state_new[state_idx])

        x = self.out_norm(x)

        return x, state