Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch | |
| import torch.distributed as dist | |
| class GlobalAvgPool2d(nn.Module): | |
| def __init__(self): | |
| """Global average pooling over the input's spatial dimensions""" | |
| super(GlobalAvgPool2d, self).__init__() | |
| def forward(self, inputs): | |
| in_size = inputs.size() | |
| return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) | |
| class SingleGPU(nn.Module): | |
| def __init__(self, module): | |
| super(SingleGPU, self).__init__() | |
| self.module=module | |
| def forward(self, input): | |
| return self.module(input.cuda(non_blocking=True)) | |