lorocksUMD's picture
Upload 32 files
e6d4b46 verified
# Author: David Harwath
import torch
import torch.nn as nn
import torch.nn.functional
import torch.nn.functional
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torchvision.models as imagemodels
class Davenet(nn.Module):
def __init__(self, embedding_dim=1024):
super(Davenet, self).__init__()
self.embedding_dim = embedding_dim
self.batchnorm1 = nn.BatchNorm2d(1)
self.conv1 = nn.Conv2d(1, 128, kernel_size=(40, 1), stride=(1, 1), padding=(0, 0))
self.conv2 = nn.Conv2d(128, 256, kernel_size=(1, 11), stride=(1, 1), padding=(0, 5))
self.conv3 = nn.Conv2d(256, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
self.conv4 = nn.Conv2d(512, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
self.conv5 = nn.Conv2d(512, embedding_dim, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
self.pool = nn.MaxPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))
def forward(self, x):
if x.dim() == 3:
x = x.unsqueeze(1)
x = self.batchnorm1(x)
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.pool(x)
x = F.relu(self.conv3(x))
x = self.pool(x)
x = F.relu(self.conv4(x))
x = self.pool(x)
x = F.relu(self.conv5(x))
x = self.pool(x)
x = x.squeeze(2)
return x
class Resnet18(imagemodels.ResNet):
def __init__(self, embedding_dim=1024, pretrained=False):
super(Resnet18, self).__init__(imagemodels.resnet.BasicBlock, [2, 2, 2, 2])
if pretrained:
self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet18']))
self.avgpool = None
self.fc = None
self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0)
self.embedding_dim = embedding_dim
self.pretrained = pretrained
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.embedder(x)
return x
class Resnet34(imagemodels.ResNet):
def __init__(self, embedding_dim=1024, pretrained=False):
super(Resnet34, self).__init__(imagemodels.resnet.BasicBlock, [3, 4, 6, 3])
if pretrained:
self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet34']))
self.avgpool = None
self.fc = None
self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.embedder(x)
return x
class Resnet50(imagemodels.ResNet):
def __init__(self, embedding_dim=1024, pretrained=False):
super(Resnet50, self).__init__(imagemodels.resnet.Bottleneck, [3, 4, 6, 3])
if pretrained:
self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet50']))
self.avgpool = None
self.fc = None
self.embedder = nn.Conv2d(2048, embedding_dim, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.embedder(x)
return x
class VGG16(nn.Module):
def __init__(self, embedding_dim=1024, pretrained=False):
super(VGG16, self).__init__()
seed_model = imagemodels.__dict__['vgg16'](pretrained=pretrained).features
seed_model = nn.Sequential(*list(seed_model.children())[:-1]) # remove final maxpool
last_layer_index = len(list(seed_model.children()))
seed_model.add_module(str(last_layer_index),
nn.Conv2d(512, embedding_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
self.image_model = seed_model
def forward(self, x):
x = self.image_model(x)
return x
def prep(dict):
return {k.replace("module.", ""): v for k, v in dict.items()}
class DavenetAudioFeaturizer(nn.Module):
def __init__(self):
super().__init__()
self.audio_model = Davenet()
self.audio_model.load_state_dict(prep(torch.load("../models/davenet_pt_audio.pth")))
def forward(self, audio, include_cls):
patch_tokens = self.audio_model(audio).unsqueeze(2)
if include_cls:
return patch_tokens, None
else:
return patch_tokens
def get_last_params(self):
return []
class DavenetImageFeaturizer(nn.Module):
def __init__(self):
super().__init__()
self.image_model = VGG16()
self.image_model.load_state_dict(prep(torch.load("../models/davenet_pt_image.pth")))
def forward(self, image, include_cls):
patch_tokens = self.image_model(image)
if include_cls:
return patch_tokens, None
else:
return patch_tokens
def get_last_params(self):
return []