Spaces:
Running
on
T4
Running
on
T4
Nikita
commited on
Commit
·
14d91dc
1
Parent(s):
810611f
added tirex as model
Browse files- .DS_Store +0 -0
- tirex/__init__.py +8 -0
- tirex/api_adapter/__init__.py +2 -0
- tirex/api_adapter/forecast.py +209 -0
- tirex/api_adapter/gluon.py +48 -0
- tirex/api_adapter/hf_data.py +38 -0
- tirex/api_adapter/standard_adapter.py +67 -0
- tirex/base.py +73 -0
- tirex/models/__init__.py +2 -0
- tirex/models/components.py +147 -0
- tirex/models/mixed_stack.py +143 -0
- tirex/models/predict_utils.py +72 -0
- tirex/models/tirex.py +231 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
tirex/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
from .api_adapter.forecast import ForecastModel
|
5 |
+
from .base import load_model
|
6 |
+
from .models.tirex import TiRexZero
|
7 |
+
|
8 |
+
__all__ = ["load_model", "ForecastModel"]
|
tirex/api_adapter/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
tirex/api_adapter/forecast.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from typing import Literal
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .standard_adapter import ContextType, get_batches
|
10 |
+
|
11 |
+
try:
|
12 |
+
from .gluon import format_gluonts_output, get_gluon_batches
|
13 |
+
|
14 |
+
_GLUONTS_AVAILABLE = True
|
15 |
+
except ImportError:
|
16 |
+
_GLUONTS_AVAILABLE = False
|
17 |
+
|
18 |
+
try:
|
19 |
+
from .hf_data import get_hfdata_batches
|
20 |
+
|
21 |
+
_HF_DATASETS_AVAILABLE = True
|
22 |
+
except ImportError:
|
23 |
+
_HF_DATASETS_AVAILABLE = False
|
24 |
+
|
25 |
+
|
26 |
+
DEF_TARGET_COLUMN = "target"
|
27 |
+
DEF_META_COLUMNS = ("start", "item_id")
|
28 |
+
|
29 |
+
|
30 |
+
def _format_output(
|
31 |
+
quantiles: torch.Tensor,
|
32 |
+
means: torch.Tensor,
|
33 |
+
sample_meta: list[dict],
|
34 |
+
quantile_levels: list[float],
|
35 |
+
output_type: Literal["torch", "numpy", "gluonts"],
|
36 |
+
):
|
37 |
+
if output_type == "torch":
|
38 |
+
return quantiles.cpu(), means.cpu()
|
39 |
+
elif output_type == "numpy":
|
40 |
+
return quantiles.cpu().numpy(), means.cpu().numpy()
|
41 |
+
elif output_type == "gluonts":
|
42 |
+
if not _GLUONTS_AVAILABLE:
|
43 |
+
raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
|
44 |
+
return format_gluonts_output(quantiles, means, sample_meta, quantile_levels)
|
45 |
+
else:
|
46 |
+
raise ValueError(f"Invalid output type: {output_type}")
|
47 |
+
|
48 |
+
|
49 |
+
def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
|
50 |
+
for batch_ctx, batch_meta in batches:
|
51 |
+
quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
|
52 |
+
yield _format_output(
|
53 |
+
quantiles=quantiles,
|
54 |
+
means=mean,
|
55 |
+
sample_meta=batch_meta,
|
56 |
+
quantile_levels=quantile_levels,
|
57 |
+
output_type=output_type,
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
|
62 |
+
if yield_per_batch:
|
63 |
+
return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
|
64 |
+
|
65 |
+
prediction_q = []
|
66 |
+
prediction_m = []
|
67 |
+
sample_meta = []
|
68 |
+
for batch_ctx, batch_meta in batches:
|
69 |
+
quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
|
70 |
+
prediction_q.append(quantiles)
|
71 |
+
prediction_m.append(mean)
|
72 |
+
sample_meta.extend(batch_meta)
|
73 |
+
|
74 |
+
prediction_q = torch.cat(prediction_q, dim=0)
|
75 |
+
prediction_m = torch.cat(prediction_m, dim=0)
|
76 |
+
|
77 |
+
return _format_output(
|
78 |
+
quantiles=prediction_q,
|
79 |
+
means=prediction_m,
|
80 |
+
sample_meta=sample_meta,
|
81 |
+
quantile_levels=quantile_levels,
|
82 |
+
output_type=output_type,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
def _common_forecast_doc():
|
87 |
+
common_doc = f"""
|
88 |
+
This method takes historical context data as input and outputs probabilistic forecasts.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
output_type (Literal["torch", "numpy", "gluonts"], optional):
|
92 |
+
Specifies the desired format of the returned forecasts:
|
93 |
+
- "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|]
|
94 |
+
- "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|]
|
95 |
+
- "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
|
96 |
+
Defaults to "torch".
|
97 |
+
|
98 |
+
batch_size (int, optional): The number of time series instances to process concurrently by the model.
|
99 |
+
Defaults to 512. Must be $>= 1$.
|
100 |
+
|
101 |
+
quantile_levels (List[float], optional): Quantile levels for which predictions should be generated.
|
102 |
+
Defaults to (0.1, 0.2, ..., 0.9).
|
103 |
+
|
104 |
+
yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
|
105 |
+
forecasts batch by batch as they are computed.
|
106 |
+
Defaults to `False`.
|
107 |
+
|
108 |
+
**predict_kwargs: Additional keyword arguments that are passed directly to the underlying
|
109 |
+
prediction mechanism of the pre-trained model. Refer to the model's
|
110 |
+
internal prediction method documentation for available options.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
The return type depends on `output_type` and `yield_per_batch`:
|
114 |
+
- If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
|
115 |
+
will correspond to a batch of forecasts in the format specified by `output_type`.
|
116 |
+
- If `yield_per_batch` is `False`: A single object containing all forecasts.
|
117 |
+
- If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
|
118 |
+
- If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
|
119 |
+
- If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
|
120 |
+
"""
|
121 |
+
return common_doc
|
122 |
+
|
123 |
+
|
124 |
+
class ForecastModel(ABC):
|
125 |
+
@abstractmethod
|
126 |
+
def _forecast_quantiles(self, batch, **predict_kwargs):
|
127 |
+
pass
|
128 |
+
|
129 |
+
def forecast(
|
130 |
+
self,
|
131 |
+
context: ContextType,
|
132 |
+
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
|
133 |
+
batch_size: int = 512,
|
134 |
+
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
135 |
+
yield_per_batch: bool = False,
|
136 |
+
**predict_kwargs,
|
137 |
+
):
|
138 |
+
f"""
|
139 |
+
{_common_forecast_doc}
|
140 |
+
Args:
|
141 |
+
context (ContextType): The historical "context" data of the time series:
|
142 |
+
- `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor
|
143 |
+
- `np.ndarray`: 1D `[context_length]` or 2D `[batch_dim, context_length]` array
|
144 |
+
- `List[torch.Tensor]`: List of 1D tensors (samples with different lengths get padded per batch)
|
145 |
+
- `List[np.ndarray]`: List of 1D arrays (samples with different lengths get padded per batch)
|
146 |
+
"""
|
147 |
+
assert batch_size >= 1, "Batch size must be >= 1"
|
148 |
+
batches = get_batches(context, batch_size)
|
149 |
+
return _gen_forecast(
|
150 |
+
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
|
151 |
+
)
|
152 |
+
|
153 |
+
def forecast_gluon(
|
154 |
+
self,
|
155 |
+
gluonDataset,
|
156 |
+
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
|
157 |
+
batch_size: int = 512,
|
158 |
+
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
159 |
+
yield_per_batch: bool = False,
|
160 |
+
data_kwargs: dict = {},
|
161 |
+
**predict_kwargs,
|
162 |
+
):
|
163 |
+
f"""
|
164 |
+
{_common_forecast_doc()}
|
165 |
+
|
166 |
+
Args:
|
167 |
+
gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the
|
168 |
+
historical time series data.
|
169 |
+
|
170 |
+
data_kwargs (dict, optional): Additional keyword arguments passed to the
|
171 |
+
autogluon data processing function.
|
172 |
+
"""
|
173 |
+
assert batch_size >= 1, "Batch size must be >= 1"
|
174 |
+
if not _GLUONTS_AVAILABLE:
|
175 |
+
raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!")
|
176 |
+
batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
|
177 |
+
return _gen_forecast(
|
178 |
+
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
|
179 |
+
)
|
180 |
+
|
181 |
+
def forecast_hfdata(
|
182 |
+
self,
|
183 |
+
hf_dataset,
|
184 |
+
output_type: Literal["torch", "numpy", "gluonts"] = "torch",
|
185 |
+
batch_size: int = 512,
|
186 |
+
quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
|
187 |
+
yield_per_batch: bool = False,
|
188 |
+
data_kwargs: dict = {},
|
189 |
+
**predict_kwargs,
|
190 |
+
):
|
191 |
+
f"""
|
192 |
+
{_common_forecast_doc()}
|
193 |
+
|
194 |
+
Args:
|
195 |
+
hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the
|
196 |
+
historical time series data.
|
197 |
+
|
198 |
+
data_kwargs (dict, optional): Additional keyword arguments passed to the
|
199 |
+
datasets data processing function.
|
200 |
+
"""
|
201 |
+
assert batch_size >= 1, "Batch size must be >= 1"
|
202 |
+
if not _HF_DATASETS_AVAILABLE:
|
203 |
+
raise ValueError(
|
204 |
+
"forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!"
|
205 |
+
)
|
206 |
+
batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
|
207 |
+
return _gen_forecast(
|
208 |
+
self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
|
209 |
+
)
|
tirex/api_adapter/gluon.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
from gluonts.dataset.common import Dataset
|
7 |
+
from gluonts.dataset.field_names import FieldName
|
8 |
+
from gluonts.model.forecast import QuantileForecast
|
9 |
+
|
10 |
+
from .standard_adapter import _batch_pad_iterable
|
11 |
+
|
12 |
+
DEF_TARGET_COLUMN = FieldName.TARGET # target
|
13 |
+
DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
|
14 |
+
|
15 |
+
|
16 |
+
def _get_gluon_ts_map(**gluon_kwargs):
|
17 |
+
target_col = gluon_kwargs.get("target_column", DEF_TARGET_COLUMN)
|
18 |
+
meta_columns = gluon_kwargs.get("meta_columns", DEF_META_COLUMNS)
|
19 |
+
|
20 |
+
def extract_gluon(series):
|
21 |
+
ctx = torch.Tensor(series[target_col])
|
22 |
+
meta = {k: series[k] for k in meta_columns if k in series}
|
23 |
+
meta["length"] = len(ctx)
|
24 |
+
return ctx, meta
|
25 |
+
|
26 |
+
return extract_gluon
|
27 |
+
|
28 |
+
|
29 |
+
def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
|
30 |
+
return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
|
31 |
+
|
32 |
+
|
33 |
+
def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
|
34 |
+
forecasts = []
|
35 |
+
for i in range(quantile_forecasts.shape[0]):
|
36 |
+
start_date = meta[i].get(FieldName.START, pd.Period("01-01-2000", freq=meta[i].get("freq", "h")))
|
37 |
+
start_date += meta[i].get("length", 0)
|
38 |
+
forecasts.append(
|
39 |
+
QuantileForecast(
|
40 |
+
forecast_arrays=torch.cat((quantile_forecasts[i], mean_forecasts[i].unsqueeze(1)), dim=1)
|
41 |
+
.T.cpu()
|
42 |
+
.numpy(),
|
43 |
+
start_date=start_date,
|
44 |
+
item_id=meta[i].get(FieldName.ITEM_ID, None),
|
45 |
+
forecast_keys=list(map(str, quantile_levels)) + ["mean"],
|
46 |
+
)
|
47 |
+
)
|
48 |
+
return forecasts
|
tirex/api_adapter/hf_data.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
import datasets
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .standard_adapter import _batch_pad_iterable
|
8 |
+
|
9 |
+
DEF_TARGET_COLUMN = "target"
|
10 |
+
|
11 |
+
|
12 |
+
def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
|
13 |
+
target_col = hf_kwargs.get("target_column", DEF_TARGET_COLUMN)
|
14 |
+
meta_columns = hf_kwargs.get("meta_columns", ())
|
15 |
+
|
16 |
+
columns_to_pass = [target_col] + list(meta_columns)
|
17 |
+
remove_cols = [col for col in dataset.column_names if col not in columns_to_pass]
|
18 |
+
dataset = (
|
19 |
+
dataset.with_format("torch")
|
20 |
+
.remove_columns(remove_cols)
|
21 |
+
.cast_column(target_col, datasets.Sequence(datasets.Value("float32")))
|
22 |
+
)
|
23 |
+
|
24 |
+
def yield_batch_tuples(sample: dict) -> tuple[torch.Tensor, dict]:
|
25 |
+
context_data = sample[target_col]
|
26 |
+
if context_data.ndim > 1:
|
27 |
+
context_data = context_data.squeeze()
|
28 |
+
assert context_data.ndim == 1
|
29 |
+
meta = {k: sample[k] for k in meta_columns if k in sample}
|
30 |
+
meta["length"] = len(context_data)
|
31 |
+
return context_data, meta
|
32 |
+
|
33 |
+
return dataset, yield_batch_tuples
|
34 |
+
|
35 |
+
|
36 |
+
def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
|
37 |
+
dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
|
38 |
+
return _batch_pad_iterable(map(map_func, dataset), batch_size)
|
tirex/api_adapter/standard_adapter.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
import itertools
|
5 |
+
from collections.abc import Iterable, Iterator, Sequence
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
ContextType = Union[
|
12 |
+
torch.Tensor,
|
13 |
+
np.ndarray,
|
14 |
+
list[torch.Tensor],
|
15 |
+
list[np.ndarray],
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def _batched_slice(full_batch, full_meta: list[dict] | None, batch_size: int) -> Iterator[tuple[Sequence, list[dict]]]:
|
20 |
+
if len(full_batch) <= batch_size:
|
21 |
+
yield full_batch, full_meta if full_meta is not None else [{} for _ in range(len(full_batch))]
|
22 |
+
else:
|
23 |
+
for i in range(0, len(full_batch), batch_size):
|
24 |
+
batch = full_batch[i : i + batch_size]
|
25 |
+
yield batch, (full_meta[i : i + batch_size] if full_meta is not None else [{} for _ in range(len(batch))])
|
26 |
+
|
27 |
+
|
28 |
+
def _batched(iterable: Iterable, n: int):
|
29 |
+
it = iter(iterable)
|
30 |
+
while batch := tuple(itertools.islice(it, n)):
|
31 |
+
yield batch
|
32 |
+
|
33 |
+
|
34 |
+
def _batch_pad_iterable(iterable: Iterable[tuple[torch.Tensor, dict]], batch_size: int):
|
35 |
+
for batch in _batched(iterable, batch_size):
|
36 |
+
# ctx_it_len, ctx_it_data, it_meta = itertools.tee(batch, 3)
|
37 |
+
max_len = max(len(el[0]) for el in batch)
|
38 |
+
padded_batch = []
|
39 |
+
meta = []
|
40 |
+
for el in batch:
|
41 |
+
sample = el[0]
|
42 |
+
assert isinstance(sample, torch.Tensor)
|
43 |
+
assert sample.ndim == 1
|
44 |
+
assert len(sample) > 0, "Each sample needs to have a length > 0"
|
45 |
+
padding = torch.full(size=(max_len - len(sample),), fill_value=torch.nan, device=sample.device)
|
46 |
+
padded_batch.append(torch.cat((padding, sample)))
|
47 |
+
meta.append(el[1])
|
48 |
+
yield torch.stack(padded_batch), meta
|
49 |
+
|
50 |
+
|
51 |
+
def get_batches(context: ContextType, batch_size: int):
|
52 |
+
batches = None
|
53 |
+
if isinstance(context, torch.Tensor):
|
54 |
+
if context.ndim == 1:
|
55 |
+
context = context.unsqueeze(0)
|
56 |
+
assert context.ndim == 2
|
57 |
+
batches = _batched_slice(context, None, batch_size)
|
58 |
+
elif isinstance(context, np.ndarray):
|
59 |
+
if context.ndim == 1:
|
60 |
+
context = np.expand_dims(context, axis=0)
|
61 |
+
assert context.ndim == 2
|
62 |
+
batches = map(lambda x: (torch.Tensor(x[0]), x[1]), _batched_slice(context, None, batch_size))
|
63 |
+
elif isinstance(context, (list, Iterable)):
|
64 |
+
batches = _batch_pad_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
|
65 |
+
if batches is None:
|
66 |
+
raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
|
67 |
+
return batches
|
tirex/base.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
import os
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
from typing import TypeVar
|
7 |
+
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
T = TypeVar("T", bound="PretrainedModel")
|
11 |
+
|
12 |
+
|
13 |
+
def parse_hf_repo_id(path):
|
14 |
+
parts = path.split("/")
|
15 |
+
return "/".join(parts[0:2])
|
16 |
+
|
17 |
+
|
18 |
+
class PretrainedModel(ABC):
|
19 |
+
REGISTRY: dict[str, "PretrainedModel"] = {}
|
20 |
+
|
21 |
+
def __init_subclass__(cls, **kwargs):
|
22 |
+
super().__init_subclass__(**kwargs)
|
23 |
+
cls.REGISTRY[cls.register_name()] = cls
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def from_pretrained(cls: type[T], path, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> T:
|
27 |
+
if hf_kwargs is None:
|
28 |
+
hf_kwargs = {}
|
29 |
+
if ckp_kwargs is None:
|
30 |
+
ckp_kwargs = {}
|
31 |
+
if os.path.exists(path):
|
32 |
+
print("Loading weights from local directory")
|
33 |
+
checkpoint_path = path
|
34 |
+
else:
|
35 |
+
repo_id = parse_hf_repo_id(path)
|
36 |
+
checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
|
37 |
+
model = cls.load_from_checkpoint(checkpoint_path, map_location=device, **ckp_kwargs)
|
38 |
+
model.after_load_from_checkpoint()
|
39 |
+
return model
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
@abstractmethod
|
43 |
+
def register_name(cls) -> str:
|
44 |
+
pass
|
45 |
+
|
46 |
+
def after_load_from_checkpoint(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> PretrainedModel:
|
51 |
+
"""Loads a TiRex model. This function attempts to load the specified model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
|
55 |
+
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
|
56 |
+
If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
|
57 |
+
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
|
58 |
+
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
PretrainedModel: The loaded model.
|
62 |
+
|
63 |
+
Examples:
|
64 |
+
model: ForecastModel = load_model("NX-AI/TiRex")
|
65 |
+
"""
|
66 |
+
try:
|
67 |
+
_, model_id = parse_hf_repo_id(path).split("/")
|
68 |
+
except:
|
69 |
+
raise ValueError(f"Invalid model path {path}")
|
70 |
+
model_cls = PretrainedModel.REGISTRY.get(model_id, None)
|
71 |
+
if model_cls is None:
|
72 |
+
raise ValueError(f"Invalid model id {model_id}")
|
73 |
+
return model_cls.from_pretrained(path, device=device, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
|
tirex/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
tirex/models/components.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
SCALER_STATE = "scaler_state"
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualBlock(torch.nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_dim: int,
|
17 |
+
h_dim: int,
|
18 |
+
out_dim: int,
|
19 |
+
dropout: float = 0,
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self.dropout = torch.nn.Dropout(dropout)
|
23 |
+
self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
|
24 |
+
self.output_layer = torch.nn.Linear(h_dim, out_dim)
|
25 |
+
self.residual_layer = torch.nn.Linear(in_dim, out_dim)
|
26 |
+
self.act = torch.nn.ReLU()
|
27 |
+
|
28 |
+
def forward(self, x: torch.Tensor):
|
29 |
+
hid = self.act(self.hidden_layer(x))
|
30 |
+
out = self.output_layer(hid)
|
31 |
+
res = self.residual_layer(x)
|
32 |
+
out = out + res
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class StandardScaler:
|
38 |
+
eps: float = 1e-5
|
39 |
+
nan_loc: float = 0.0
|
40 |
+
|
41 |
+
def scale(
|
42 |
+
self,
|
43 |
+
x: torch.Tensor,
|
44 |
+
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
|
45 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
46 |
+
if loc_scale is None:
|
47 |
+
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
|
48 |
+
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
|
49 |
+
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
50 |
+
else:
|
51 |
+
loc, scale = loc_scale
|
52 |
+
|
53 |
+
return ((x - loc) / scale), (loc, scale)
|
54 |
+
|
55 |
+
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
56 |
+
loc, scale = loc_scale
|
57 |
+
return x * scale + loc
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class _Patcher:
|
62 |
+
patch_size: int
|
63 |
+
patch_stride: int
|
64 |
+
left_pad: bool
|
65 |
+
|
66 |
+
def __post_init__(self):
|
67 |
+
assert self.patch_size % self.patch_stride == 0
|
68 |
+
|
69 |
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
70 |
+
assert x.ndim == 2
|
71 |
+
length = x.shape[-1]
|
72 |
+
|
73 |
+
if length < self.patch_size or (length % self.patch_stride != 0):
|
74 |
+
if length < self.patch_size:
|
75 |
+
padding_size = (
|
76 |
+
*x.shape[:-1],
|
77 |
+
self.patch_size - (length % self.patch_size),
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
padding_size = (
|
81 |
+
*x.shape[:-1],
|
82 |
+
self.patch_stride - (length % self.patch_stride),
|
83 |
+
)
|
84 |
+
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
|
85 |
+
if self.left_pad:
|
86 |
+
x = torch.concat((padding, x), dim=-1)
|
87 |
+
else:
|
88 |
+
x = torch.concat((x, padding), dim=-1)
|
89 |
+
|
90 |
+
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class PatchedUniTokenizer:
|
96 |
+
patch_size: int
|
97 |
+
scaler: Any = field(default_factory=StandardScaler)
|
98 |
+
patch_stride: int | None = None
|
99 |
+
|
100 |
+
def __post_init__(self):
|
101 |
+
if self.patch_stride is None:
|
102 |
+
self.patch_stride = self.patch_size
|
103 |
+
self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)
|
104 |
+
|
105 |
+
def context_input_transform(self, data: torch.Tensor):
|
106 |
+
assert data.ndim == 2
|
107 |
+
data, scale_state = self.scaler.scale(data)
|
108 |
+
return self.patcher(data), {SCALER_STATE: scale_state}
|
109 |
+
|
110 |
+
def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
|
111 |
+
data_shape = data.shape
|
112 |
+
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
|
113 |
+
return data
|
114 |
+
|
115 |
+
|
116 |
+
class StreamToLogger:
|
117 |
+
"""Fake file-like stream object that redirects writes to a logger
|
118 |
+
instance."""
|
119 |
+
|
120 |
+
def __init__(self, logger, log_level):
|
121 |
+
self.logger = logger
|
122 |
+
self.log_level = log_level
|
123 |
+
self.linebuf = "" # Buffer for partial lines
|
124 |
+
|
125 |
+
def write(self, message):
|
126 |
+
# Filter out empty messages (often from just a newline)
|
127 |
+
if message.strip():
|
128 |
+
self.linebuf += message
|
129 |
+
# If the message contains a newline, process the full line
|
130 |
+
if "\n" in self.linebuf:
|
131 |
+
lines = self.linebuf.splitlines(keepends=True)
|
132 |
+
for line in lines:
|
133 |
+
if line.endswith("\n"):
|
134 |
+
# Log full lines without the trailing newline (logger adds its own)
|
135 |
+
self.logger.log(self.log_level, line.rstrip("\n"))
|
136 |
+
else:
|
137 |
+
# Keep partial lines in buffer
|
138 |
+
self.linebuf = line
|
139 |
+
return
|
140 |
+
self.linebuf = "" # All lines processed
|
141 |
+
# If no newline, keep buffering
|
142 |
+
|
143 |
+
def flush(self):
|
144 |
+
# Log any remaining buffered content when flush is called
|
145 |
+
if self.linebuf.strip():
|
146 |
+
self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
|
147 |
+
self.linebuf = ""
|
tirex/models/mixed_stack.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
|
5 |
+
import os
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
|
11 |
+
from xlstm.xlstm_large import xLSTMLargeConfig
|
12 |
+
from xlstm.xlstm_large.components import RMSNorm
|
13 |
+
from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType
|
14 |
+
|
15 |
+
|
16 |
+
def skip_cuda():
|
17 |
+
return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
|
18 |
+
|
19 |
+
|
20 |
+
def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
|
21 |
+
return sLSTMLayer(
|
22 |
+
sLSTMLayerConfig(
|
23 |
+
embedding_dim=config.embedding_dim,
|
24 |
+
num_heads=config.num_heads,
|
25 |
+
conv1d_kernel_size=0, # 0 means no convolution included
|
26 |
+
group_norm_weight=True,
|
27 |
+
dropout=0,
|
28 |
+
# CellConfig
|
29 |
+
backend="vanilla" if skip_cuda() else "cuda",
|
30 |
+
bias_init="powerlaw_blockdependent",
|
31 |
+
recurrent_weight_init="zeros",
|
32 |
+
num_gates=4,
|
33 |
+
gradient_recurrent_cut=False,
|
34 |
+
gradient_recurrent_clipval=None,
|
35 |
+
forward_clipval=None,
|
36 |
+
batch_size=8, # needed?
|
37 |
+
_block_idx=block_idx,
|
38 |
+
_num_blocks=num_blocks,
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
|
44 |
+
sLSTMStateType = dict[int, sLSTMLayerStateType]
|
45 |
+
|
46 |
+
|
47 |
+
class sLSTMBlock(nn.Module):
|
48 |
+
def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
|
49 |
+
super().__init__()
|
50 |
+
self.config = config
|
51 |
+
self.norm_slstm = RMSNorm(
|
52 |
+
num_features=config.embedding_dim,
|
53 |
+
eps=config.norm_eps,
|
54 |
+
use_weight=True,
|
55 |
+
use_bias=config.use_bias,
|
56 |
+
force_float32_reductions=config.norm_reduction_force_float32,
|
57 |
+
)
|
58 |
+
self.slstm_layer = init_cell(config, block_idx, num_blocks)
|
59 |
+
|
60 |
+
self.norm_ffn = RMSNorm(
|
61 |
+
num_features=config.embedding_dim,
|
62 |
+
eps=config.norm_eps,
|
63 |
+
use_weight=True,
|
64 |
+
use_bias=config.use_bias,
|
65 |
+
force_float32_reductions=config.norm_reduction_force_float32,
|
66 |
+
)
|
67 |
+
self.ffn = FeedForward(config)
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
|
71 |
+
) -> tuple[torch.Tensor, sLSTMLayerStateType]:
|
72 |
+
x_slstm = self.norm_slstm(x)
|
73 |
+
if state is None:
|
74 |
+
conv_state, slstm_state = None, None
|
75 |
+
else:
|
76 |
+
conv_state, slstm_state = state
|
77 |
+
x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
|
78 |
+
x = x + x_slstm
|
79 |
+
|
80 |
+
x_ffn = self.norm_ffn(x)
|
81 |
+
x_ffn = self.ffn(x_ffn)
|
82 |
+
x = x + x_ffn
|
83 |
+
|
84 |
+
return x, (state["conv_state"], state["slstm_state"])
|
85 |
+
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class xLSTMMixedLargeConfig(xLSTMLargeConfig):
|
89 |
+
slstm_at: list[int] = field(default_factory=list)
|
90 |
+
all_slstm: bool = True
|
91 |
+
|
92 |
+
@property
|
93 |
+
def block_types(self):
|
94 |
+
return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]
|
95 |
+
|
96 |
+
|
97 |
+
class xLSTMMixedLargeBlockStack(nn.Module):
|
98 |
+
config_class = xLSTMMixedLargeConfig
|
99 |
+
|
100 |
+
def __init__(self, config: xLSTMMixedLargeConfig):
|
101 |
+
super().__init__()
|
102 |
+
self.config = config
|
103 |
+
|
104 |
+
self.blocks = nn.ModuleList(
|
105 |
+
[
|
106 |
+
sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
|
107 |
+
for i, t in enumerate(config.block_types)
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
if self.config.add_out_norm:
|
112 |
+
self.out_norm = RMSNorm(
|
113 |
+
num_features=config.embedding_dim,
|
114 |
+
eps=config.norm_eps,
|
115 |
+
use_weight=True,
|
116 |
+
use_bias=config.use_bias,
|
117 |
+
force_float32_reductions=config.norm_reduction_force_float32,
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
self.out_norm = nn.Identity()
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
|
124 |
+
) -> tuple[torch.Tensor, mLSTMStateType]:
|
125 |
+
if state is None:
|
126 |
+
state = {i: None for i in range(len(self.blocks))}
|
127 |
+
|
128 |
+
for i, block in enumerate(self.blocks):
|
129 |
+
block_state = state[i]
|
130 |
+
x, block_state_new = block(x, block_state)
|
131 |
+
|
132 |
+
if block_state is None:
|
133 |
+
state[i] = block_state_new
|
134 |
+
else:
|
135 |
+
pass
|
136 |
+
## layer state is a tuple of three tensors: c, n, m
|
137 |
+
## we update the state in place in order to avoid creating new tensors
|
138 |
+
# for state_idx in range(len(block_state)):
|
139 |
+
# state[i][state_idx].copy_(block_state_new[state_idx])
|
140 |
+
|
141 |
+
x = self.out_norm(x)
|
142 |
+
|
143 |
+
return x, state
|
tirex/models/predict_utils.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from abc import abstractmethod
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from ..api_adapter.forecast import ForecastModel
|
11 |
+
|
12 |
+
LOGGER = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class TensorQuantileUniPredictMixin(ForecastModel):
|
16 |
+
@abstractmethod
|
17 |
+
def _forecast_tensor(
|
18 |
+
self,
|
19 |
+
context: torch.Tensor,
|
20 |
+
prediction_length: int | None = None,
|
21 |
+
**predict_kwargs,
|
22 |
+
) -> torch.Tensor:
|
23 |
+
pass
|
24 |
+
|
25 |
+
@property
|
26 |
+
@abstractmethod
|
27 |
+
def quantiles(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def _forecast_quantiles(
|
31 |
+
self,
|
32 |
+
context: torch.Tensor,
|
33 |
+
prediction_length: int | None = None,
|
34 |
+
quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
35 |
+
output_device: str = "cpu",
|
36 |
+
auto_cast: bool = False,
|
37 |
+
**predict_kwargs,
|
38 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
39 |
+
with torch.autocast(device_type=self.device.type, enabled=auto_cast):
|
40 |
+
predictions = self._forecast_tensor(
|
41 |
+
context=context, prediction_length=prediction_length, **predict_kwargs
|
42 |
+
).detach()
|
43 |
+
predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
|
44 |
+
|
45 |
+
training_quantile_levels = list(self.quantiles)
|
46 |
+
|
47 |
+
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
48 |
+
quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
|
49 |
+
else:
|
50 |
+
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
|
51 |
+
training_quantile_levels
|
52 |
+
):
|
53 |
+
logging.warning(
|
54 |
+
f"Requested quantile levels ({quantile_levels}) fall outside the range of "
|
55 |
+
f"quantiles the model was trained on ({training_quantile_levels}). "
|
56 |
+
"Predictions for out-of-range quantiles will be clamped to the nearest "
|
57 |
+
"boundary of the trained quantiles (i.e., minimum or maximum trained level). "
|
58 |
+
"This can significantly impact prediction accuracy, especially for extreme quantiles. "
|
59 |
+
)
|
60 |
+
# Interpolate quantiles
|
61 |
+
augmented_predictions = torch.cat(
|
62 |
+
[predictions[..., [0]], predictions, predictions[..., [-1]]],
|
63 |
+
dim=-1,
|
64 |
+
)
|
65 |
+
quantiles = torch.quantile(
|
66 |
+
augmented_predictions,
|
67 |
+
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
|
68 |
+
dim=-1,
|
69 |
+
).permute(1, 2, 0)
|
70 |
+
# median as mean
|
71 |
+
mean = predictions[:, :, training_quantile_levels.index(0.5)]
|
72 |
+
return quantiles, mean
|
tirex/models/tirex.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) NXAI GmbH.
|
2 |
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import warnings
|
6 |
+
from contextlib import redirect_stdout
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import lightning as L
|
10 |
+
import torch
|
11 |
+
from dacite import Config, from_dict
|
12 |
+
|
13 |
+
from ..base import PretrainedModel
|
14 |
+
from .components import PatchedUniTokenizer, ResidualBlock, StreamToLogger
|
15 |
+
from .mixed_stack import skip_cuda, xLSTMMixedLargeBlockStack, xLSTMMixedLargeConfig
|
16 |
+
from .predict_utils import TensorQuantileUniPredictMixin
|
17 |
+
|
18 |
+
LOGGER = logging.getLogger()
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class TiRexZeroConfig:
|
23 |
+
input_patch_size: int
|
24 |
+
output_patch_size: int
|
25 |
+
quantiles: list[float]
|
26 |
+
block_kwargs: dict
|
27 |
+
input_ff_dim: int
|
28 |
+
|
29 |
+
|
30 |
+
class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixin):
|
31 |
+
def __init__(self, model_config: dict, train_ctx_len=None):
|
32 |
+
super().__init__()
|
33 |
+
self.model_config: TiRexZeroConfig = from_dict(TiRexZeroConfig, model_config, config=Config(strict=True))
|
34 |
+
assert self.model_config.input_patch_size == self.model_config.output_patch_size
|
35 |
+
self.train_ctx_len = train_ctx_len
|
36 |
+
|
37 |
+
# Block Stack
|
38 |
+
self.nan_mask_value = 0
|
39 |
+
self.block_stack, resolved_config = self.init_block(self.model_config.block_kwargs)
|
40 |
+
self.model_config.block_kwargs = resolved_config
|
41 |
+
|
42 |
+
# Input Layer
|
43 |
+
self.input_patch_embedding = ResidualBlock(
|
44 |
+
in_dim=self.model_config.input_patch_size * 2,
|
45 |
+
h_dim=self.model_config.input_ff_dim,
|
46 |
+
out_dim=self.model_config.block_kwargs.embedding_dim,
|
47 |
+
)
|
48 |
+
self.tokenizer = PatchedUniTokenizer(
|
49 |
+
patch_size=self.model_config.input_patch_size,
|
50 |
+
)
|
51 |
+
|
52 |
+
# Output Layer
|
53 |
+
self.num_quantiles = len(self.model_config.quantiles)
|
54 |
+
quantiles = torch.tensor(self.model_config.quantiles)
|
55 |
+
self.register_buffer("quantiles", quantiles, persistent=False)
|
56 |
+
|
57 |
+
self.output_patch_embedding = ResidualBlock(
|
58 |
+
in_dim=self.model_config.block_kwargs.embedding_dim,
|
59 |
+
h_dim=self.model_config.input_ff_dim,
|
60 |
+
out_dim=self.num_quantiles * self.model_config.output_patch_size,
|
61 |
+
)
|
62 |
+
|
63 |
+
self.save_hyperparameters()
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def register_name(cls):
|
67 |
+
return "TiRex"
|
68 |
+
|
69 |
+
def init_block(self, block_kwargs):
|
70 |
+
config = from_dict(xLSTMMixedLargeConfig, block_kwargs)
|
71 |
+
log_redirect = StreamToLogger(LOGGER, logging.INFO)
|
72 |
+
with redirect_stdout(log_redirect): # avoid excessive print statements of sLSTM compile
|
73 |
+
model = xLSTMMixedLargeBlockStack(config)
|
74 |
+
return model, config
|
75 |
+
|
76 |
+
@property
|
77 |
+
def quantiles(self):
|
78 |
+
return self.model.quantiles
|
79 |
+
|
80 |
+
def _forward_model_tokenized(
|
81 |
+
self,
|
82 |
+
input_token,
|
83 |
+
input_mask=None,
|
84 |
+
rollouts=1,
|
85 |
+
):
|
86 |
+
input_mask = (
|
87 |
+
input_mask.to(input_token.dtype)
|
88 |
+
if input_mask is not None
|
89 |
+
else torch.isnan(input_token).logical_not().to(input_token.dtype)
|
90 |
+
)
|
91 |
+
assert rollouts >= 1
|
92 |
+
bs, numb_ctx_token, token_dim = input_token.shape
|
93 |
+
if rollouts > 1:
|
94 |
+
input_token = torch.cat(
|
95 |
+
(
|
96 |
+
input_token,
|
97 |
+
torch.full(
|
98 |
+
(bs, rollouts - 1, token_dim),
|
99 |
+
fill_value=torch.nan,
|
100 |
+
device=input_token.device,
|
101 |
+
dtype=input_token.dtype,
|
102 |
+
),
|
103 |
+
),
|
104 |
+
dim=1,
|
105 |
+
)
|
106 |
+
input_mask = torch.cat(
|
107 |
+
(
|
108 |
+
input_mask,
|
109 |
+
torch.full(
|
110 |
+
(bs, rollouts - 1, token_dim),
|
111 |
+
fill_value=False,
|
112 |
+
device=input_mask.device,
|
113 |
+
dtype=input_mask.dtype,
|
114 |
+
),
|
115 |
+
),
|
116 |
+
dim=1,
|
117 |
+
)
|
118 |
+
input_token = torch.nan_to_num(input_token, nan=self.nan_mask_value)
|
119 |
+
input_embeds = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
|
120 |
+
|
121 |
+
# hidden_states = []
|
122 |
+
# for rollout in range(rollout):
|
123 |
+
x = self.block_stack(input_embeds)
|
124 |
+
if isinstance(x, tuple):
|
125 |
+
hidden_states = x[0]
|
126 |
+
else:
|
127 |
+
hidden_states = x
|
128 |
+
|
129 |
+
quantile_preds = self.output_patch_embedding(hidden_states)
|
130 |
+
quantile_preds = torch.unflatten(quantile_preds, -1, (self.num_quantiles, self.model_config.output_patch_size))
|
131 |
+
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
132 |
+
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
133 |
+
|
134 |
+
return quantile_preds, hidden_states
|
135 |
+
|
136 |
+
@torch.inference_mode()
|
137 |
+
def _forecast_tensor(
|
138 |
+
self,
|
139 |
+
context: torch.Tensor,
|
140 |
+
prediction_length: int | None = None,
|
141 |
+
max_context: int | None = None,
|
142 |
+
max_accelerated_rollout_steps: int = 1,
|
143 |
+
) -> torch.Tensor:
|
144 |
+
predictions = []
|
145 |
+
if prediction_length is None:
|
146 |
+
prediction_length = self.tokenizer.patch_size
|
147 |
+
remaining = -(prediction_length // -self.tokenizer.patch_size)
|
148 |
+
if max_context is None:
|
149 |
+
max_context = self.train_ctx_len
|
150 |
+
min_context = max(self.train_ctx_len, max_context)
|
151 |
+
|
152 |
+
context = context.to(
|
153 |
+
device=self.device,
|
154 |
+
dtype=torch.float32,
|
155 |
+
)
|
156 |
+
while remaining > 0:
|
157 |
+
if context.shape[-1] > max_context:
|
158 |
+
context = context[..., -max_context:]
|
159 |
+
if context.shape[-1] < min_context:
|
160 |
+
pad = torch.full(
|
161 |
+
(context.shape[0], min_context - context.shape[-1]),
|
162 |
+
fill_value=torch.nan,
|
163 |
+
device=context.device,
|
164 |
+
dtype=context.dtype,
|
165 |
+
)
|
166 |
+
context = torch.concat((pad, context), dim=1)
|
167 |
+
tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
|
168 |
+
fut_rollouts = min(remaining, max_accelerated_rollout_steps)
|
169 |
+
with torch.no_grad():
|
170 |
+
prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=fut_rollouts)
|
171 |
+
prediction = prediction[:, :, -fut_rollouts:, :].to(tokenized_tensor) # predicted token
|
172 |
+
# [bs, num_quantiles, num_predicted_token, output_patch_size]
|
173 |
+
prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
|
174 |
+
prediction = prediction.flatten(start_dim=2)
|
175 |
+
|
176 |
+
predictions.append(prediction)
|
177 |
+
remaining -= fut_rollouts
|
178 |
+
|
179 |
+
if remaining <= 0:
|
180 |
+
break
|
181 |
+
|
182 |
+
context = torch.cat([context, torch.full_like(prediction[:, 0, :], fill_value=torch.nan)], dim=-1)
|
183 |
+
|
184 |
+
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
|
185 |
+
dtype=torch.float32,
|
186 |
+
)
|
187 |
+
|
188 |
+
def on_load_checkpoint(self, checkpoint: dict) -> None:
|
189 |
+
state_dict = checkpoint["state_dict"]
|
190 |
+
load_vanilla_kernel = skip_cuda()
|
191 |
+
if load_vanilla_kernel:
|
192 |
+
warnings.warn(
|
193 |
+
"You use TiRex without sLSTM CUDA kernels! This might slow down the model considerably and might degrade forecasting results!"
|
194 |
+
"Set the environment variable TIREX_NO_CUDA to 0 to avoid this!"
|
195 |
+
)
|
196 |
+
block_kwargs = self.model_config.block_kwargs
|
197 |
+
head_dim = block_kwargs.embedding_dim // block_kwargs.num_heads
|
198 |
+
num_gates = 4
|
199 |
+
new_state_dict = {}
|
200 |
+
for k, v in state_dict.items():
|
201 |
+
if "slstm_layer.slstm_cell._recurrent_kernel_" in k:
|
202 |
+
new_state_dict[k] = (
|
203 |
+
v.reshape(
|
204 |
+
block_kwargs.num_heads,
|
205 |
+
head_dim,
|
206 |
+
num_gates,
|
207 |
+
head_dim,
|
208 |
+
)
|
209 |
+
.permute(0, 2, 3, 1)
|
210 |
+
.reshape(
|
211 |
+
block_kwargs.num_heads,
|
212 |
+
num_gates * head_dim,
|
213 |
+
head_dim,
|
214 |
+
)
|
215 |
+
)
|
216 |
+
# new_state_dict[k] = v.permute(0, 2, 1)
|
217 |
+
elif "slstm_layer.slstm_cell._bias_" in k:
|
218 |
+
new_state_dict[k] = (
|
219 |
+
v.reshape(block_kwargs.num_heads, num_gates, head_dim).permute(1, 0, 2).reshape(-1)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
new_state_dict[k] = v
|
223 |
+
checkpoint["state_dict"] = new_state_dict
|
224 |
+
|
225 |
+
def after_load_from_checkpoint(self):
|
226 |
+
if not skip_cuda() and self.device.type != "cuda":
|
227 |
+
warnings.warn(
|
228 |
+
f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {self.device.type})!"
|
229 |
+
"This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
|
230 |
+
"If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
|
231 |
+
)
|