Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) NXAI GmbH. | |
# This software may be used and distributed according to the terms of the NXAI Community License Agreement. | |
from dataclasses import dataclass, field | |
from typing import Any | |
import torch | |
SCALER_STATE = "scaler_state" | |
class ResidualBlock(torch.nn.Module): | |
def __init__( | |
self, | |
in_dim: int, | |
h_dim: int, | |
out_dim: int, | |
dropout: float = 0, | |
) -> None: | |
super().__init__() | |
self.dropout = torch.nn.Dropout(dropout) | |
self.hidden_layer = torch.nn.Linear(in_dim, h_dim) | |
self.output_layer = torch.nn.Linear(h_dim, out_dim) | |
self.residual_layer = torch.nn.Linear(in_dim, out_dim) | |
self.act = torch.nn.ReLU() | |
def forward(self, x: torch.Tensor): | |
hid = self.act(self.hidden_layer(x)) | |
out = self.output_layer(hid) | |
res = self.residual_layer(x) | |
out = out + res | |
return out | |
class StandardScaler: | |
eps: float = 1e-5 | |
nan_loc: float = 0.0 | |
def scale( | |
self, | |
x: torch.Tensor, | |
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None, | |
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: | |
if loc_scale is None: | |
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc) | |
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0) | |
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale) | |
else: | |
loc, scale = loc_scale | |
return ((x - loc) / scale), (loc, scale) | |
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: | |
loc, scale = loc_scale | |
return x * scale + loc | |
class _Patcher: | |
patch_size: int | |
patch_stride: int | |
left_pad: bool | |
def __post_init__(self): | |
assert self.patch_size % self.patch_stride == 0 | |
def __call__(self, x: torch.Tensor) -> torch.Tensor: | |
assert x.ndim == 2 | |
length = x.shape[-1] | |
if length < self.patch_size or (length % self.patch_stride != 0): | |
if length < self.patch_size: | |
padding_size = ( | |
*x.shape[:-1], | |
self.patch_size - (length % self.patch_size), | |
) | |
else: | |
padding_size = ( | |
*x.shape[:-1], | |
self.patch_stride - (length % self.patch_stride), | |
) | |
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device) | |
if self.left_pad: | |
x = torch.concat((padding, x), dim=-1) | |
else: | |
x = torch.concat((x, padding), dim=-1) | |
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride) | |
return x | |
class PatchedUniTokenizer: | |
patch_size: int | |
scaler: Any = field(default_factory=StandardScaler) | |
patch_stride: int | None = None | |
def __post_init__(self): | |
if self.patch_stride is None: | |
self.patch_stride = self.patch_size | |
self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True) | |
def context_input_transform(self, data: torch.Tensor): | |
assert data.ndim == 2 | |
data, scale_state = self.scaler.scale(data) | |
return self.patcher(data), {SCALER_STATE: scale_state} | |
def output_transform(self, data: torch.Tensor, tokenizer_state: dict): | |
data_shape = data.shape | |
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape) | |
return data | |
class StreamToLogger: | |
"""Fake file-like stream object that redirects writes to a logger | |
instance.""" | |
def __init__(self, logger, log_level): | |
self.logger = logger | |
self.log_level = log_level | |
self.linebuf = "" # Buffer for partial lines | |
def write(self, message): | |
# Filter out empty messages (often from just a newline) | |
if message.strip(): | |
self.linebuf += message | |
# If the message contains a newline, process the full line | |
if "\n" in self.linebuf: | |
lines = self.linebuf.splitlines(keepends=True) | |
for line in lines: | |
if line.endswith("\n"): | |
# Log full lines without the trailing newline (logger adds its own) | |
self.logger.log(self.log_level, line.rstrip("\n")) | |
else: | |
# Keep partial lines in buffer | |
self.linebuf = line | |
return | |
self.linebuf = "" # All lines processed | |
# If no newline, keep buffering | |
def flush(self): | |
# Log any remaining buffered content when flush is called | |
if self.linebuf.strip(): | |
self.logger.log(self.log_level, self.linebuf.rstrip("\n")) | |
self.linebuf = "" | |