Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,031 Bytes
79cc514 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import numpy as np
import torch
from torch import nn
from model_utils.resnet import ResNet, BottleneckBlock
import torch.nn.functional as F
class DummyAggregationNetwork(nn.Module): # for testing, return the input
def __init__(self):
super(DummyAggregationNetwork, self).__init__()
# dummy paprameter
self.dummy = nn.Parameter(torch.ones([]))
def forward(self, batch, pose=None):
return batch * self.dummy
class AggregationNetwork(nn.Module):
"""
Module for aggregating feature maps across time and space.
Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
"""
def __init__(
self,
device,
feature_dims=[640,1280,1280,768],
projection_dim=384,
num_norm_groups=32,
save_timestep=[1],
kernel_size = [1,3,1],
contrastive_temp = 10,
feat_map_dropout=0.0,
num_blocks=None,
bottleneck_channels=None
):
super().__init__()
self.skip_connection = True
self.feat_map_dropout = feat_map_dropout
self.azimuth_embedding = None
self.pos_embedding = None
self.bottleneck_layers = nn.ModuleList()
self.feature_dims = feature_dims
self.num_blocks = num_blocks if num_blocks is not None else 1
self.bottleneck_channels = bottleneck_channels if bottleneck_channels is not None else projection_dim//4
# For CLIP symmetric cross entropy loss during training
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
self.device = device
self.save_timestep = save_timestep
self.mixing_weights_names = []
for l, feature_dim in enumerate(self.feature_dims):
bottleneck_layer = nn.Sequential(
*ResNet.make_stage(
BottleneckBlock,
num_blocks=self.num_blocks,
in_channels=feature_dim,
bottleneck_channels=self.bottleneck_channels,
out_channels=projection_dim,
norm="GN",
num_norm_groups=num_norm_groups,
kernel_size=kernel_size
)
)
self.bottleneck_layers.append(bottleneck_layer)
for t in save_timestep:
# 1-index the layer name following prior work
self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}")
self.last_layer = None
self.bottleneck_layers = self.bottleneck_layers.to(device)
mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
self.mixing_weights = nn.Parameter(mixing_weights.to(device))
# count number of parameters
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(f"AggregationNetwork has {num_params} parameters.")
def load_pretrained_weights(self, pretrained_dict):
custom_dict = self.state_dict()
# Handle size mismatch
if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
# Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
else:
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
# Load the weights that do match
matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
custom_dict.update(matching_keys)
# Now load the updated state_dict
self.load_state_dict(custom_dict, strict=False)
def forward(self, batch, pose=None):
"""
Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
"""
if self.feat_map_dropout > 0 and self.training:
batch = F.dropout(batch, p=self.feat_map_dropout)
output_feature = None
start = 0
mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
if self.pos_embedding is not None: #position embedding
batch = torch.cat((batch, self.pos_embedding), dim=1)
for i in range(len(mixing_weights)):
# Share bottleneck layers across timesteps
bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
# Chunk the batch according the layer
# Account for looping if there are multiple timesteps
end = start + self.feature_dims[i % len(self.feature_dims)]
feats = batch[:, start:end, :, :]
start = end
# Downsample the number of channels and weight the layer
bottlenecked_feature = bottleneck_layer(feats)
bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
if output_feature is None:
output_feature = bottlenecked_feature
else:
output_feature += bottlenecked_feature
if self.last_layer is not None:
output_feature_after = self.last_layer(output_feature)
if self.skip_connection:
# skip connection
output_feature = output_feature + output_feature_after
return output_feature
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution without padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.conv2 = conv3x3(planes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
conv1x1(in_planes, planes, stride=stride),
nn.BatchNorm2d(planes)
)
def forward(self, x):
y = x
y = self.relu(self.bn1(self.conv1(y)))
y = self.bn2(self.conv2(y))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
|