TiRex-demo / tirex /models /components.py
Nikita
added tirex as model
14d91dc
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
from dataclasses import dataclass, field
from typing import Any
import torch
SCALER_STATE = "scaler_state"
class ResidualBlock(torch.nn.Module):
def __init__(
self,
in_dim: int,
h_dim: int,
out_dim: int,
dropout: float = 0,
) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(dropout)
self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
self.output_layer = torch.nn.Linear(h_dim, out_dim)
self.residual_layer = torch.nn.Linear(in_dim, out_dim)
self.act = torch.nn.ReLU()
def forward(self, x: torch.Tensor):
hid = self.act(self.hidden_layer(x))
out = self.output_layer(hid)
res = self.residual_layer(x)
out = out + res
return out
@dataclass
class StandardScaler:
eps: float = 1e-5
nan_loc: float = 0.0
def scale(
self,
x: torch.Tensor,
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if loc_scale is None:
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
else:
loc, scale = loc_scale
return ((x - loc) / scale), (loc, scale)
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
loc, scale = loc_scale
return x * scale + loc
@dataclass
class _Patcher:
patch_size: int
patch_stride: int
left_pad: bool
def __post_init__(self):
assert self.patch_size % self.patch_stride == 0
def __call__(self, x: torch.Tensor) -> torch.Tensor:
assert x.ndim == 2
length = x.shape[-1]
if length < self.patch_size or (length % self.patch_stride != 0):
if length < self.patch_size:
padding_size = (
*x.shape[:-1],
self.patch_size - (length % self.patch_size),
)
else:
padding_size = (
*x.shape[:-1],
self.patch_stride - (length % self.patch_stride),
)
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
if self.left_pad:
x = torch.concat((padding, x), dim=-1)
else:
x = torch.concat((x, padding), dim=-1)
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
return x
@dataclass
class PatchedUniTokenizer:
patch_size: int
scaler: Any = field(default_factory=StandardScaler)
patch_stride: int | None = None
def __post_init__(self):
if self.patch_stride is None:
self.patch_stride = self.patch_size
self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)
def context_input_transform(self, data: torch.Tensor):
assert data.ndim == 2
data, scale_state = self.scaler.scale(data)
return self.patcher(data), {SCALER_STATE: scale_state}
def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
data_shape = data.shape
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
return data
class StreamToLogger:
"""Fake file-like stream object that redirects writes to a logger
instance."""
def __init__(self, logger, log_level):
self.logger = logger
self.log_level = log_level
self.linebuf = "" # Buffer for partial lines
def write(self, message):
# Filter out empty messages (often from just a newline)
if message.strip():
self.linebuf += message
# If the message contains a newline, process the full line
if "\n" in self.linebuf:
lines = self.linebuf.splitlines(keepends=True)
for line in lines:
if line.endswith("\n"):
# Log full lines without the trailing newline (logger adds its own)
self.logger.log(self.log_level, line.rstrip("\n"))
else:
# Keep partial lines in buffer
self.linebuf = line
return
self.linebuf = "" # All lines processed
# If no newline, keep buffering
def flush(self):
# Log any remaining buffered content when flush is called
if self.linebuf.strip():
self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
self.linebuf = ""