Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
class MultitaskHead(nn.Module): | |
def __init__(self, input_channels, num_class, head_size): | |
super(MultitaskHead, self).__init__() | |
m = int(input_channels / 4) | |
heads = [] | |
for output_channels in sum(head_size, []): | |
heads.append( | |
nn.Sequential( | |
nn.Conv2d(input_channels, m, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(m, output_channels, kernel_size=1), | |
) | |
) | |
self.heads = nn.ModuleList(heads) | |
assert num_class == sum(sum(head_size, [])) | |
def forward(self, x): | |
# import pdb;pdb.set_trace() | |
return torch.cat([head(x) for head in self.heads], dim=1) | |
class AngleDistanceHead(nn.Module): | |
def __init__(self, input_channels, num_class, head_size): | |
super(AngleDistanceHead, self).__init__() | |
m = int(input_channels/4) | |
heads = [] | |
for output_channels in sum(head_size, []): | |
if output_channels != 2: | |
heads.append( | |
nn.Sequential( | |
nn.Conv2d(input_channels, m, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(m, output_channels, kernel_size=1), | |
) | |
) | |
else: | |
heads.append( | |
nn.Sequential( | |
nn.Conv2d(input_channels, m, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
CosineSineLayer(m) | |
) | |
) | |
self.heads = nn.ModuleList(heads) | |
assert num_class == sum(sum(head_size, [])) | |
def forward(self, x): | |
return torch.cat([head(x) for head in self.heads], dim=1) |