File size: 9,419 Bytes
a7dedf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
import torch.nn as nn
import torch.nn.functional as F

def carafe_forward(
    features: torch.Tensor,
    masks: torch.Tensor,
    kernel_size: int,
    group_size: int,
    scale_factor: int
) -> torch.Tensor:
    """
    Pure-PyTorch implementation of the CARAFE upsampling operator.
    
    Args:
        features (Tensor): Input feature map of shape (N, C, H, W).
        masks (Tensor): Reassembly kernel weights of shape 
            (N, kernel_size*kernel_size*group_size, H_out, W_out),
            where H_out = H*scale_factor and W_out = W*scale_factor.
        kernel_size (int): The spatial size of the reassembly kernel.
        group_size (int): The group size to divide channels. Must divide C.
        scale_factor (int): The upsampling factor.
    
    Returns:
        Tensor: Upsampled feature map of shape (N, C, H*scale_factor, W*scale_factor).
    """
    N, C, H, W = features.size()
    out_H, out_W = H * scale_factor, W * scale_factor
    num_channels = C // group_size  # channels per group

    # Reshape features to (N, group_size, num_channels, H, W)
    features = features.view(N, group_size, num_channels, H, W)
    # Merge batch and group dims for unfolding
    features_reshaped = features.view(N * group_size, num_channels, H, W)
    # Extract local patches; use padding so that output spatial dims match input
    patches = F.unfold(features_reshaped, kernel_size=kernel_size, 
                       padding=(kernel_size - 1) // 2)
    # patches shape: (N*group_size, num_channels*kernel_size*kernel_size, H*W)
    # Reshape to (N, group_size, num_channels, kernel_size*kernel_size, H, W)
    patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H, W)
    # Flatten spatial dimensions: now (N, group_size, num_channels, kernel_size*kernel_size, H*W)
    patches = patches.view(N, group_size, num_channels, kernel_size * kernel_size, H * W)

    # For each output pixel location, determine the corresponding base input index.
    # For an output coordinate (oh, ow), the corresponding input index is:
    #   h = oh // scale_factor, w = ow // scale_factor, linear index = h * W + w.
    device = features.device
    # Create coordinate indices for output
    h_idx = torch.div(torch.arange(out_H, device=device), scale_factor, rounding_mode='floor')  # (out_H,)
    w_idx = torch.div(torch.arange(out_W, device=device), scale_factor, rounding_mode='floor')  # (out_W,)
    # Form a 2D grid of base indices (shape: out_H x out_W)
    h_idx = h_idx.unsqueeze(1).expand(out_H, out_W)  # (out_H, out_W)
    w_idx = w_idx.unsqueeze(0).expand(out_H, out_W)  # (out_H, out_W)
    base_idx = (h_idx * W + w_idx).view(-1)  # (out_H*out_W,)

    # Expand base_idx so that it can index the last dimension of patches:
    # Desired shape for gathering: (N, group_size, num_channels, kernel_size*kernel_size, out_H*out_W)
    base_idx = base_idx.view(1, 1, 1, 1, -1).expand(N, group_size, num_channels, kernel_size * kernel_size, -1)
    # Gather patches corresponding to each output location
    gathered_patches = torch.gather(patches, -1, base_idx)
    # Reshape gathered patches to (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
    gathered_patches = gathered_patches.view(N, group_size, num_channels, kernel_size * kernel_size, out_H, out_W)

    # Reshape masks to separate groups.
    # Expected mask shape: (N, kernel_size*kernel_size*group_size, out_H, out_W)
    # Reshape to: (N, group_size, kernel_size*kernel_size, out_H, out_W)
    masks = masks.view(N, group_size, kernel_size * kernel_size, out_H, out_W)
    # For multiplication, add a channel dimension so that masks shape becomes
    # (N, group_size, 1, kernel_size*kernel_size, out_H, out_W)
    masks = masks.unsqueeze(2)
    # Expand masks to match gathered_patches: (N, group_size, num_channels, kernel_size*kernel_size, out_H, out_W)
    masks = masks.expand(-1, -1, num_channels, -1, -1, -1)

    # Multiply patches with masks and sum over the kernel dimension.
    # This yields the reassembled features for each output location.
    out = (gathered_patches * masks).sum(dim=3)  # shape: (N, group_size, num_channels, out_H, out_W)
    # Reshape back to (N, C, out_H, out_W)
    out = out.view(N, C, out_H, out_W)
    return out


