Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) NXAI GmbH. | |
# This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
import logging | |
import warnings | |
from contextlib import redirect_stdout | |
from dataclasses import dataclass | |
import lightning as L | |
import torch | |
from dacite import Config, from_dict | |
from ..base import PretrainedModel | |
from .components import PatchedUniTokenizer, ResidualBlock, StreamToLogger | |
from .mixed_stack import skip_cuda, xLSTMMixedLargeBlockStack, xLSTMMixedLargeConfig | |
from .predict_utils import TensorQuantileUniPredictMixin | |
LOGGER = logging.getLogger() | |
class TiRexZeroConfig: | |
input_patch_size: int | |
output_patch_size: int | |
quantiles: list[float] | |
block_kwargs: dict | |
input_ff_dim: int | |
class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixin): | |
def __init__(self, model_config: dict, train_ctx_len=None): | |
super().__init__() | |
self.model_config: TiRexZeroConfig = from_dict(TiRexZeroConfig, model_config, config=Config(strict=True)) | |
assert self.model_config.input_patch_size == self.model_config.output_patch_size | |
self.train_ctx_len = train_ctx_len | |
# Block Stack | |
self.nan_mask_value = 0 | |
self.block_stack, resolved_config = self.init_block(self.model_config.block_kwargs) | |
self.model_config.block_kwargs = resolved_config | |
# Input Layer | |
self.input_patch_embedding = ResidualBlock( | |
in_dim=self.model_config.input_patch_size * 2, | |
h_dim=self.model_config.input_ff_dim, | |
out_dim=self.model_config.block_kwargs.embedding_dim, | |
) | |
self.tokenizer = PatchedUniTokenizer( | |
patch_size=self.model_config.input_patch_size, | |
) | |
# Output Layer | |
self.num_quantiles = len(self.model_config.quantiles) | |
quantiles = torch.tensor(self.model_config.quantiles) | |
self.register_buffer("quantiles", quantiles, persistent=False) | |
self.output_patch_embedding = ResidualBlock( | |
in_dim=self.model_config.block_kwargs.embedding_dim, | |
h_dim=self.model_config.input_ff_dim, | |
out_dim=self.num_quantiles * self.model_config.output_patch_size, | |
) | |
self.save_hyperparameters() | |
def register_name(cls): | |
return "TiRex" | |
def init_block(self, block_kwargs): | |
config = from_dict(xLSTMMixedLargeConfig, block_kwargs) | |
log_redirect = StreamToLogger(LOGGER, logging.INFO) | |
with redirect_stdout(log_redirect): # avoid excessive print statements of sLSTM compile | |
model = xLSTMMixedLargeBlockStack(config) | |
return model, config | |
def quantiles(self): | |
return self.model.quantiles | |
def _forward_model_tokenized( | |
self, | |
input_token, | |
input_mask=None, | |
rollouts=1, | |
): | |
input_mask = ( | |
input_mask.to(input_token.dtype) | |
if input_mask is not None | |
else torch.isnan(input_token).logical_not().to(input_token.dtype) | |
) | |
assert rollouts >= 1 | |
bs, numb_ctx_token, token_dim = input_token.shape | |
if rollouts > 1: | |
input_token = torch.cat( | |
( | |
input_token, | |
torch.full( | |
(bs, rollouts - 1, token_dim), | |
fill_value=torch.nan, | |
device=input_token.device, | |
dtype=input_token.dtype, | |
), | |
), | |
dim=1, | |
) | |
input_mask = torch.cat( | |
( | |
input_mask, | |
torch.full( | |
(bs, rollouts - 1, token_dim), | |
fill_value=False, | |
device=input_mask.device, | |
dtype=input_mask.dtype, | |
), | |
), | |
dim=1, | |
) | |
input_token = torch.nan_to_num(input_token, nan=self.nan_mask_value) | |
input_embeds = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2)) | |
# hidden_states = [] | |
# for rollout in range(rollout): | |
x = self.block_stack(input_embeds) | |
if isinstance(x, tuple): | |
hidden_states = x[0] | |
else: | |
hidden_states = x | |
quantile_preds = self.output_patch_embedding(hidden_states) | |
quantile_preds = torch.unflatten(quantile_preds, -1, (self.num_quantiles, self.model_config.output_patch_size)) | |
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension | |
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size] | |
return quantile_preds, hidden_states | |
def _forecast_tensor( | |
self, | |
context: torch.Tensor, | |
prediction_length: int | None = None, | |
max_context: int | None = None, | |
max_accelerated_rollout_steps: int = 1, | |
) -> torch.Tensor: | |
predictions = [] | |
if prediction_length is None: | |
prediction_length = self.tokenizer.patch_size | |
remaining = -(prediction_length // -self.tokenizer.patch_size) | |
if max_context is None: | |
max_context = self.train_ctx_len | |
min_context = max(self.train_ctx_len, max_context) | |
context = context.to( | |
device=self.device, | |
dtype=torch.float32, | |
) | |
while remaining > 0: | |
if context.shape[-1] > max_context: | |
context = context[..., -max_context:] | |
if context.shape[-1] < min_context: | |
pad = torch.full( | |
(context.shape[0], min_context - context.shape[-1]), | |
fill_value=torch.nan, | |
device=context.device, | |
dtype=context.dtype, | |
) | |
context = torch.concat((pad, context), dim=1) | |
tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context) | |
fut_rollouts = min(remaining, max_accelerated_rollout_steps) | |
with torch.no_grad(): | |
prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=fut_rollouts) | |
prediction = prediction[:, :, -fut_rollouts:, :].to(tokenized_tensor) # predicted token | |
# [bs, num_quantiles, num_predicted_token, output_patch_size] | |
prediction = self.tokenizer.output_transform(prediction, tokenizer_state) | |
prediction = prediction.flatten(start_dim=2) | |
predictions.append(prediction) | |
remaining -= fut_rollouts | |
if remaining <= 0: | |
break | |
context = torch.cat([context, torch.full_like(prediction[:, 0, :], fill_value=torch.nan)], dim=-1) | |
return torch.cat(predictions, dim=-1)[..., :prediction_length].to( | |
dtype=torch.float32, | |
) | |
def on_load_checkpoint(self, checkpoint: dict) -> None: | |
state_dict = checkpoint["state_dict"] | |
load_vanilla_kernel = skip_cuda() | |
if load_vanilla_kernel: | |
warnings.warn( | |
"You use TiRex without sLSTM CUDA kernels! This might slow down the model considerably and might degrade forecasting results!" | |
"Set the environment variable TIREX_NO_CUDA to 0 to avoid this!" | |
) | |
block_kwargs = self.model_config.block_kwargs | |
head_dim = block_kwargs.embedding_dim // block_kwargs.num_heads | |
num_gates = 4 | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if "slstm_layer.slstm_cell._recurrent_kernel_" in k: | |
new_state_dict[k] = ( | |
v.reshape( | |
block_kwargs.num_heads, | |
head_dim, | |
num_gates, | |
head_dim, | |
) | |
.permute(0, 2, 3, 1) | |
.reshape( | |
block_kwargs.num_heads, | |
num_gates * head_dim, | |
head_dim, | |
) | |
) | |
# new_state_dict[k] = v.permute(0, 2, 1) | |
elif "slstm_layer.slstm_cell._bias_" in k: | |
new_state_dict[k] = ( | |
v.reshape(block_kwargs.num_heads, num_gates, head_dim).permute(1, 0, 2).reshape(-1) | |
) | |
else: | |
new_state_dict[k] = v | |
checkpoint["state_dict"] = new_state_dict | |
def after_load_from_checkpoint(self): | |
if not skip_cuda() and self.device.type != "cuda": | |
warnings.warn( | |
f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {self.device.type})!" | |
"This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!" | |
"If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)" | |
) | |