import numpy as np import torch from torch import nn from model_utils.resnet import ResNet, BottleneckBlock import torch.nn.functional as F class DummyAggregationNetwork(nn.Module): # for testing, return the input def __init__(self): super(DummyAggregationNetwork, self).__init__() # dummy paprameter self.dummy = nn.Parameter(torch.ones([])) def forward(self, batch, pose=None): return batch * self.dummy class AggregationNetwork(nn.Module): """ Module for aggregating feature maps across time and space. Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023). https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py """ def __init__( self, device, feature_dims=[640,1280,1280,768], projection_dim=384, num_norm_groups=32, save_timestep=[1], kernel_size = [1,3,1], contrastive_temp = 10, feat_map_dropout=0.0, num_blocks=None, bottleneck_channels=None ): super().__init__() self.skip_connection = True self.feat_map_dropout = feat_map_dropout self.azimuth_embedding = None self.pos_embedding = None self.bottleneck_layers = nn.ModuleList() self.feature_dims = feature_dims self.num_blocks = num_blocks if num_blocks is not None else 1 self.bottleneck_channels = bottleneck_channels if bottleneck_channels is not None else projection_dim//4 # For CLIP symmetric cross entropy loss during training self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp)) self.device = device self.save_timestep = save_timestep self.mixing_weights_names = [] for l, feature_dim in enumerate(self.feature_dims): bottleneck_layer = nn.Sequential( *ResNet.make_stage( BottleneckBlock, num_blocks=self.num_blocks, in_channels=feature_dim, bottleneck_channels=self.bottleneck_channels, out_channels=projection_dim, norm="GN", num_norm_groups=num_norm_groups, kernel_size=kernel_size ) ) self.bottleneck_layers.append(bottleneck_layer) for t in save_timestep: # 1-index the layer name following prior work self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}") self.last_layer = None self.bottleneck_layers = self.bottleneck_layers.to(device) mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep)) self.mixing_weights = nn.Parameter(mixing_weights.to(device)) # count number of parameters num_params = 0 for param in self.parameters(): num_params += param.numel() print(f"AggregationNetwork has {num_params} parameters.") def load_pretrained_weights(self, pretrained_dict): custom_dict = self.state_dict() # Handle size mismatch if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape: # Keep the first four weights from the pretrained model, and randomly initialize the fifth weight custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4] custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4]) else: custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4] # Load the weights that do match matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'} custom_dict.update(matching_keys) # Now load the updated state_dict self.load_state_dict(custom_dict, strict=False) def forward(self, batch, pose=None): """ Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features. """ if self.feat_map_dropout > 0 and self.training: batch = F.dropout(batch, p=self.feat_map_dropout) output_feature = None start = 0 mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0) if self.pos_embedding is not None: #position embedding batch = torch.cat((batch, self.pos_embedding), dim=1) for i in range(len(mixing_weights)): # Share bottleneck layers across timesteps bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)] # Chunk the batch according the layer # Account for looping if there are multiple timesteps end = start + self.feature_dims[i % len(self.feature_dims)] feats = batch[:, start:end, :, :] start = end # Downsample the number of channels and weight the layer bottlenecked_feature = bottleneck_layer(feats) bottlenecked_feature = mixing_weights[i] * bottlenecked_feature if output_feature is None: output_feature = bottlenecked_feature else: output_feature += bottlenecked_feature if self.last_layer is not None: output_feature_after = self.last_layer(output_feature) if self.skip_connection: # skip connection output_feature = output_feature + output_feature_after return output_feature def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution without padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = conv3x3(in_planes, planes, stride) self.conv2 = conv3x3(planes, planes) self.bn1 = nn.BatchNorm2d(planes) self.bn2 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) if stride == 1: self.downsample = None else: self.downsample = nn.Sequential( conv1x1(in_planes, planes, stride=stride), nn.BatchNorm2d(planes) ) def forward(self, x): y = x y = self.relu(self.bn1(self.conv1(y))) y = self.bn2(self.conv2(y)) if self.downsample is not None: x = self.downsample(x) return self.relu(x+y)