class CARAFE(nn.Module):
    """
    CARAFE: Content-Aware ReAssembly of Features

    This PyTorch module implements the CARAFE upsampling operator in pure Python.
    Given an input feature map and its corresponding reassembly masks, the module
    reassembles features from local patches to produce a higher-resolution output.

    Args:
        kernel_size (int): Reassembly kernel size.
        group_size (int): Group size for channel grouping (must divide number of channels).
        scale_factor (int): Upsample ratio.
    """
    def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
        super(CARAFE, self).__init__()
        self.kernel_size = kernel_size
        self.group_size = group_size
        self.scale_factor = scale_factor

    def forward(self, features: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
        return carafe_forward(features, masks, self.kernel_size, self.group_size, self.scale_factor)


class CARAFEPack(nn.Module):
    """
    A unified package of the CARAFE upsampler that contains:
    1) A channel compressor.
    2) A content encoder that predicts reassembly masks.
    3) The CARAFE operator.

    This is modeled after the official CARAFE package.

    Args:
        channels (int): Number of input feature channels.
        scale_factor (int): Upsample ratio.
        up_kernel (int): Kernel size for the CARAFE operator.
        up_group (int): Group size for the CARAFE operator.
        encoder_kernel (int): Kernel size of the content encoder.
        encoder_dilation (int): Dilation rate for the content encoder.
        compressed_channels (int): Output channels for the channel compressor.
    """
    def __init__(
        self,
        channels: int,
        scale_factor: int,
        up_kernel: int = 5,
        up_group: int = 1,
        encoder_kernel: int = 3,
        encoder_dilation: int = 1,
        compressed_channels: int = 64
    ):
        super(CARAFEPack, self).__init__()
        self.channels = channels
        self.scale_factor = scale_factor
        self.up_kernel = up_kernel
        self.up_group = up_group
        self.encoder_kernel = encoder_kernel
        self.encoder_dilation = encoder_dilation
        self.compressed_channels = compressed_channels

        # Compress input channels.
        self.channel_compressor = nn.Conv2d(channels, compressed_channels, kernel_size=1)
        # Predict reassembly masks.
        self.content_encoder = nn.Conv2d(
            compressed_channels,
            up_kernel * up_kernel * up_group * scale_factor * scale_factor,
            kernel_size=encoder_kernel,
            padding=int((encoder_kernel - 1) * encoder_dilation / 2),
            dilation=encoder_dilation
        )
        # Initialize weights (using Xavier for conv layers).
        nn.init.xavier_uniform_(self.channel_compressor.weight)
        nn.init.xavier_uniform_(self.content_encoder.weight)
        if self.channel_compressor.bias is not None:
            nn.init.constant_(self.channel_compressor.bias, 0)
        if self.content_encoder.bias is not None:
            nn.init.constant_(self.content_encoder.bias, 0)

    def kernel_normalizer(self, mask: torch.Tensor) -> torch.Tensor:
        """
        Normalize and reshape the mask.
        Applies pixel shuffle to upsample the predicted kernel weights and then
        applies softmax normalization across the kernel dimension.
        
        Args:
            mask (Tensor): Predicted mask of shape (N, out_channels, H, W).
            
        Returns:
            Tensor: Normalized mask of shape (N, up_group * up_kernel^2, H*scale, W*scale).
        """
        # Pixel shuffle to rearrange and upsample the mask.
        mask = F.pixel_shuffle(mask, self.scale_factor)
        N, mask_c, H, W = mask.size()
        # Determine the number of channels per kernel
        mask_channel = mask_c // (self.up_kernel ** 2)
        mask = mask.view(N, mask_channel, self.up_kernel ** 2, H, W)
        mask = F.softmax(mask, dim=2)
        mask = mask.view(N, mask_channel * self.up_kernel ** 2, H, W).contiguous()
        return mask

    def feature_reassemble(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        return carafe_forward(x, mask, self.up_kernel, self.up_group, self.scale_factor)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        compressed_x = self.channel_compressor(x)
        mask = self.content_encoder(compressed_x)
        mask = self.kernel_normalizer(mask)
        out = self.feature_reassemble(x, mask)
        return out


# === Example Usage ===
if __name__ == '__main__':
    # Create dummy input: batch size 2, 64 channels, 32x32 spatial resolution.
    x = torch.randn(2, 64, 32, 32).cuda()  # assuming GPU available
    # Define CARAFEPack with upsample ratio 2.
    # For example, use kernel size 5, group size 1.
    upsampler = CARAFEPack(channels=64, scale_factor=2, up_kernel=5, up_group=1).cuda()
    # Get upsampled feature map.
    out = upsampler(x)
    print("Input shape: ", x.shape)
    print("Output shape:", out.shape)  # Expected shape: (2, 64, 64, 64)