import torch from torch.autograd import Variable import torch.nn as nn import numpy as np from typing import List from torch.nn.modules.batchnorm import _BatchNorm from collections.abc import Iterable def norm(x, dims: List[int], EPS: float = 1e-8): mean = x.mean(dim=dims, keepdim=True) var2 = torch.var(x, dim=dims, keepdim=True, unbiased=False) value = (x - mean) / torch.sqrt(var2 + EPS) return value def glob_norm(x, ESP: float = 1e-8): dims: List[int] = torch.arange(1, len(x.shape)).tolist() return norm(x, dims, ESP) class MLayerNorm(nn.Module): def __init__(self, channel_size): super().__init__() self.channel_size = channel_size self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True) self.beta = nn.Parameter(torch.ones(channel_size), requires_grad=True) def apply_gain_and_bias(self, normed_x): """Assumes input of size `[batch, chanel, *]`.""" return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1) def forward(self, x, EPS: float = 1e-8): pass class GlobalLN(MLayerNorm): def forward(self, x, EPS: float = 1e-8): value = glob_norm(x, EPS) return self.apply_gain_and_bias(value) class ChannelLN(MLayerNorm): def forward(self, x, EPS: float = 1e-8): mean = torch.mean(x, dim=1, keepdim=True) var = torch.var(x, dim=1, keepdim=True, unbiased=False) return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt()) # class CumulateLN(MLayerNorm): # def forward(self, x, EPS: float = 1e-8): # batch, channels, time = x.size() # cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=1) # cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=1) # cnt = torch.arange( # start=channels, end=channels * (time + 1), step=channels, dtype=x.dtype, device=x.device # ).view(1, 1, -1) # cum_mean = cum_sum / cnt # cum_var = (cum_pow_sum / cnt) - cum_mean.pow(2) # return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt()) class BatchNorm(_BatchNorm): """Wrapper class for pytorch BatchNorm1D and BatchNorm2D""" def _check_input_dim(self, input): if input.dim() < 2 or input.dim() > 4: raise ValueError( "expected 4D or 3D input (got {}D input)".format(input.dim()) ) class CumulativeLayerNorm(nn.LayerNorm): def __init__(self, dim, elementwise_affine=True): super(CumulativeLayerNorm, self).__init__( dim, elementwise_affine=elementwise_affine ) def forward(self, x): # x: N x C x L # N x L x C x = torch.transpose(x, 1, -1) # N x L x C == only channel norm x = super().forward(x) # N x C x L x = torch.transpose(x, 1, -1) return x class CumulateLN(nn.Module): def __init__(self, dimension, eps=1e-8, trainable=True): super(CumulateLN, self).__init__() self.eps = eps if trainable: self.gain = nn.Parameter(torch.ones(1, dimension, 1)) self.bias = nn.Parameter(torch.zeros(1, dimension, 1)) else: self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False) self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False) def forward(self, input): # input size: (Batch, Freq, Time) # cumulative mean for each time step batch_size = input.size(0) channel = input.size(1) time_step = input.size(2) step_sum = input.sum(1) # B, T step_pow_sum = input.pow(2).sum(1) # B, T cum_sum = torch.cumsum(step_sum, dim=1) # B, T cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T entry_cnt = np.arange(channel, channel * (time_step + 1), channel) entry_cnt = torch.from_numpy(entry_cnt).type(input.type()) entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum) cum_mean = cum_sum / entry_cnt # B, T cum_var = (cum_pow_sum - 2 * cum_mean * cum_sum) / entry_cnt + cum_mean.pow( 2 ) # B, T cum_std = (cum_var + self.eps).sqrt() # B, T cum_mean = cum_mean.unsqueeze(1) cum_std = cum_std.unsqueeze(1) x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input) return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type( x.type() ) class LayerNormalization4D(nn.Module): def __init__(self, input_dimension: Iterable, eps: float = 1e-5): super(LayerNormalization4D, self).__init__() assert len(input_dimension) == 2 param_size = [1, input_dimension[0], 1, input_dimension[1]] self.dim = (1, 3) if param_size[-1] > 1 else (1,) self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32)) self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32)) nn.init.ones_(self.gamma) nn.init.zeros_(self.beta) self.eps = eps def forward(self, x: torch.Tensor): mu_ = x.mean(dim=self.dim, keepdim=True) std_ = torch.sqrt(x.var(dim=self.dim, unbiased=False, keepdim=True) + self.eps) x_hat = ((x - mu_) / std_) * self.gamma + self.beta return x_hat # Aliases. gLN = GlobalLN cLN = CumulateLN LN = CumulativeLayerNorm bN = BatchNorm LN4D = LayerNormalization4D def get(identifier): """Returns a norm class from a string. Returns its input if it is callable (already a :class:`._LayerNorm` for example). Args: identifier (str or Callable or None): the norm identifier. Returns: :class:`._LayerNorm` or None """ if identifier is None: return None elif callable(identifier): return identifier elif isinstance(identifier, str): cls = globals().get(identifier) if cls is None: raise ValueError( "Could not interpret normalization identifier: " + str(identifier) ) return cls else: raise ValueError( "Could not interpret normalization identifier: " + str(identifier) )