Nikita commited on
Commit
14d91dc
·
1 Parent(s): 810611f

added tirex as model

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
tirex/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from .api_adapter.forecast import ForecastModel
5
+ from .base import load_model
6
+ from .models.tirex import TiRexZero
7
+
8
+ __all__ = ["load_model", "ForecastModel"]
tirex/api_adapter/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
tirex/api_adapter/forecast.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from .standard_adapter import ContextType, get_batches
10
+
11
+ try:
12
+ from .gluon import format_gluonts_output, get_gluon_batches
13
+
14
+ _GLUONTS_AVAILABLE = True
15
+ except ImportError:
16
+ _GLUONTS_AVAILABLE = False
17
+
18
+ try:
19
+ from .hf_data import get_hfdata_batches
20
+
21
+ _HF_DATASETS_AVAILABLE = True
22
+ except ImportError:
23
+ _HF_DATASETS_AVAILABLE = False
24
+
25
+
26
+ DEF_TARGET_COLUMN = "target"
27
+ DEF_META_COLUMNS = ("start", "item_id")
28
+
29
+
30
+ def _format_output(
31
+ quantiles: torch.Tensor,
32
+ means: torch.Tensor,
33
+ sample_meta: list[dict],
34
+ quantile_levels: list[float],
35
+ output_type: Literal["torch", "numpy", "gluonts"],
36
+ ):
37
+ if output_type == "torch":
38
+ return quantiles.cpu(), means.cpu()
39
+ elif output_type == "numpy":
40
+ return quantiles.cpu().numpy(), means.cpu().numpy()
41
+ elif output_type == "gluonts":
42
+ if not _GLUONTS_AVAILABLE:
43
+ raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
44
+ return format_gluonts_output(quantiles, means, sample_meta, quantile_levels)
45
+ else:
46
+ raise ValueError(f"Invalid output type: {output_type}")
47
+
48
+
49
+ def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
50
+ for batch_ctx, batch_meta in batches:
51
+ quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
52
+ yield _format_output(
53
+ quantiles=quantiles,
54
+ means=mean,
55
+ sample_meta=batch_meta,
56
+ quantile_levels=quantile_levels,
57
+ output_type=output_type,
58
+ )
59
+
60
+
61
+ def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
62
+ if yield_per_batch:
63
+ return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
64
+
65
+ prediction_q = []
66
+ prediction_m = []
67
+ sample_meta = []
68
+ for batch_ctx, batch_meta in batches:
69
+ quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
70
+ prediction_q.append(quantiles)
71
+ prediction_m.append(mean)
72
+ sample_meta.extend(batch_meta)
73
+
74
+ prediction_q = torch.cat(prediction_q, dim=0)
75
+ prediction_m = torch.cat(prediction_m, dim=0)
76
+
77
+ return _format_output(
78
+ quantiles=prediction_q,
79
+ means=prediction_m,
80
+ sample_meta=sample_meta,
81
+ quantile_levels=quantile_levels,
82
+ output_type=output_type,
83
+ )
84
+
85
+
86
+ def _common_forecast_doc():
87
+ common_doc = f"""
88
+ This method takes historical context data as input and outputs probabilistic forecasts.
89
+
90
+ Args:
91
+ output_type (Literal["torch", "numpy", "gluonts"], optional):
92
+ Specifies the desired format of the returned forecasts:
93
+ - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|]
94
+ - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|]
95
+ - "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
96
+ Defaults to "torch".
97
+
98
+ batch_size (int, optional): The number of time series instances to process concurrently by the model.
99
+ Defaults to 512. Must be $>= 1$.
100
+
101
+ quantile_levels (List[float], optional): Quantile levels for which predictions should be generated.
102
+ Defaults to (0.1, 0.2, ..., 0.9).
103
+
104
+ yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
105
+ forecasts batch by batch as they are computed.
106
+ Defaults to `False`.
107
+
108
+ **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
109
+ prediction mechanism of the pre-trained model. Refer to the model's
110
+ internal prediction method documentation for available options.
111
+
112
+ Returns:
113
+ The return type depends on `output_type` and `yield_per_batch`:
114
+ - If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
115
+ will correspond to a batch of forecasts in the format specified by `output_type`.
116
+ - If `yield_per_batch` is `False`: A single object containing all forecasts.
117
+ - If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
118
+ - If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
119
+ - If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
120
+ """
121
+ return common_doc
122
+
123
+
124
+ class ForecastModel(ABC):
125
+ @abstractmethod
126
+ def _forecast_quantiles(self, batch, **predict_kwargs):
127
+ pass
128
+
129
+ def forecast(
130
+ self,
131
+ context: ContextType,
132
+ output_type: Literal["torch", "numpy", "gluonts"] = "torch",
133
+ batch_size: int = 512,
134
+ quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
135
+ yield_per_batch: bool = False,
136
+ **predict_kwargs,
137
+ ):
138
+ f"""
139
+ {_common_forecast_doc}
140
+ Args:
141
+ context (ContextType): The historical "context" data of the time series:
142
+ - `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor
143
+ - `np.ndarray`: 1D `[context_length]` or 2D `[batch_dim, context_length]` array
144
+ - `List[torch.Tensor]`: List of 1D tensors (samples with different lengths get padded per batch)
145
+ - `List[np.ndarray]`: List of 1D arrays (samples with different lengths get padded per batch)
146
+ """
147
+ assert batch_size >= 1, "Batch size must be >= 1"
148
+ batches = get_batches(context, batch_size)
149
+ return _gen_forecast(
150
+ self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
151
+ )
152
+
153
+ def forecast_gluon(
154
+ self,
155
+ gluonDataset,
156
+ output_type: Literal["torch", "numpy", "gluonts"] = "torch",
157
+ batch_size: int = 512,
158
+ quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
159
+ yield_per_batch: bool = False,
160
+ data_kwargs: dict = {},
161
+ **predict_kwargs,
162
+ ):
163
+ f"""
164
+ {_common_forecast_doc()}
165
+
166
+ Args:
167
+ gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the
168
+ historical time series data.
169
+
170
+ data_kwargs (dict, optional): Additional keyword arguments passed to the
171
+ autogluon data processing function.
172
+ """
173
+ assert batch_size >= 1, "Batch size must be >= 1"
174
+ if not _GLUONTS_AVAILABLE:
175
+ raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!")
176
+ batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
177
+ return _gen_forecast(
178
+ self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
179
+ )
180
+
181
+ def forecast_hfdata(
182
+ self,
183
+ hf_dataset,
184
+ output_type: Literal["torch", "numpy", "gluonts"] = "torch",
185
+ batch_size: int = 512,
186
+ quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
187
+ yield_per_batch: bool = False,
188
+ data_kwargs: dict = {},
189
+ **predict_kwargs,
190
+ ):
191
+ f"""
192
+ {_common_forecast_doc()}
193
+
194
+ Args:
195
+ hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the
196
+ historical time series data.
197
+
198
+ data_kwargs (dict, optional): Additional keyword arguments passed to the
199
+ datasets data processing function.
200
+ """
201
+ assert batch_size >= 1, "Batch size must be >= 1"
202
+ if not _HF_DATASETS_AVAILABLE:
203
+ raise ValueError(
204
+ "forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!"
205
+ )
206
+ batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
207
+ return _gen_forecast(
208
+ self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
209
+ )
tirex/api_adapter/gluon.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from gluonts.dataset.common import Dataset
7
+ from gluonts.dataset.field_names import FieldName
8
+ from gluonts.model.forecast import QuantileForecast
9
+
10
+ from .standard_adapter import _batch_pad_iterable
11
+
12
+ DEF_TARGET_COLUMN = FieldName.TARGET # target
13
+ DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
14
+
15
+
16
+ def _get_gluon_ts_map(**gluon_kwargs):
17
+ target_col = gluon_kwargs.get("target_column", DEF_TARGET_COLUMN)
18
+ meta_columns = gluon_kwargs.get("meta_columns", DEF_META_COLUMNS)
19
+
20
+ def extract_gluon(series):
21
+ ctx = torch.Tensor(series[target_col])
22
+ meta = {k: series[k] for k in meta_columns if k in series}
23
+ meta["length"] = len(ctx)
24
+ return ctx, meta
25
+
26
+ return extract_gluon
27
+
28
+
29
+ def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
30
+ return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
31
+
32
+
33
+ def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
34
+ forecasts = []
35
+ for i in range(quantile_forecasts.shape[0]):
36
+ start_date = meta[i].get(FieldName.START, pd.Period("01-01-2000", freq=meta[i].get("freq", "h")))
37
+ start_date += meta[i].get("length", 0)
38
+ forecasts.append(
39
+ QuantileForecast(
40
+ forecast_arrays=torch.cat((quantile_forecasts[i], mean_forecasts[i].unsqueeze(1)), dim=1)
41
+ .T.cpu()
42
+ .numpy(),
43
+ start_date=start_date,
44
+ item_id=meta[i].get(FieldName.ITEM_ID, None),
45
+ forecast_keys=list(map(str, quantile_levels)) + ["mean"],
46
+ )
47
+ )
48
+ return forecasts
tirex/api_adapter/hf_data.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import datasets
5
+ import torch
6
+
7
+ from .standard_adapter import _batch_pad_iterable
8
+
9
+ DEF_TARGET_COLUMN = "target"
10
+
11
+
12
+ def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
13
+ target_col = hf_kwargs.get("target_column", DEF_TARGET_COLUMN)
14
+ meta_columns = hf_kwargs.get("meta_columns", ())
15
+
16
+ columns_to_pass = [target_col] + list(meta_columns)
17
+ remove_cols = [col for col in dataset.column_names if col not in columns_to_pass]
18
+ dataset = (
19
+ dataset.with_format("torch")
20
+ .remove_columns(remove_cols)
21
+ .cast_column(target_col, datasets.Sequence(datasets.Value("float32")))
22
+ )
23
+
24
+ def yield_batch_tuples(sample: dict) -> tuple[torch.Tensor, dict]:
25
+ context_data = sample[target_col]
26
+ if context_data.ndim > 1:
27
+ context_data = context_data.squeeze()
28
+ assert context_data.ndim == 1
29
+ meta = {k: sample[k] for k in meta_columns if k in sample}
30
+ meta["length"] = len(context_data)
31
+ return context_data, meta
32
+
33
+ return dataset, yield_batch_tuples
34
+
35
+
36
+ def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
37
+ dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
38
+ return _batch_pad_iterable(map(map_func, dataset), batch_size)
tirex/api_adapter/standard_adapter.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import itertools
5
+ from collections.abc import Iterable, Iterator, Sequence
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ ContextType = Union[
12
+ torch.Tensor,
13
+ np.ndarray,
14
+ list[torch.Tensor],
15
+ list[np.ndarray],
16
+ ]
17
+
18
+
19
+ def _batched_slice(full_batch, full_meta: list[dict] | None, batch_size: int) -> Iterator[tuple[Sequence, list[dict]]]:
20
+ if len(full_batch) <= batch_size:
21
+ yield full_batch, full_meta if full_meta is not None else [{} for _ in range(len(full_batch))]
22
+ else:
23
+ for i in range(0, len(full_batch), batch_size):
24
+ batch = full_batch[i : i + batch_size]
25
+ yield batch, (full_meta[i : i + batch_size] if full_meta is not None else [{} for _ in range(len(batch))])
26
+
27
+
28
+ def _batched(iterable: Iterable, n: int):
29
+ it = iter(iterable)
30
+ while batch := tuple(itertools.islice(it, n)):
31
+ yield batch
32
+
33
+
34
+ def _batch_pad_iterable(iterable: Iterable[tuple[torch.Tensor, dict]], batch_size: int):
35
+ for batch in _batched(iterable, batch_size):
36
+ # ctx_it_len, ctx_it_data, it_meta = itertools.tee(batch, 3)
37
+ max_len = max(len(el[0]) for el in batch)
38
+ padded_batch = []
39
+ meta = []
40
+ for el in batch:
41
+ sample = el[0]
42
+ assert isinstance(sample, torch.Tensor)
43
+ assert sample.ndim == 1
44
+ assert len(sample) > 0, "Each sample needs to have a length > 0"
45
+ padding = torch.full(size=(max_len - len(sample),), fill_value=torch.nan, device=sample.device)
46
+ padded_batch.append(torch.cat((padding, sample)))
47
+ meta.append(el[1])
48
+ yield torch.stack(padded_batch), meta
49
+
50
+
51
+ def get_batches(context: ContextType, batch_size: int):
52
+ batches = None
53
+ if isinstance(context, torch.Tensor):
54
+ if context.ndim == 1:
55
+ context = context.unsqueeze(0)
56
+ assert context.ndim == 2
57
+ batches = _batched_slice(context, None, batch_size)
58
+ elif isinstance(context, np.ndarray):
59
+ if context.ndim == 1:
60
+ context = np.expand_dims(context, axis=0)
61
+ assert context.ndim == 2
62
+ batches = map(lambda x: (torch.Tensor(x[0]), x[1]), _batched_slice(context, None, batch_size))
63
+ elif isinstance(context, (list, Iterable)):
64
+ batches = _batch_pad_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
65
+ if batches is None:
66
+ raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
67
+ return batches
tirex/base.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import os
5
+ from abc import ABC, abstractmethod
6
+ from typing import TypeVar
7
+
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ T = TypeVar("T", bound="PretrainedModel")
11
+
12
+
13
+ def parse_hf_repo_id(path):
14
+ parts = path.split("/")
15
+ return "/".join(parts[0:2])
16
+
17
+
18
+ class PretrainedModel(ABC):
19
+ REGISTRY: dict[str, "PretrainedModel"] = {}
20
+
21
+ def __init_subclass__(cls, **kwargs):
22
+ super().__init_subclass__(**kwargs)
23
+ cls.REGISTRY[cls.register_name()] = cls
24
+
25
+ @classmethod
26
+ def from_pretrained(cls: type[T], path, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> T:
27
+ if hf_kwargs is None:
28
+ hf_kwargs = {}
29
+ if ckp_kwargs is None:
30
+ ckp_kwargs = {}
31
+ if os.path.exists(path):
32
+ print("Loading weights from local directory")
33
+ checkpoint_path = path
34
+ else:
35
+ repo_id = parse_hf_repo_id(path)
36
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
37
+ model = cls.load_from_checkpoint(checkpoint_path, map_location=device, **ckp_kwargs)
38
+ model.after_load_from_checkpoint()
39
+ return model
40
+
41
+ @classmethod
42
+ @abstractmethod
43
+ def register_name(cls) -> str:
44
+ pass
45
+
46
+ def after_load_from_checkpoint(self):
47
+ pass
48
+
49
+
50
+ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> PretrainedModel:
51
+ """Loads a TiRex model. This function attempts to load the specified model.
52
+
53
+ Args:
54
+ path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
55
+ device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
56
+ If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
57
+ hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
58
+ ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
59
+
60
+ Returns:
61
+ PretrainedModel: The loaded model.
62
+
63
+ Examples:
64
+ model: ForecastModel = load_model("NX-AI/TiRex")
65
+ """
66
+ try:
67
+ _, model_id = parse_hf_repo_id(path).split("/")
68
+ except:
69
+ raise ValueError(f"Invalid model path {path}")
70
+ model_cls = PretrainedModel.REGISTRY.get(model_id, None)
71
+ if model_cls is None:
72
+ raise ValueError(f"Invalid model id {model_id}")
73
+ return model_cls.from_pretrained(path, device=device, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
tirex/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
tirex/models/components.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ SCALER_STATE = "scaler_state"
11
+
12
+
13
+ class ResidualBlock(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim: int,
17
+ h_dim: int,
18
+ out_dim: int,
19
+ dropout: float = 0,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.dropout = torch.nn.Dropout(dropout)
23
+ self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
24
+ self.output_layer = torch.nn.Linear(h_dim, out_dim)
25
+ self.residual_layer = torch.nn.Linear(in_dim, out_dim)
26
+ self.act = torch.nn.ReLU()
27
+
28
+ def forward(self, x: torch.Tensor):
29
+ hid = self.act(self.hidden_layer(x))
30
+ out = self.output_layer(hid)
31
+ res = self.residual_layer(x)
32
+ out = out + res
33
+ return out
34
+
35
+
36
+ @dataclass
37
+ class StandardScaler:
38
+ eps: float = 1e-5
39
+ nan_loc: float = 0.0
40
+
41
+ def scale(
42
+ self,
43
+ x: torch.Tensor,
44
+ loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
45
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
46
+ if loc_scale is None:
47
+ loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
48
+ scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
49
+ scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
50
+ else:
51
+ loc, scale = loc_scale
52
+
53
+ return ((x - loc) / scale), (loc, scale)
54
+
55
+ def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
56
+ loc, scale = loc_scale
57
+ return x * scale + loc
58
+
59
+
60
+ @dataclass
61
+ class _Patcher:
62
+ patch_size: int
63
+ patch_stride: int
64
+ left_pad: bool
65
+
66
+ def __post_init__(self):
67
+ assert self.patch_size % self.patch_stride == 0
68
+
69
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
70
+ assert x.ndim == 2
71
+ length = x.shape[-1]
72
+
73
+ if length < self.patch_size or (length % self.patch_stride != 0):
74
+ if length < self.patch_size:
75
+ padding_size = (
76
+ *x.shape[:-1],
77
+ self.patch_size - (length % self.patch_size),
78
+ )
79
+ else:
80
+ padding_size = (
81
+ *x.shape[:-1],
82
+ self.patch_stride - (length % self.patch_stride),
83
+ )
84
+ padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
85
+ if self.left_pad:
86
+ x = torch.concat((padding, x), dim=-1)
87
+ else:
88
+ x = torch.concat((x, padding), dim=-1)
89
+
90
+ x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
91
+ return x
92
+
93
+
94
+ @dataclass
95
+ class PatchedUniTokenizer:
96
+ patch_size: int
97
+ scaler: Any = field(default_factory=StandardScaler)
98
+ patch_stride: int | None = None
99
+
100
+ def __post_init__(self):
101
+ if self.patch_stride is None:
102
+ self.patch_stride = self.patch_size
103
+ self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)
104
+
105
+ def context_input_transform(self, data: torch.Tensor):
106
+ assert data.ndim == 2
107
+ data, scale_state = self.scaler.scale(data)
108
+ return self.patcher(data), {SCALER_STATE: scale_state}
109
+
110
+ def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
111
+ data_shape = data.shape
112
+ data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
113
+ return data
114
+
115
+
116
+ class StreamToLogger:
117
+ """Fake file-like stream object that redirects writes to a logger
118
+ instance."""
119
+
120
+ def __init__(self, logger, log_level):
121
+ self.logger = logger
122
+ self.log_level = log_level
123
+ self.linebuf = "" # Buffer for partial lines
124
+
125
+ def write(self, message):
126
+ # Filter out empty messages (often from just a newline)
127
+ if message.strip():
128
+ self.linebuf += message
129
+ # If the message contains a newline, process the full line
130
+ if "\n" in self.linebuf:
131
+ lines = self.linebuf.splitlines(keepends=True)
132
+ for line in lines:
133
+ if line.endswith("\n"):
134
+ # Log full lines without the trailing newline (logger adds its own)
135
+ self.logger.log(self.log_level, line.rstrip("\n"))
136
+ else:
137
+ # Keep partial lines in buffer
138
+ self.linebuf = line
139
+ return
140
+ self.linebuf = "" # All lines processed
141
+ # If no newline, keep buffering
142
+
143
+ def flush(self):
144
+ # Log any remaining buffered content when flush is called
145
+ if self.linebuf.strip():
146
+ self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
147
+ self.linebuf = ""
tirex/models/mixed_stack.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+
5
+ import os
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+ from torch import nn
10
+ from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
11
+ from xlstm.xlstm_large import xLSTMLargeConfig
12
+ from xlstm.xlstm_large.components import RMSNorm
13
+ from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType
14
+
15
+
16
+ def skip_cuda():
17
+ return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
18
+
19
+
20
+ def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
21
+ return sLSTMLayer(
22
+ sLSTMLayerConfig(
23
+ embedding_dim=config.embedding_dim,
24
+ num_heads=config.num_heads,
25
+ conv1d_kernel_size=0, # 0 means no convolution included
26
+ group_norm_weight=True,
27
+ dropout=0,
28
+ # CellConfig
29
+ backend="vanilla" if skip_cuda() else "cuda",
30
+ bias_init="powerlaw_blockdependent",
31
+ recurrent_weight_init="zeros",
32
+ num_gates=4,
33
+ gradient_recurrent_cut=False,
34
+ gradient_recurrent_clipval=None,
35
+ forward_clipval=None,
36
+ batch_size=8, # needed?
37
+ _block_idx=block_idx,
38
+ _num_blocks=num_blocks,
39
+ )
40
+ )
41
+
42
+
43
+ sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
44
+ sLSTMStateType = dict[int, sLSTMLayerStateType]
45
+
46
+
47
+ class sLSTMBlock(nn.Module):
48
+ def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
49
+ super().__init__()
50
+ self.config = config
51
+ self.norm_slstm = RMSNorm(
52
+ num_features=config.embedding_dim,
53
+ eps=config.norm_eps,
54
+ use_weight=True,
55
+ use_bias=config.use_bias,
56
+ force_float32_reductions=config.norm_reduction_force_float32,
57
+ )
58
+ self.slstm_layer = init_cell(config, block_idx, num_blocks)
59
+
60
+ self.norm_ffn = RMSNorm(
61
+ num_features=config.embedding_dim,
62
+ eps=config.norm_eps,
63
+ use_weight=True,
64
+ use_bias=config.use_bias,
65
+ force_float32_reductions=config.norm_reduction_force_float32,
66
+ )
67
+ self.ffn = FeedForward(config)
68
+
69
+ def forward(
70
+ self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
71
+ ) -> tuple[torch.Tensor, sLSTMLayerStateType]:
72
+ x_slstm = self.norm_slstm(x)
73
+ if state is None:
74
+ conv_state, slstm_state = None, None
75
+ else:
76
+ conv_state, slstm_state = state
77
+ x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
78
+ x = x + x_slstm
79
+
80
+ x_ffn = self.norm_ffn(x)
81
+ x_ffn = self.ffn(x_ffn)
82
+ x = x + x_ffn
83
+
84
+ return x, (state["conv_state"], state["slstm_state"])
85
+
86
+
87
+ @dataclass
88
+ class xLSTMMixedLargeConfig(xLSTMLargeConfig):
89
+ slstm_at: list[int] = field(default_factory=list)
90
+ all_slstm: bool = True
91
+
92
+ @property
93
+ def block_types(self):
94
+ return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]
95
+
96
+
97
+ class xLSTMMixedLargeBlockStack(nn.Module):
98
+ config_class = xLSTMMixedLargeConfig
99
+
100
+ def __init__(self, config: xLSTMMixedLargeConfig):
101
+ super().__init__()
102
+ self.config = config
103
+
104
+ self.blocks = nn.ModuleList(
105
+ [
106
+ sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
107
+ for i, t in enumerate(config.block_types)
108
+ ]
109
+ )
110
+
111
+ if self.config.add_out_norm:
112
+ self.out_norm = RMSNorm(
113
+ num_features=config.embedding_dim,
114
+ eps=config.norm_eps,
115
+ use_weight=True,
116
+ use_bias=config.use_bias,
117
+ force_float32_reductions=config.norm_reduction_force_float32,
118
+ )
119
+ else:
120
+ self.out_norm = nn.Identity()
121
+
122
+ def forward(
123
+ self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
124
+ ) -> tuple[torch.Tensor, mLSTMStateType]:
125
+ if state is None:
126
+ state = {i: None for i in range(len(self.blocks))}
127
+
128
+ for i, block in enumerate(self.blocks):
129
+ block_state = state[i]
130
+ x, block_state_new = block(x, block_state)
131
+
132
+ if block_state is None:
133
+ state[i] = block_state_new
134
+ else:
135
+ pass
136
+ ## layer state is a tuple of three tensors: c, n, m
137
+ ## we update the state in place in order to avoid creating new tensors
138
+ # for state_idx in range(len(block_state)):
139
+ # state[i][state_idx].copy_(block_state_new[state_idx])
140
+
141
+ x = self.out_norm(x)
142
+
143
+ return x, state
tirex/models/predict_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+
5
+ import logging
6
+ from abc import abstractmethod
7
+
8
+ import torch
9
+
10
+ from ..api_adapter.forecast import ForecastModel
11
+
12
+ LOGGER = logging.getLogger()
13
+
14
+
15
+ class TensorQuantileUniPredictMixin(ForecastModel):
16
+ @abstractmethod
17
+ def _forecast_tensor(
18
+ self,
19
+ context: torch.Tensor,
20
+ prediction_length: int | None = None,
21
+ **predict_kwargs,
22
+ ) -> torch.Tensor:
23
+ pass
24
+
25
+ @property
26
+ @abstractmethod
27
+ def quantiles(self):
28
+ pass
29
+
30
+ def _forecast_quantiles(
31
+ self,
32
+ context: torch.Tensor,
33
+ prediction_length: int | None = None,
34
+ quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
35
+ output_device: str = "cpu",
36
+ auto_cast: bool = False,
37
+ **predict_kwargs,
38
+ ) -> tuple[torch.Tensor, torch.Tensor]:
39
+ with torch.autocast(device_type=self.device.type, enabled=auto_cast):
40
+ predictions = self._forecast_tensor(
41
+ context=context, prediction_length=prediction_length, **predict_kwargs
42
+ ).detach()
43
+ predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
44
+
45
+ training_quantile_levels = list(self.quantiles)
46
+
47
+ if set(quantile_levels).issubset(set(training_quantile_levels)):
48
+ quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
49
+ else:
50
+ if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
51
+ training_quantile_levels
52
+ ):
53
+ logging.warning(
54
+ f"Requested quantile levels ({quantile_levels}) fall outside the range of "
55
+ f"quantiles the model was trained on ({training_quantile_levels}). "
56
+ "Predictions for out-of-range quantiles will be clamped to the nearest "
57
+ "boundary of the trained quantiles (i.e., minimum or maximum trained level). "
58
+ "This can significantly impact prediction accuracy, especially for extreme quantiles. "
59
+ )
60
+ # Interpolate quantiles
61
+ augmented_predictions = torch.cat(
62
+ [predictions[..., [0]], predictions, predictions[..., [-1]]],
63
+ dim=-1,
64
+ )
65
+ quantiles = torch.quantile(
66
+ augmented_predictions,
67
+ q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
68
+ dim=-1,
69
+ ).permute(1, 2, 0)
70
+ # median as mean
71
+ mean = predictions[:, :, training_quantile_levels.index(0.5)]
72
+ return quantiles, mean
tirex/models/tirex.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import logging
5
+ import warnings
6
+ from contextlib import redirect_stdout
7
+ from dataclasses import dataclass
8
+
9
+ import lightning as L
10
+ import torch
11
+ from dacite import Config, from_dict
12
+
13
+ from ..base import PretrainedModel
14
+ from .components import PatchedUniTokenizer, ResidualBlock, StreamToLogger
15
+ from .mixed_stack import skip_cuda, xLSTMMixedLargeBlockStack, xLSTMMixedLargeConfig
16
+ from .predict_utils import TensorQuantileUniPredictMixin
17
+
18
+ LOGGER = logging.getLogger()
19
+
20
+
21
+ @dataclass
22
+ class TiRexZeroConfig:
23
+ input_patch_size: int
24
+ output_patch_size: int
25
+ quantiles: list[float]
26
+ block_kwargs: dict
27
+ input_ff_dim: int
28
+
29
+
30
+ class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixin):
31
+ def __init__(self, model_config: dict, train_ctx_len=None):
32
+ super().__init__()
33
+ self.model_config: TiRexZeroConfig = from_dict(TiRexZeroConfig, model_config, config=Config(strict=True))
34
+ assert self.model_config.input_patch_size == self.model_config.output_patch_size
35
+ self.train_ctx_len = train_ctx_len
36
+
37
+ # Block Stack
38
+ self.nan_mask_value = 0
39
+ self.block_stack, resolved_config = self.init_block(self.model_config.block_kwargs)
40
+ self.model_config.block_kwargs = resolved_config
41
+
42
+ # Input Layer
43
+ self.input_patch_embedding = ResidualBlock(
44
+ in_dim=self.model_config.input_patch_size * 2,
45
+ h_dim=self.model_config.input_ff_dim,
46
+ out_dim=self.model_config.block_kwargs.embedding_dim,
47
+ )
48
+ self.tokenizer = PatchedUniTokenizer(
49
+ patch_size=self.model_config.input_patch_size,
50
+ )
51
+
52
+ # Output Layer
53
+ self.num_quantiles = len(self.model_config.quantiles)
54
+ quantiles = torch.tensor(self.model_config.quantiles)
55
+ self.register_buffer("quantiles", quantiles, persistent=False)
56
+
57
+ self.output_patch_embedding = ResidualBlock(
58
+ in_dim=self.model_config.block_kwargs.embedding_dim,
59
+ h_dim=self.model_config.input_ff_dim,
60
+ out_dim=self.num_quantiles * self.model_config.output_patch_size,
61
+ )
62
+
63
+ self.save_hyperparameters()
64
+
65
+ @classmethod
66
+ def register_name(cls):
67
+ return "TiRex"
68
+
69
+ def init_block(self, block_kwargs):
70
+ config = from_dict(xLSTMMixedLargeConfig, block_kwargs)
71
+ log_redirect = StreamToLogger(LOGGER, logging.INFO)
72
+ with redirect_stdout(log_redirect): # avoid excessive print statements of sLSTM compile
73
+ model = xLSTMMixedLargeBlockStack(config)
74
+ return model, config
75
+
76
+ @property
77
+ def quantiles(self):
78
+ return self.model.quantiles
79
+
80
+ def _forward_model_tokenized(
81
+ self,
82
+ input_token,
83
+ input_mask=None,
84
+ rollouts=1,
85
+ ):
86
+ input_mask = (
87
+ input_mask.to(input_token.dtype)
88
+ if input_mask is not None
89
+ else torch.isnan(input_token).logical_not().to(input_token.dtype)
90
+ )
91
+ assert rollouts >= 1
92
+ bs, numb_ctx_token, token_dim = input_token.shape
93
+ if rollouts > 1:
94
+ input_token = torch.cat(
95
+ (
96
+ input_token,
97
+ torch.full(
98
+ (bs, rollouts - 1, token_dim),
99
+ fill_value=torch.nan,
100
+ device=input_token.device,
101
+ dtype=input_token.dtype,
102
+ ),
103
+ ),
104
+ dim=1,
105
+ )
106
+ input_mask = torch.cat(
107
+ (
108
+ input_mask,
109
+ torch.full(
110
+ (bs, rollouts - 1, token_dim),
111
+ fill_value=False,
112
+ device=input_mask.device,
113
+ dtype=input_mask.dtype,
114
+ ),
115
+ ),
116
+ dim=1,
117
+ )
118
+ input_token = torch.nan_to_num(input_token, nan=self.nan_mask_value)
119
+ input_embeds = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
120
+
121
+ # hidden_states = []
122
+ # for rollout in range(rollout):
123
+ x = self.block_stack(input_embeds)
124
+ if isinstance(x, tuple):
125
+ hidden_states = x[0]
126
+ else:
127
+ hidden_states = x
128
+
129
+ quantile_preds = self.output_patch_embedding(hidden_states)
130
+ quantile_preds = torch.unflatten(quantile_preds, -1, (self.num_quantiles, self.model_config.output_patch_size))
131
+ quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
132
+ # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
133
+
134
+ return quantile_preds, hidden_states
135
+
136
+ @torch.inference_mode()
137
+ def _forecast_tensor(
138
+ self,
139
+ context: torch.Tensor,
140
+ prediction_length: int | None = None,
141
+ max_context: int | None = None,
142
+ max_accelerated_rollout_steps: int = 1,
143
+ ) -> torch.Tensor:
144
+ predictions = []
145
+ if prediction_length is None:
146
+ prediction_length = self.tokenizer.patch_size
147
+ remaining = -(prediction_length // -self.tokenizer.patch_size)
148
+ if max_context is None:
149
+ max_context = self.train_ctx_len
150
+ min_context = max(self.train_ctx_len, max_context)
151
+
152
+ context = context.to(
153
+ device=self.device,
154
+ dtype=torch.float32,
155
+ )
156
+ while remaining > 0:
157
+ if context.shape[-1] > max_context:
158
+ context = context[..., -max_context:]
159
+ if context.shape[-1] < min_context:
160
+ pad = torch.full(
161
+ (context.shape[0], min_context - context.shape[-1]),
162
+ fill_value=torch.nan,
163
+ device=context.device,
164
+ dtype=context.dtype,
165
+ )
166
+ context = torch.concat((pad, context), dim=1)
167
+ tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
168
+ fut_rollouts = min(remaining, max_accelerated_rollout_steps)
169
+ with torch.no_grad():
170
+ prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=fut_rollouts)
171
+ prediction = prediction[:, :, -fut_rollouts:, :].to(tokenized_tensor) # predicted token
172
+ # [bs, num_quantiles, num_predicted_token, output_patch_size]
173
+ prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
174
+ prediction = prediction.flatten(start_dim=2)
175
+
176
+ predictions.append(prediction)
177
+ remaining -= fut_rollouts
178
+
179
+ if remaining <= 0:
180
+ break
181
+
182
+ context = torch.cat([context, torch.full_like(prediction[:, 0, :], fill_value=torch.nan)], dim=-1)
183
+
184
+ return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
185
+ dtype=torch.float32,
186
+ )
187
+
188
+ def on_load_checkpoint(self, checkpoint: dict) -> None:
189
+ state_dict = checkpoint["state_dict"]
190
+ load_vanilla_kernel = skip_cuda()
191
+ if load_vanilla_kernel:
192
+ warnings.warn(
193
+ "You use TiRex without sLSTM CUDA kernels! This might slow down the model considerably and might degrade forecasting results!"
194
+ "Set the environment variable TIREX_NO_CUDA to 0 to avoid this!"
195
+ )
196
+ block_kwargs = self.model_config.block_kwargs
197
+ head_dim = block_kwargs.embedding_dim // block_kwargs.num_heads
198
+ num_gates = 4
199
+ new_state_dict = {}
200
+ for k, v in state_dict.items():
201
+ if "slstm_layer.slstm_cell._recurrent_kernel_" in k:
202
+ new_state_dict[k] = (
203
+ v.reshape(
204
+ block_kwargs.num_heads,
205
+ head_dim,
206
+ num_gates,
207
+ head_dim,
208
+ )
209
+ .permute(0, 2, 3, 1)
210
+ .reshape(
211
+ block_kwargs.num_heads,
212
+ num_gates * head_dim,
213
+ head_dim,
214
+ )
215
+ )
216
+ # new_state_dict[k] = v.permute(0, 2, 1)
217
+ elif "slstm_layer.slstm_cell._bias_" in k:
218
+ new_state_dict[k] = (
219
+ v.reshape(block_kwargs.num_heads, num_gates, head_dim).permute(1, 0, 2).reshape(-1)
220
+ )
221
+ else:
222
+ new_state_dict[k] = v
223
+ checkpoint["state_dict"] = new_state_dict
224
+
225
+ def after_load_from_checkpoint(self):
226
+ if not skip_cuda() and self.device.type != "cuda":
227
+ warnings.warn(
228
+ f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {self.device.type})!"
229
+ "This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
230
+ "If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
231
+ )