vardaan123 commited on
Commit
758a536
·
1 Parent(s): c97ecfa

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +124 -0
model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch import Tensor
5
+ from typing import Tuple
6
+
7
+ from torchvision.models import resnet18, resnet50
8
+ from torchvision.models import ResNet18_Weights, ResNet50_Weights
9
+
10
+ class DistMult(nn.Module):
11
+ def __init__(self, args, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
12
+ super(DistMult, self).__init__()
13
+ self.args = args
14
+ self.num_ent_uid = num_ent_uid
15
+
16
+ self.num_relations = 4
17
+
18
+ self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, args.embedding_dim, sparse=False)
19
+ self.rel_embedding = torch.nn.Embedding(self.num_relations, args.embedding_dim, sparse=False)
20
+
21
+ self.location_embedding = MLP(args.location_input_dim, args.embedding_dim, args.mlp_location_numlayer)
22
+
23
+ self.time_embedding = MLP(args.time_input_dim, args.embedding_dim, args.mlp_time_numlayer)
24
+
25
+ if self.args.img_embed_model == 'resnet50':
26
+ self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
27
+ self.image_embedding.fc = nn.Linear(2048, args.embedding_dim)
28
+ else:
29
+ self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
30
+ self.image_embedding.fc = nn.Linear(512, args.embedding_dim)
31
+
32
+ self.target_list = target_list
33
+
34
+ if all_locs is not None:
35
+ self.all_locs = all_locs.to(device)
36
+ if all_timestamps is not None:
37
+ self.all_timestamps = all_timestamps.to(device)
38
+
39
+ self.args = args
40
+ self.device = device
41
+
42
+ self.init()
43
+
44
+ def init(self):
45
+ nn.init.xavier_uniform_(self.ent_embedding.weight.data)
46
+ nn.init.xavier_uniform_(self.rel_embedding.weight.data)
47
+ nn.init.xavier_uniform_(self.image_embedding.fc.weight.data)
48
+
49
+ def forward_ce(self, h, r, triple_type=None):
50
+ emb_h = self.batch_embedding_concat_h(h) # [batch, hid]
51
+
52
+ emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid]
53
+
54
+ emb_hr = emb_h * emb_r # [batch, hid]
55
+
56
+ if triple_type == ('image', 'id'):
57
+ score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent]
58
+ elif triple_type == ('id', 'id'):
59
+ score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent]
60
+ elif triple_type == ('image', 'location'):
61
+ loc_emb = self.location_embedding(self.all_locs) # computed for each batch
62
+ score = torch.mm(emb_hr, loc_emb.T)
63
+ elif triple_type == ('image', 'time'):
64
+ time_emb = self.time_embedding(self.all_timestamps)
65
+ score = torch.mm(emb_hr, time_emb.T)
66
+ else:
67
+ raise NotImplementedError
68
+
69
+ return score
70
+
71
+ def batch_embedding_concat_h(self, e1):
72
+ e1_embedded = None
73
+
74
+ if len(e1.size())==1 or e1.size(1) == 1: # uid
75
+ # print('ent_embedding = {}'.format(self.ent_embedding.weight.size()))
76
+ e1_embedded = self.ent_embedding(e1.squeeze(-1))
77
+ elif e1.size(1) == 15: # time
78
+ e1_embedded = self.time_embedding(e1)
79
+ elif e1.size(1) == 2: # GPS
80
+ e1_embedded = self.location_embedding(e1)
81
+ elif e1.size(1) == 3: # Image
82
+ e1_embedded = self.image_embedding(e1)
83
+
84
+ return e1_embedded
85
+
86
+
87
+ class MLP(nn.Module):
88
+ def __init__(self,
89
+ input_dim,
90
+ output_dim,
91
+ num_layers=3,
92
+ p_dropout=0.0,
93
+ bias=True):
94
+
95
+ super().__init__()
96
+
97
+ self.input_dim = input_dim
98
+ self.output_dim = output_dim
99
+
100
+ self.p_dropout = p_dropout
101
+ step_size = (input_dim - output_dim) // num_layers
102
+ hidden_dims = [output_dim + (i * step_size)
103
+ for i in reversed(range(num_layers))]
104
+
105
+ mlp = list()
106
+ layer_indim = input_dim
107
+ for hidden_dim in hidden_dims:
108
+ mlp.extend([nn.Linear(layer_indim, hidden_dim, bias),
109
+ nn.Dropout(p=self.p_dropout, inplace=True),
110
+ nn.PReLU()])
111
+
112
+ layer_indim = hidden_dim
113
+
114
+ self.mlp = nn.Sequential(*mlp)
115
+
116
+ # initialize weights
117
+ self.init()
118
+
119
+ def forward(self, x):
120
+ return self.mlp(x)
121
+
122
+ def init(self):
123
+ for param in self.parameters():
124
+ nn.init.uniform_(param)