Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def carafe_forward( | |
features: torch.Tensor, | |
masks: torch.Tensor, | |
kernel_size: int, | |
group_size: int, | |
scale_factor: int | |
) -> torch.Tensor: | |
""" | |
Pure-PyTorch implementation of the CARAFE upsampling operator. | |
Args: | |
features (Tensor): Input feature map of shape (N, C, H, W). | |
masks (Tensor): Reassembly kernel weights of shape | |
(N, kernel_size*kernel_size*group_size, H_out, W_out), | |
where H_out = H*scale_factor and W_out = W*scale_factor. | |
kernel_size (int): The spatial size of the reassembly kernel. | |
group_size (int): The group size to divide channels. Must divide C. | |
scale_factor (int): The upsampling factor. | |
Returns: | |
Tensor: Upsampled feature map of shape (N, C, H*scale_factor, W*scale_factor). | |
""" | |
N, C, H, W = features.size() | |
out_H, out_W = H * scale_factor, W * scale_factor | |
num_channels = C // group_size # channels per group | |
# Reshape features to (N, group_size, num_channels, H, W) | |
features = features.view(N, group_size, num_channels, H, W) | |
# Merge batch and group dims for unfolding | |
features_reshaped = features.view(N * group_size, num_channels, H, W) | |
# Extract local patches; use padding so that output spatial dims match input | |
patches = F.unfold(features_reshaped, kernel_size=kernel_size, | |
padding=(kernel_size - 1) // 2) | |
# patches shape: (N*group_size, num_channels*kernel_size*kernel_size, H*W) | |
# Reshape to (N, group_size, num_channels, kernel_size*kernel_size, H, W) | |
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H, W) | |
# Flatten spatial dimensions: now (N, group_size, num_channels, kernel_size*kernel_size, H*W) | |
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H * W) | |
# For each output pixel location, determine the corresponding base input index. | |
# For an output coordinate (oh, ow), the corresponding input index is: | |
# h = oh // scale_factor, w = ow // scale_factor, linear index = h * W + w. | |
device = features.device | |
# Create coordinate indices for output | |
h_idx = torch.div(torch.arange(out_H, device=device), scale_factor, rounding_mode='floor') # (out_H,) | |
w_idx = torch.div(torch.arange(out_W, device=device), scale_factor, rounding_mode='floor') # (out_W,) | |
# Form a 2D grid of base indices (shape: out_H x out_W) | |
h_idx = h_idx.unsqueeze(1).expand(out_H, out_W) # (out_H, out_W) | |
w_idx = w_idx.unsqueeze(0).expand(out_H, out_W) # (out_H, out_W) | |
base_idx = (h_idx * W + w_idx).view(-1) # (out_H*out_W,) | |
# Expand base_idx so that it can index the last dimension of patches: | |
# Desired shape for gathering: (N, group_size, num_channels, kernel_size*kernel_size, out_H*out_W) | |
base_idx = base_idx.view(1, 1, 1, 1, -1).expand(N, group_size, num_channels, kernel_size * kernel_size, -1) | |
# Gather patches corresponding to each output location | |
gathered_patches = torch.gather(patches, -1, base_idx) | |
# Reshape gathered patches to (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W) | |
gathered_patches = gathered_patches.view(N, group_size, num_channels, kernel_size * kernel_size, out_H, out_W) | |
# Reshape masks to separate groups. | |
# Expected mask shape: (N, kernel_size*kernel_size*group_size, out_H, out_W) | |
# Reshape to: (N, group_size, kernel_size*kernel_size, out_H, out_W) | |
masks = masks.view(N, group_size, kernel_size * kernel_size, out_H, out_W) | |
# For multiplication, add a channel dimension so that masks shape becomes | |
# (N, group_size, 1, kernel_size*kernel_size, out_H, out_W) | |
masks = masks.unsqueeze(2) | |
# Expand masks to match gathered_patches: (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W) | |
masks = masks.expand(-1, -1, num_channels, -1, -1, -1) | |
# Multiply patches with masks and sum over the kernel dimension. | |
# This yields the reassembled features for each output location. | |
out = (gathered_patches * masks).sum(dim=3) # shape: (N, group_size, num_channels, out_H, out_W) | |
# Reshape back to (N, C, out_H, out_W) | |
out = out.view(N, C, out_H, out_W) | |
return out | |
class CARAFE(nn.Module): | |
""" | |
CARAFE: Content-Aware ReAssembly of Features | |
This PyTorch module implements the CARAFE upsampling operator in pure Python. | |
Given an input feature map and its corresponding reassembly masks, the module | |
reassembles features from local patches to produce a higher-resolution output. | |
Args: | |
kernel_size (int): Reassembly kernel size. | |
group_size (int): Group size for channel grouping (must divide number of channels). | |
scale_factor (int): Upsample ratio. | |
""" | |
def __init__(self, kernel_size: int, group_size: int, scale_factor: int): | |
super(CARAFE, self).__init__() | |
self.kernel_size = kernel_size | |
self.group_size = group_size | |
self.scale_factor = scale_factor | |
def forward(self, features: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: | |
return carafe_forward(features, masks, self.kernel_size, self.group_size, self.scale_factor) | |
class CARAFEPack(nn.Module): | |
""" | |
A unified package of the CARAFE upsampler that contains: | |
1) A channel compressor. | |
2) A content encoder that predicts reassembly masks. | |
3) The CARAFE operator. | |
This is modeled after the official CARAFE package. | |
Args: | |
channels (int): Number of input feature channels. | |
scale_factor (int): Upsample ratio. | |
up_kernel (int): Kernel size for the CARAFE operator. | |
up_group (int): Group size for the CARAFE operator. | |
encoder_kernel (int): Kernel size of the content encoder. | |
encoder_dilation (int): Dilation rate for the content encoder. | |
compressed_channels (int): Output channels for the channel compressor. | |
""" | |
def __init__( | |
self, | |
channels: int, | |
scale_factor: int, | |
up_kernel: int = 5, | |
up_group: int = 1, | |
encoder_kernel: int = 3, | |
encoder_dilation: int = 1, | |
compressed_channels: int = 64 | |
): | |
super(CARAFEPack, self).__init__() | |
self.channels = channels | |
self.scale_factor = scale_factor | |
self.up_kernel = up_kernel | |
self.up_group = up_group | |
self.encoder_kernel = encoder_kernel | |
self.encoder_dilation = encoder_dilation | |
self.compressed_channels = compressed_channels | |
# Compress input channels. | |
self.channel_compressor = nn.Conv2d(channels, compressed_channels, kernel_size=1) | |
# Predict reassembly masks. | |
self.content_encoder = nn.Conv2d( | |
compressed_channels, | |
up_kernel * up_kernel * up_group * scale_factor * scale_factor, | |
kernel_size=encoder_kernel, | |
padding=int((encoder_kernel - 1) * encoder_dilation / 2), | |
dilation=encoder_dilation | |
) | |
# Initialize weights (using Xavier for conv layers). | |
nn.init.xavier_uniform_(self.channel_compressor.weight) | |
nn.init.xavier_uniform_(self.content_encoder.weight) | |
if self.channel_compressor.bias is not None: | |
nn.init.constant_(self.channel_compressor.bias, 0) | |
if self.content_encoder.bias is not None: | |
nn.init.constant_(self.content_encoder.bias, 0) | |
def kernel_normalizer(self, mask: torch.Tensor) -> torch.Tensor: | |
""" | |
Normalize and reshape the mask. | |
Applies pixel shuffle to upsample the predicted kernel weights and then | |
applies softmax normalization across the kernel dimension. | |
Args: | |
mask (Tensor): Predicted mask of shape (N, out_channels, H, W). | |
Returns: | |
Tensor: Normalized mask of shape (N, up_group * up_kernel^2, H*scale, W*scale). | |
""" | |
# Pixel shuffle to rearrange and upsample the mask. | |
mask = F.pixel_shuffle(mask, self.scale_factor) | |
N, mask_c, H, W = mask.size() | |
# Determine the number of channels per kernel | |
mask_channel = mask_c // (self.up_kernel ** 2) | |
mask = mask.view(N, mask_channel, self.up_kernel ** 2, H, W) | |
mask = F.softmax(mask, dim=2) | |
mask = mask.view(N, mask_channel * self.up_kernel ** 2, H, W).contiguous() | |
return mask | |
def feature_reassemble(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
return carafe_forward(x, mask, self.up_kernel, self.up_group, self.scale_factor) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
compressed_x = self.channel_compressor(x) | |
mask = self.content_encoder(compressed_x) | |
mask = self.kernel_normalizer(mask) | |
out = self.feature_reassemble(x, mask) | |
return out | |
# === Example Usage === | |
if __name__ == '__main__': | |
# Create dummy input: batch size 2, 64 channels, 32x32 spatial resolution. | |
x = torch.randn(2, 64, 32, 32).cuda() # assuming GPU available | |
# Define CARAFEPack with upsample ratio 2. | |
# For example, use kernel size 5, group size 1. | |
upsampler = CARAFEPack(channels=64, scale_factor=2, up_kernel=5, up_group=1).cuda() | |
# Get upsampled feature map. | |
out = upsampler(x) | |
print("Input shape: ", x.shape) | |
print("Output shape:", out.shape) # Expected shape: (2, 64, 64, 64) | |