Spaces:
Runtime error
Runtime error
# flake8: noqa | |
from torch.autograd import Function, Variable | |
from torch.nn.modules.module import Module | |
import channelnorm_cuda | |
class ChannelNormFunction(Function): | |
def forward(ctx, input1, norm_deg=2): | |
assert input1.is_contiguous() | |
b, _, h, w = input1.size() | |
output = input1.new(b, 1, h, w).zero_() | |
channelnorm_cuda.forward(input1, output, norm_deg) | |
ctx.save_for_backward(input1, output) | |
ctx.norm_deg = norm_deg | |
return output | |
def backward(ctx, grad_output): | |
input1, output = ctx.saved_tensors | |
grad_input1 = Variable(input1.new(input1.size()).zero_()) | |
channelnorm_cuda.backward(input1, output, grad_output.data, | |
grad_input1.data, ctx.norm_deg) | |
return grad_input1, None | |
class ChannelNorm(Module): | |
def __init__(self, norm_deg=2): | |
super(ChannelNorm, self).__init__() | |
self.norm_deg = norm_deg | |
def forward(self, input1): | |
return ChannelNormFunction.apply(input1, self.norm_deg) | |