Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import torch | |
| from torch import nn | |
| class FrozenBatchNorm2d(nn.Module): | |
| """ | |
| BatchNorm2d where the batch statistics and the affine parameters | |
| are fixed | |
| """ | |
| def __init__(self, n): | |
| super(FrozenBatchNorm2d, self).__init__() | |
| self.register_buffer("weight", torch.ones(n)) | |
| self.register_buffer("bias", torch.zeros(n)) | |
| self.register_buffer("running_mean", torch.zeros(n)) | |
| self.register_buffer("running_var", torch.ones(n)) | |
| def forward(self, x): | |
| # Cast all fixed parameters to half() if necessary | |
| if x.dtype == torch.float16: | |
| self.weight = self.weight.half() | |
| self.bias = self.bias.half() | |
| self.running_mean = self.running_mean.half() | |
| self.running_var = self.running_var.half() | |
| scale = self.weight * self.running_var.rsqrt() | |
| bias = self.bias - self.running_mean * scale | |
| scale = scale.reshape(1, -1, 1, 1) | |
| bias = bias.reshape(1, -1, 1, 1) | |
| return x * scale + bias | |