4893ce0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch import torch.nn as nn torch.nn.LayerNorm class LayerNorm1d(nn.BatchNorm1d): def forward(self, input: torch.Tensor) -> torch.Tensor: return ( super() .forward(input.transpose(1, 2).contiguous()) .transpose(1, 2) .contiguous() )