Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch import Tensor | |
| from typing import Tuple | |
| from torchvision.models import resnet18, resnet50 | |
| from torchvision.models import ResNet18_Weights, ResNet50_Weights | |
| class DistMult(nn.Module): | |
| def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None): | |
| super(DistMult, self).__init__() | |
| self.num_ent_uid = num_ent_uid | |
| self.num_relations = 4 | |
| self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False) | |
| self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False) | |
| self.location_embedding = MLP(2, 512, 3) | |
| self.time_embedding = MLP(1, 512, 3) | |
| self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| self.image_embedding.fc = nn.Linear(2048, 512) | |
| self.target_list = target_list | |
| if all_locs is not None: | |
| self.all_locs = all_locs.to(device) | |
| if all_timestamps is not None: | |
| self.all_timestamps = all_timestamps.to(device) | |
| self.device = device | |
| self.init() | |
| def init(self): | |
| nn.init.xavier_uniform_(self.ent_embedding.weight.data) | |
| nn.init.xavier_uniform_(self.rel_embedding.weight.data) | |
| nn.init.xavier_uniform_(self.image_embedding.fc.weight.data) | |
| def forward_ce(self, h, r, triple_type=None): | |
| emb_h = self.batch_embedding_concat_h(h) # [batch, hid] | |
| emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid] | |
| emb_hr = emb_h * emb_r # [batch, hid] | |
| if triple_type == ('image', 'id'): | |
| score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent] | |
| elif triple_type == ('id', 'id'): | |
| score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent] | |
| elif triple_type == ('image', 'location'): | |
| loc_emb = self.location_embedding(self.all_locs) # computed for each batch | |
| score = torch.mm(emb_hr, loc_emb.T) | |
| elif triple_type == ('image', 'time'): | |
| time_emb = self.time_embedding(self.all_timestamps) | |
| score = torch.mm(emb_hr, time_emb.T) | |
| else: | |
| raise NotImplementedError | |
| return score | |
| def batch_embedding_concat_h(self, e1): | |
| e1_embedded = None | |
| if len(e1.size())==1 or e1.size(1) == 1: # uid | |
| # print('ent_embedding = {}'.format(self.ent_embedding.weight.size())) | |
| e1_embedded = self.ent_embedding(e1.squeeze(-1)) | |
| elif e1.size(1) == 15: # time | |
| e1_embedded = self.time_embedding(e1) | |
| elif e1.size(1) == 2: # GPS | |
| e1_embedded = self.location_embedding(e1) | |
| elif e1.size(1) == 3: # Image | |
| e1_embedded = self.image_embedding(e1) | |
| return e1_embedded | |
| class MLP(nn.Module): | |
| def __init__(self, | |
| input_dim, | |
| output_dim, | |
| num_layers=3, | |
| p_dropout=0.0, | |
| bias=True): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.p_dropout = p_dropout | |
| step_size = (input_dim - output_dim) // num_layers | |
| hidden_dims = [output_dim + (i * step_size) | |
| for i in reversed(range(num_layers))] | |
| mlp = list() | |
| layer_indim = input_dim | |
| for hidden_dim in hidden_dims: | |
| mlp.extend([nn.Linear(layer_indim, hidden_dim, bias), | |
| nn.Dropout(p=self.p_dropout, inplace=True), | |
| nn.PReLU()]) | |
| layer_indim = hidden_dim | |
| self.mlp = nn.Sequential(*mlp) | |
| # initialize weights | |
| self.init() | |
| def forward(self, x): | |
| return self.mlp(x) | |
| def init(self): | |
| for param in self.parameters(): | |
| nn.init.uniform_(param) |