Spaces:
Sleeping
Sleeping
Commit
·
758a536
1
Parent(s):
c97ecfa
Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch import Tensor
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
from torchvision.models import resnet18, resnet50
|
8 |
+
from torchvision.models import ResNet18_Weights, ResNet50_Weights
|
9 |
+
|
10 |
+
class DistMult(nn.Module):
|
11 |
+
def __init__(self, args, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
|
12 |
+
super(DistMult, self).__init__()
|
13 |
+
self.args = args
|
14 |
+
self.num_ent_uid = num_ent_uid
|
15 |
+
|
16 |
+
self.num_relations = 4
|
17 |
+
|
18 |
+
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, args.embedding_dim, sparse=False)
|
19 |
+
self.rel_embedding = torch.nn.Embedding(self.num_relations, args.embedding_dim, sparse=False)
|
20 |
+
|
21 |
+
self.location_embedding = MLP(args.location_input_dim, args.embedding_dim, args.mlp_location_numlayer)
|
22 |
+
|
23 |
+
self.time_embedding = MLP(args.time_input_dim, args.embedding_dim, args.mlp_time_numlayer)
|
24 |
+
|
25 |
+
if self.args.img_embed_model == 'resnet50':
|
26 |
+
self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
27 |
+
self.image_embedding.fc = nn.Linear(2048, args.embedding_dim)
|
28 |
+
else:
|
29 |
+
self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
|
30 |
+
self.image_embedding.fc = nn.Linear(512, args.embedding_dim)
|
31 |
+
|
32 |
+
self.target_list = target_list
|
33 |
+
|
34 |
+
if all_locs is not None:
|
35 |
+
self.all_locs = all_locs.to(device)
|
36 |
+
if all_timestamps is not None:
|
37 |
+
self.all_timestamps = all_timestamps.to(device)
|
38 |
+
|
39 |
+
self.args = args
|
40 |
+
self.device = device
|
41 |
+
|
42 |
+
self.init()
|
43 |
+
|
44 |
+
def init(self):
|
45 |
+
nn.init.xavier_uniform_(self.ent_embedding.weight.data)
|
46 |
+
nn.init.xavier_uniform_(self.rel_embedding.weight.data)
|
47 |
+
nn.init.xavier_uniform_(self.image_embedding.fc.weight.data)
|
48 |
+
|
49 |
+
def forward_ce(self, h, r, triple_type=None):
|
50 |
+
emb_h = self.batch_embedding_concat_h(h) # [batch, hid]
|
51 |
+
|
52 |
+
emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid]
|
53 |
+
|
54 |
+
emb_hr = emb_h * emb_r # [batch, hid]
|
55 |
+
|
56 |
+
if triple_type == ('image', 'id'):
|
57 |
+
score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent]
|
58 |
+
elif triple_type == ('id', 'id'):
|
59 |
+
score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent]
|
60 |
+
elif triple_type == ('image', 'location'):
|
61 |
+
loc_emb = self.location_embedding(self.all_locs) # computed for each batch
|
62 |
+
score = torch.mm(emb_hr, loc_emb.T)
|
63 |
+
elif triple_type == ('image', 'time'):
|
64 |
+
time_emb = self.time_embedding(self.all_timestamps)
|
65 |
+
score = torch.mm(emb_hr, time_emb.T)
|
66 |
+
else:
|
67 |
+
raise NotImplementedError
|
68 |
+
|
69 |
+
return score
|
70 |
+
|
71 |
+
def batch_embedding_concat_h(self, e1):
|
72 |
+
e1_embedded = None
|
73 |
+
|
74 |
+
if len(e1.size())==1 or e1.size(1) == 1: # uid
|
75 |
+
# print('ent_embedding = {}'.format(self.ent_embedding.weight.size()))
|
76 |
+
e1_embedded = self.ent_embedding(e1.squeeze(-1))
|
77 |
+
elif e1.size(1) == 15: # time
|
78 |
+
e1_embedded = self.time_embedding(e1)
|
79 |
+
elif e1.size(1) == 2: # GPS
|
80 |
+
e1_embedded = self.location_embedding(e1)
|
81 |
+
elif e1.size(1) == 3: # Image
|
82 |
+
e1_embedded = self.image_embedding(e1)
|
83 |
+
|
84 |
+
return e1_embedded
|
85 |
+
|
86 |
+
|
87 |
+
class MLP(nn.Module):
|
88 |
+
def __init__(self,
|
89 |
+
input_dim,
|
90 |
+
output_dim,
|
91 |
+
num_layers=3,
|
92 |
+
p_dropout=0.0,
|
93 |
+
bias=True):
|
94 |
+
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.input_dim = input_dim
|
98 |
+
self.output_dim = output_dim
|
99 |
+
|
100 |
+
self.p_dropout = p_dropout
|
101 |
+
step_size = (input_dim - output_dim) // num_layers
|
102 |
+
hidden_dims = [output_dim + (i * step_size)
|
103 |
+
for i in reversed(range(num_layers))]
|
104 |
+
|
105 |
+
mlp = list()
|
106 |
+
layer_indim = input_dim
|
107 |
+
for hidden_dim in hidden_dims:
|
108 |
+
mlp.extend([nn.Linear(layer_indim, hidden_dim, bias),
|
109 |
+
nn.Dropout(p=self.p_dropout, inplace=True),
|
110 |
+
nn.PReLU()])
|
111 |
+
|
112 |
+
layer_indim = hidden_dim
|
113 |
+
|
114 |
+
self.mlp = nn.Sequential(*mlp)
|
115 |
+
|
116 |
+
# initialize weights
|
117 |
+
self.init()
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
return self.mlp(x)
|
121 |
+
|
122 |
+
def init(self):
|
123 |
+
for param in self.parameters():
|
124 |
+
nn.init.uniform_(param)
|