Spaces:
Running
on
T4
Running
on
T4
File size: 4,991 Bytes
14d91dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
from dataclasses import dataclass, field
from typing import Any
import torch
SCALER_STATE = "scaler_state"
class ResidualBlock(torch.nn.Module):
def __init__(
self,
in_dim: int,
h_dim: int,
out_dim: int,
dropout: float = 0,
) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(dropout)
self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
self.output_layer = torch.nn.Linear(h_dim, out_dim)
self.residual_layer = torch.nn.Linear(in_dim, out_dim)
self.act = torch.nn.ReLU()
def forward(self, x: torch.Tensor):
hid = self.act(self.hidden_layer(x))
out = self.output_layer(hid)
res = self.residual_layer(x)
out = out + res
return out
@dataclass
class StandardScaler:
eps: float = 1e-5
nan_loc: float = 0.0
def scale(
self,
x: torch.Tensor,
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if loc_scale is None:
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
else:
loc, scale = loc_scale
return ((x - loc) / scale), (loc, scale)
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
loc, scale = loc_scale
return x * scale + loc
@dataclass
class _Patcher:
patch_size: int
patch_stride: int
left_pad: bool
def __post_init__(self):
assert self.patch_size % self.patch_stride == 0
def __call__(self, x: torch.Tensor) -> torch.Tensor:
assert x.ndim == 2
length = x.shape[-1]
if length < self.patch_size or (length % self.patch_stride != 0):
if length < self.patch_size:
padding_size = (
*x.shape[:-1],
self.patch_size - (length % self.patch_size),
)
else:
padding_size = (
*x.shape[:-1],
self.patch_stride - (length % self.patch_stride),
)
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
if self.left_pad:
x = torch.concat((padding, x), dim=-1)
else:
x = torch.concat((x, padding), dim=-1)
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
return x
@dataclass
class PatchedUniTokenizer:
patch_size: int
scaler: Any = field(default_factory=StandardScaler)
patch_stride: int | None = None
def __post_init__(self):
if self.patch_stride is None:
self.patch_stride = self.patch_size
self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)
def context_input_transform(self, data: torch.Tensor):
assert data.ndim == 2
data, scale_state = self.scaler.scale(data)
return self.patcher(data), {SCALER_STATE: scale_state}
def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
data_shape = data.shape
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
return data
class StreamToLogger:
"""Fake file-like stream object that redirects writes to a logger
instance."""
def __init__(self, logger, log_level):
self.logger = logger
self.log_level = log_level
self.linebuf = "" # Buffer for partial lines
def write(self, message):
# Filter out empty messages (often from just a newline)
if message.strip():
self.linebuf += message
# If the message contains a newline, process the full line
if "\n" in self.linebuf:
lines = self.linebuf.splitlines(keepends=True)
for line in lines:
if line.endswith("\n"):
# Log full lines without the trailing newline (logger adds its own)
self.logger.log(self.log_level, line.rstrip("\n"))
else:
# Keep partial lines in buffer
self.linebuf = line
return
self.linebuf = "" # All lines processed
# If no newline, keep buffering
def flush(self):
# Log any remaining buffered content when flush is called
if self.linebuf.strip():
self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
self.linebuf = ""
|