File size: 4,991 Bytes
14d91dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) NXAI GmbH.
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.


from dataclasses import dataclass, field
from typing import Any

import torch

SCALER_STATE = "scaler_state"


class ResidualBlock(torch.nn.Module):
    def __init__(
        self,
        in_dim: int,
        h_dim: int,
        out_dim: int,
        dropout: float = 0,
    ) -> None:
        super().__init__()
        self.dropout = torch.nn.Dropout(dropout)
        self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
        self.output_layer = torch.nn.Linear(h_dim, out_dim)
        self.residual_layer = torch.nn.Linear(in_dim, out_dim)
        self.act = torch.nn.ReLU()

    def forward(self, x: torch.Tensor):
        hid = self.act(self.hidden_layer(x))
        out = self.output_layer(hid)
        res = self.residual_layer(x)
        out = out + res
        return out


@dataclass
class StandardScaler:
    eps: float = 1e-5
    nan_loc: float = 0.0

    def scale(
        self,
        x: torch.Tensor,
        loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        if loc_scale is None:
            loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
            scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
            scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
        else:
            loc, scale = loc_scale

        return ((x - loc) / scale), (loc, scale)

    def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        loc, scale = loc_scale
        return x * scale + loc


@dataclass
class _Patcher:
    patch_size: int
    patch_stride: int
    left_pad: bool

    def __post_init__(self):
        assert self.patch_size % self.patch_stride == 0

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 2
        length = x.shape[-1]

        if length < self.patch_size or (length % self.patch_stride != 0):
            if length < self.patch_size:
                padding_size = (
                    *x.shape[:-1],
                    self.patch_size - (length % self.patch_size),
                )
            else:
                padding_size = (
                    *x.shape[:-1],
                    self.patch_stride - (length % self.patch_stride),
                )
            padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
            if self.left_pad:
                x = torch.concat((padding, x), dim=-1)
            else:
                x = torch.concat((x, padding), dim=-1)

        x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
        return x


@dataclass
class PatchedUniTokenizer:
    patch_size: int
    scaler: Any = field(default_factory=StandardScaler)
    patch_stride: int | None = None

    def __post_init__(self):
        if self.patch_stride is None:
            self.patch_stride = self.patch_size
        self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)

    def context_input_transform(self, data: torch.Tensor):
        assert data.ndim == 2
        data, scale_state = self.scaler.scale(data)
        return self.patcher(data), {SCALER_STATE: scale_state}

    def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
        data_shape = data.shape
        data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
        return data


class StreamToLogger:
    """Fake file-like stream object that redirects writes to a logger
    instance."""

    def __init__(self, logger, log_level):
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ""  # Buffer for partial lines

    def write(self, message):
        # Filter out empty messages (often from just a newline)
        if message.strip():
            self.linebuf += message
            # If the message contains a newline, process the full line
            if "\n" in self.linebuf:
                lines = self.linebuf.splitlines(keepends=True)
                for line in lines:
                    if line.endswith("\n"):
                        # Log full lines without the trailing newline (logger adds its own)
                        self.logger.log(self.log_level, line.rstrip("\n"))
                    else:
                        # Keep partial lines in buffer
                        self.linebuf = line
                        return
                self.linebuf = ""  # All lines processed
            # If no newline, keep buffering

    def flush(self):
        # Log any remaining buffered content when flush is called
        if self.linebuf.strip():
            self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
            self.linebuf = ""