File size: 7,031 Bytes
79cc514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import numpy as np
import torch
from torch import nn
from model_utils.resnet import ResNet, BottleneckBlock
import torch.nn.functional as F

class DummyAggregationNetwork(nn.Module): # for testing, return the input
    def __init__(self):
        super(DummyAggregationNetwork, self).__init__()
        # dummy paprameter
        self.dummy = nn.Parameter(torch.ones([]))
    def forward(self, batch, pose=None):
        return batch * self.dummy

class AggregationNetwork(nn.Module):
    """
    Module for aggregating feature maps across time and space.
    Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
    https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
    """
    def __init__(
            self, 
            device, 
            feature_dims=[640,1280,1280,768],
            projection_dim=384,
            num_norm_groups=32,
            save_timestep=[1],
            kernel_size = [1,3,1],
            contrastive_temp = 10,
            feat_map_dropout=0.0,
            num_blocks=None,
            bottleneck_channels=None
        ):
        super().__init__()
        self.skip_connection = True
        self.feat_map_dropout = feat_map_dropout
        self.azimuth_embedding = None
        self.pos_embedding = None
        self.bottleneck_layers = nn.ModuleList()
        self.feature_dims = feature_dims
        self.num_blocks = num_blocks if num_blocks is not None else 1
        self.bottleneck_channels = bottleneck_channels if bottleneck_channels is not None else projection_dim//4
        # For CLIP symmetric cross entropy loss during training
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
        self.device = device
        self.save_timestep = save_timestep
        
        self.mixing_weights_names = []
        for l, feature_dim in enumerate(self.feature_dims):
            bottleneck_layer = nn.Sequential(
                *ResNet.make_stage(
                    BottleneckBlock,
                    num_blocks=self.num_blocks,
                    in_channels=feature_dim,
                    bottleneck_channels=self.bottleneck_channels,
                    out_channels=projection_dim,
                    norm="GN",
                    num_norm_groups=num_norm_groups,
                    kernel_size=kernel_size
                )
            )
            self.bottleneck_layers.append(bottleneck_layer)
            for t in save_timestep:
                # 1-index the layer name following prior work
                self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}")
        self.last_layer = None
        self.bottleneck_layers = self.bottleneck_layers.to(device)
        mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
        self.mixing_weights = nn.Parameter(mixing_weights.to(device))
        # count number of parameters
        num_params = 0
        for param in self.parameters():
            num_params += param.numel()
        print(f"AggregationNetwork has {num_params} parameters.")
    
    def load_pretrained_weights(self, pretrained_dict):
        custom_dict = self.state_dict()

        # Handle size mismatch
        if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
            # Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
            custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
            custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
        else:
            custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]

        # Load the weights that do match
        matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
        custom_dict.update(matching_keys)

        # Now load the updated state_dict
        self.load_state_dict(custom_dict, strict=False)
        
    def forward(self, batch, pose=None):
        """
        Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
        """
        if self.feat_map_dropout > 0 and self.training:
            batch = F.dropout(batch, p=self.feat_map_dropout)
        
        output_feature = None
        start = 0
        mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
        if self.pos_embedding is not None: #position embedding
            batch = torch.cat((batch, self.pos_embedding), dim=1)
        for i in range(len(mixing_weights)):
            # Share bottleneck layers across timesteps
            bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
            # Chunk the batch according the layer
            # Account for looping if there are multiple timesteps
            end = start + self.feature_dims[i % len(self.feature_dims)]
            feats = batch[:, start:end, :, :]
            start = end
            # Downsample the number of channels and weight the layer
            bottlenecked_feature = bottleneck_layer(feats)
            bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
            if output_feature is None:
                output_feature = bottlenecked_feature
            else:
                output_feature += bottlenecked_feature
        
        if self.last_layer is not None:

            output_feature_after = self.last_layer(output_feature)
            if self.skip_connection:
                # skip connection
                output_feature = output_feature + output_feature_after
        return output_feature
    

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution without padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.conv2 = conv3x3(planes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        if stride == 1:
            self.downsample = None
        else:
            self.downsample = nn.Sequential(
                conv1x1(in_planes, planes, stride=stride),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        y = x
        y = self.relu(self.bn1(self.conv1(y)))
        y = self.bn2(self.conv2(y))

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x+y)