File size: 6,456 Bytes
dcc8c59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import List, Iterable
import torch
import torch.nn as nn
from matanyone.model.group_modules import *
class UpsampleBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
super().__init__()
self.out_conv = ResBlock(in_dim, out_dim)
self.scale_factor = scale_factor
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
g = F.interpolate(in_g,
scale_factor=self.scale_factor,
mode='bilinear')
g = self.out_conv(g)
g = g + skip_f
return g
class MaskUpsampleBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
super().__init__()
self.distributor = MainToGroupDistributor(method='add')
self.out_conv = GroupResBlock(in_dim, out_dim)
self.scale_factor = scale_factor
def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
g = upsample_groups(in_g, ratio=self.scale_factor)
g = self.distributor(skip_f, g)
g = self.out_conv(g)
return g
class DecoderFeatureProcessor(nn.Module):
def __init__(self, decoder_dims: List[int], out_dims: List[int]):
super().__init__()
self.transforms = nn.ModuleList([
nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
])
def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
return outputs
# @torch.jit.script
def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
# h: batch_size * num_objects * hidden_dim * h * w
# values: batch_size * num_objects * (hidden_dim*3) * h * w
dim = values.shape[2] // 3
forget_gate = torch.sigmoid(values[:, :, :dim])
update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
new_value = torch.tanh(values[:, :, dim * 2:])
new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
return new_h
class SensoryUpdater_fullscale(nn.Module):
# Used in the decoder, multi-scale feature + GRU
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
super().__init__()
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
self.g1_conv(downsample_groups(g[4], ratio=1/16))
with torch.cuda.amp.autocast(enabled=False):
g = g.float()
h = h.float()
values = self.transform(torch.cat([g, h], dim=2))
new_h = _recurrent_update(h, values)
return new_h
class SensoryUpdater(nn.Module):
# Used in the decoder, multi-scale feature + GRU
def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
super().__init__()
self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
self.g4_conv(downsample_groups(g[2], ratio=1/4))
with torch.cuda.amp.autocast(enabled=False):
g = g.float()
h = h.float()
values = self.transform(torch.cat([g, h], dim=2))
new_h = _recurrent_update(h, values)
return new_h
class SensoryDeepUpdater(nn.Module):
def __init__(self, f_dim: int, sensory_dim: int):
super().__init__()
self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
with torch.cuda.amp.autocast(enabled=False):
g = g.float()
h = h.float()
values = self.transform(torch.cat([g, h], dim=2))
new_h = _recurrent_update(h, values)
return new_h
class ResBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
if in_dim == out_dim:
self.downsample = nn.Identity()
else:
self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
def forward(self, g: torch.Tensor) -> torch.Tensor:
out_g = self.conv1(F.relu(g))
out_g = self.conv2(F.relu(out_g))
g = self.downsample(g)
return out_g + g
def __init__(self, in_dim, reduction_dim, bins):
super(PPM, self).__init__()
self.features = []
for bin in bins:
self.features.append(nn.Sequential(
nn.AdaptiveAvgPool2d(bin),
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
nn.PReLU()
))
self.features = nn.ModuleList(self.features)
self.fuse = nn.Sequential(
nn.Conv2d(in_dim+reduction_dim*4, in_dim, kernel_size=3, padding=1, bias=False),
nn.PReLU())
def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
out_feat = self.fuse(torch.cat(out, 1))
return out_feat |