from functools import reduce import math import torch import torch.nn as nn import numpy as np from .kernel import exported_tdp import torch.nn.functional as F from functools import partial from timm.models.layers import trunc_normal_ class TimeDependentParameter(nn.Module): def __init__(self, shape, init_fn): super().__init__() self.shape = shape w = torch.empty(*shape) init_fn(w) self.param0 = nn.Parameter(w.clone().detach()) self.param1 = nn.Parameter(w.clone().detach()) self.nodecay_weight = nn.Parameter(torch.zeros(*shape)) self.nodecay_bias = nn.Parameter(torch.zeros(*shape)) self.curr_weight = None def forward(self): weight = self.curr_weight # self.curr_weight = None return weight def __repr__(self): return f"TimeDependentParameter(shape={self.shape})" @staticmethod def seed_time(model, log_snr): assert log_snr.dim() == 1 if torch.all(log_snr == log_snr[0]): log_snr = log_snr[0][None] time_condition = log_snr / 4.0 tdp_list = [module for module in model.modules() if isinstance(module, TimeDependentParameter)] for tdp in tdp_list: tdp.curr_weight = exported_tdp(tdp.param0, tdp.param1, tdp.nodecay_weight + 1, tdp.nodecay_bias, time_condition, custom = False) class LayerNorm(nn.Module): def __init__(self, dim, num_groups = 1, eps = 1e-05): super().__init__() self.eps = eps self.dim = dim self.num_groups = num_groups self.weight = TimeDependentParameter((dim, ), nn.init.ones_) self.bias = TimeDependentParameter((dim, ), nn.init.zeros_) def _forward(self, x): weight, bias = self.weight(), self.bias() assert weight.shape[0] == bias.shape[0] assert x.shape[-1] == self.dim if weight.shape[0] == 1: x = F.layer_norm(x, (self.dim, ), weight = weight[0], bias = bias[0], eps = self.eps) else: assert x.shape[0] == weight.shape[0] x = F.layer_norm(x, (self.dim, ), eps = self.eps) x = torch.addcmul(bias[:, None, :], weight[:, None, :], x) return x def forward(self, x): original_shape = x.shape batch_size = x.shape[0] assert self.dim == x.shape[-1] x = x.reshape(batch_size, -1, self.dim) x = self._forward(x) x = x.reshape(*original_shape) return x class Linear(nn.Module): def __init__(self, din, dout, bias = True, weight_init_fn = partial(trunc_normal_, std = 0.02)): super().__init__() self.din = din self.dout = dout self.weight = TimeDependentParameter((din, dout), weight_init_fn) if bias: self.bias = TimeDependentParameter((dout, ), nn.init.zeros_) else: self.bias = None def _forward(self, x): weight = self.weight() bias = self.bias() if self.bias is not None else None # if weight.shape[0] == 1: # B, L, D = x.shape # if bias is not None: # assert weight.shape[0] == bias.shape[0] # x = torch.addmm(bias, x.reshape(B * L, D), weight[0]) # else: # x = torch.matmul(x.reshape(B * L, D), weight[0]) # x = x.reshape(B, L, -1) # else: if bias is not None: x = torch.baddbmm(bias[:, None, :], x, weight) else: x = torch.bmm(x, weight) return x def forward(self, x): original_shape = x.shape batch_size = x.shape[0] x = x.reshape(batch_size, -1, self.din) x = self._forward(x) x = x.reshape(*(list(original_shape[:-1]) + [self.dout])) return x class RMSNorm(nn.Module): def __init__(self, d, p=-1., eps=1e-8, bias=False): """ Root Mean Square Layer Normalization :param d: model size :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) :param eps: epsilon value, default 1e-8 :param bias: whether use bias term for RMSNorm, disabled by default because RMSNorm doesn't enforce re-centering invariance. """ super(RMSNorm, self).__init__() self.eps = eps self.d = d self.p = p self.bias = bias self.scale = nn.Parameter(torch.ones(d)) self.register_parameter("scale", self.scale) if self.bias: self.offset = nn.Parameter(torch.zeros(d)) self.register_parameter("offset", self.offset) def forward(self, x): if self.p < 0. or self.p > 1.: norm_x = x.norm(2, dim=-1, keepdim=True) d_x = self.d else: partial_size = int(self.d * self.p) partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) norm_x = partial_x.norm(2, dim=-1, keepdim=True) d_x = partial_size rms_x = norm_x * d_x ** (-1. / 2) x_normed = x / (rms_x + self.eps) if self.bias: return self.scale * x_normed + self.offset return self.scale * x_normed class TDRMSNorm(nn.Module): def __init__(self, d, p=-1., eps=1e-8, bias=False): """ Root Mean Square Layer Normalization :param d: model size :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) :param eps: epsilon value, default 1e-8 :param bias: whether use bias term for RMSNorm, disabled by default because RMSNorm doesn't enforce re-centering invariance. """ super(TDRMSNorm, self).__init__() self.eps = eps self.d = d self.p = p self.bias = bias # self.scale = nn.Parameter(torch.ones(d)) self.scale = TimeDependentParameter((d, ), nn.init.ones_) # self.register_parameter("scale", self.scale) if self.bias: # self.offset = nn.Parameter(torch.zeros(d)) self.offset = TimeDependentParameter((d, ), nn.init.zeros_) # self.register_parameter("offset", self.offset) def forward(self, x): if self.p < 0. or self.p > 1.: norm_x = x.norm(2, dim=-1, keepdim=True) d_x = self.d else: partial_size = int(self.d * self.p) partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) norm_x = partial_x.norm(2, dim=-1, keepdim=True) d_x = partial_size rms_x = norm_x * d_x ** (-1. / 2) x_normed = x / (rms_x + self.eps) _scale = self.scale() if self.bias: # return self.scale * x_normed + self.offset _offset = self.offset() if _scale.shape[0] == 1: return _scale[0] * x_normed + _offset[0] elif x_normed.dim() == 3: return torch.addcmul(_offset[:, None, :], _scale[:, None, :], x_normed) elif x_normed.dim() == 4: return torch.addcmul(_offset[:, None, None, :], _scale[:, None, None, :], x_normed) else: raise NotImplementedError # return self.scale * x_normed if _scale.shape[0] == 1: return _scale[0] * x_normed elif x_normed.dim() == 3: return _scale[:, None, :] * x_normed elif x_normed.dim() == 4: return _scale[:, None, None, :] * x_normed else: raise NotImplementedError def zero_init(layer): nn.init.zeros_(layer.weight) if layer.bias is not None: nn.init.zeros_(layer.bias) return layer def rms_norm(x, scale, eps): dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) return x * scale.to(x.dtype) class AdaRMSNorm(nn.Module): def __init__(self, features, cond_features, eps=1e-6): super().__init__() self.eps = eps self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) def extra_repr(self): return f"eps={self.eps}," def forward(self, x, cond): return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) class QKNorm(nn.Module): def __init__(self, n_heads, eps=1e-6, max_scale=100.0): super().__init__() self.eps = eps self.max_scale = math.log(max_scale) self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0))) self.proj_() def extra_repr(self): return f"n_heads={self.scale.shape[0]}, eps={self.eps}" @torch.no_grad() def proj_(self): """Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped to the max value.""" self.scale.clamp_(max=self.max_scale) def forward(self, x): self.proj_() scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1])) return rms_norm(x, scale[:, None, None], self.eps)