ScaleLSD / scalelsd /ssl /backbones /multi_task_head.py
Nan Xue
update
4c954ae
raw
history blame
1.85 kB
import torch
import torch.nn as nn
class MultitaskHead(nn.Module):
def __init__(self, input_channels, num_class, head_size):
super(MultitaskHead, self).__init__()
m = int(input_channels / 4)
heads = []
for output_channels in sum(head_size, []):
heads.append(
nn.Sequential(
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(m, output_channels, kernel_size=1),
)
)
self.heads = nn.ModuleList(heads)
assert num_class == sum(sum(head_size, []))
def forward(self, x):
# import pdb;pdb.set_trace()
return torch.cat([head(x) for head in self.heads], dim=1)
class AngleDistanceHead(nn.Module):
def __init__(self, input_channels, num_class, head_size):
super(AngleDistanceHead, self).__init__()
m = int(input_channels/4)
heads = []
for output_channels in sum(head_size, []):
if output_channels != 2:
heads.append(
nn.Sequential(
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(m, output_channels, kernel_size=1),
)
)
else:
heads.append(
nn.Sequential(
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
CosineSineLayer(m)
)
)
self.heads = nn.ModuleList(heads)
assert num_class == sum(sum(head_size, []))
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=1)