Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| def channel_shuffle(x, groups): | |
| """Channel Shuffle operation. | |
| This function enables cross-group information flow for multiple groups | |
| convolution layers. | |
| Args: | |
| x (Tensor): The input tensor. | |
| groups (int): The number of groups to divide the input tensor | |
| in the channel dimension. | |
| Returns: | |
| Tensor: The output tensor after channel shuffle operation. | |
| """ | |
| batch_size, num_channels, height, width = x.size() | |
| assert (num_channels % groups == 0), ('num_channels should be ' | |
| 'divisible by groups') | |
| channels_per_group = num_channels // groups | |
| x = x.view(batch_size, groups, channels_per_group, height, width) | |
| x = torch.transpose(x, 1, 2).contiguous() | |
| x = x.view(batch_size, -1, height, width) | |
| return x | |