ZIP / models /utils /carafe.py
Yiming-M's picture
2025-07-31 18:59 🐣
a7dedf9
import torch
import torch.nn as nn
import torch.nn.functional as F
def carafe_forward(
features: torch.Tensor,
masks: torch.Tensor,
kernel_size: int,
group_size: int,
scale_factor: int
) -> torch.Tensor:
"""
Pure-PyTorch implementation of the CARAFE upsampling operator.
Args:
features (Tensor): Input feature map of shape (N, C, H, W).
masks (Tensor): Reassembly kernel weights of shape
(N, kernel_size*kernel_size*group_size, H_out, W_out),
where H_out = H*scale_factor and W_out = W*scale_factor.
kernel_size (int): The spatial size of the reassembly kernel.
group_size (int): The group size to divide channels. Must divide C.
scale_factor (int): The upsampling factor.
Returns:
Tensor: Upsampled feature map of shape (N, C, H*scale_factor, W*scale_factor).
"""
N, C, H, W = features.size()
out_H, out_W = H * scale_factor, W * scale_factor
num_channels = C // group_size # channels per group
# Reshape features to (N, group_size, num_channels, H, W)
features = features.view(N, group_size, num_channels, H, W)
# Merge batch and group dims for unfolding
features_reshaped = features.view(N * group_size, num_channels, H, W)
# Extract local patches; use padding so that output spatial dims match input
patches = F.unfold(features_reshaped, kernel_size=kernel_size,
padding=(kernel_size - 1) // 2)
# patches shape: (N*group_size, num_channels*kernel_size*kernel_size, H*W)
# Reshape to (N, group_size, num_channels, kernel_size*kernel_size, H, W)
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H, W)
# Flatten spatial dimensions: now (N, group_size, num_channels, kernel_size*kernel_size, H*W)
patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H * W)
# For each output pixel location, determine the corresponding base input index.
# For an output coordinate (oh, ow), the corresponding input index is:
# h = oh // scale_factor, w = ow // scale_factor, linear index = h * W + w.
device = features.device
# Create coordinate indices for output
h_idx = torch.div(torch.arange(out_H, device=device), scale_factor, rounding_mode='floor') # (out_H,)
w_idx = torch.div(torch.arange(out_W, device=device), scale_factor, rounding_mode='floor') # (out_W,)
# Form a 2D grid of base indices (shape: out_H x out_W)
h_idx = h_idx.unsqueeze(1).expand(out_H, out_W) # (out_H, out_W)
w_idx = w_idx.unsqueeze(0).expand(out_H, out_W) # (out_H, out_W)
base_idx = (h_idx * W + w_idx).view(-1) # (out_H*out_W,)
# Expand base_idx so that it can index the last dimension of patches:
# Desired shape for gathering: (N, group_size, num_channels, kernel_size*kernel_size, out_H*out_W)
base_idx = base_idx.view(1, 1, 1, 1, -1).expand(N, group_size, num_channels, kernel_size * kernel_size, -1)
# Gather patches corresponding to each output location
gathered_patches = torch.gather(patches, -1, base_idx)
# Reshape gathered patches to (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
gathered_patches = gathered_patches.view(N, group_size, num_channels, kernel_size * kernel_size, out_H, out_W)
# Reshape masks to separate groups.
# Expected mask shape: (N, kernel_size*kernel_size*group_size, out_H, out_W)
# Reshape to: (N, group_size, kernel_size*kernel_size, out_H, out_W)
masks = masks.view(N, group_size, kernel_size * kernel_size, out_H, out_W)
# For multiplication, add a channel dimension so that masks shape becomes
# (N, group_size, 1, kernel_size*kernel_size, out_H, out_W)
masks = masks.unsqueeze(2)
# Expand masks to match gathered_patches: (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
masks = masks.expand(-1, -1, num_channels, -1, -1, -1)
# Multiply patches with masks and sum over the kernel dimension.
# This yields the reassembled features for each output location.
out = (gathered_patches * masks).sum(dim=3) # shape: (N, group_size, num_channels, out_H, out_W)
# Reshape back to (N, C, out_H, out_W)
out = out.view(N, C, out_H, out_W)
return out
class CARAFE(nn.Module):
"""
CARAFE: Content-Aware ReAssembly of Features
This PyTorch module implements the CARAFE upsampling operator in pure Python.
Given an input feature map and its corresponding reassembly masks, the module
reassembles features from local patches to produce a higher-resolution output.
Args:
kernel_size (int): Reassembly kernel size.
group_size (int): Group size for channel grouping (must divide number of channels).
scale_factor (int): Upsample ratio.
"""
def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
super(CARAFE, self).__init__()
self.kernel_size = kernel_size
self.group_size = group_size
self.scale_factor = scale_factor
def forward(self, features: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
return carafe_forward(features, masks, self.kernel_size, self.group_size, self.scale_factor)
class CARAFEPack(nn.Module):
"""
A unified package of the CARAFE upsampler that contains:
1) A channel compressor.
2) A content encoder that predicts reassembly masks.
3) The CARAFE operator.
This is modeled after the official CARAFE package.
Args:
channels (int): Number of input feature channels.
scale_factor (int): Upsample ratio.
up_kernel (int): Kernel size for the CARAFE operator.
up_group (int): Group size for the CARAFE operator.
encoder_kernel (int): Kernel size of the content encoder.
encoder_dilation (int): Dilation rate for the content encoder.
compressed_channels (int): Output channels for the channel compressor.
"""
def __init__(
self,
channels: int,
scale_factor: int,
up_kernel: int = 5,
up_group: int = 1,
encoder_kernel: int = 3,
encoder_dilation: int = 1,
compressed_channels: int = 64
):
super(CARAFEPack, self).__init__()
self.channels = channels
self.scale_factor = scale_factor
self.up_kernel = up_kernel
self.up_group = up_group
self.encoder_kernel = encoder_kernel
self.encoder_dilation = encoder_dilation
self.compressed_channels = compressed_channels
# Compress input channels.
self.channel_compressor = nn.Conv2d(channels, compressed_channels, kernel_size=1)
# Predict reassembly masks.
self.content_encoder = nn.Conv2d(
compressed_channels,
up_kernel * up_kernel * up_group * scale_factor * scale_factor,
kernel_size=encoder_kernel,
padding=int((encoder_kernel - 1) * encoder_dilation / 2),
dilation=encoder_dilation
)
# Initialize weights (using Xavier for conv layers).
nn.init.xavier_uniform_(self.channel_compressor.weight)
nn.init.xavier_uniform_(self.content_encoder.weight)
if self.channel_compressor.bias is not None:
nn.init.constant_(self.channel_compressor.bias, 0)
if self.content_encoder.bias is not None:
nn.init.constant_(self.content_encoder.bias, 0)
def kernel_normalizer(self, mask: torch.Tensor) -> torch.Tensor:
"""
Normalize and reshape the mask.
Applies pixel shuffle to upsample the predicted kernel weights and then
applies softmax normalization across the kernel dimension.
Args:
mask (Tensor): Predicted mask of shape (N, out_channels, H, W).
Returns:
Tensor: Normalized mask of shape (N, up_group * up_kernel^2, H*scale, W*scale).
"""
# Pixel shuffle to rearrange and upsample the mask.
mask = F.pixel_shuffle(mask, self.scale_factor)
N, mask_c, H, W = mask.size()
# Determine the number of channels per kernel
mask_channel = mask_c // (self.up_kernel ** 2)
mask = mask.view(N, mask_channel, self.up_kernel ** 2, H, W)
mask = F.softmax(mask, dim=2)
mask = mask.view(N, mask_channel * self.up_kernel ** 2, H, W).contiguous()
return mask
def feature_reassemble(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
return carafe_forward(x, mask, self.up_kernel, self.up_group, self.scale_factor)
def forward(self, x: torch.Tensor) -> torch.Tensor:
compressed_x = self.channel_compressor(x)
mask = self.content_encoder(compressed_x)
mask = self.kernel_normalizer(mask)
out = self.feature_reassemble(x, mask)
return out
# === Example Usage ===
if __name__ == '__main__':
# Create dummy input: batch size 2, 64 channels, 32x32 spatial resolution.
x = torch.randn(2, 64, 32, 32).cuda() # assuming GPU available
# Define CARAFEPack with upsample ratio 2.
# For example, use kernel size 5, group size 1.
upsampler = CARAFEPack(channels=64, scale_factor=2, up_kernel=5, up_group=1).cuda()
# Get upsampled feature map.
out = upsampler(x)
print("Input shape: ", x.shape)
print("Output shape:", out.shape) # Expected shape: (2, 64, 64, 64)