Spaces:
Sleeping
Sleeping
File size: 1,404 Bytes
719d0db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class MLPDecoder(nn.Module):
def __init__(self, emb_dim, num_mlp_layers, num_classes, dropout):
super().__init__()
self.num_mlp_layers = num_mlp_layers
# Decoder (MLP)
self.mlp = nn.ModuleList()
for _ in range(num_mlp_layers):
self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
self.mlp.append(nn.Linear(emb_dim, num_classes, bias=True))
# Dropout
self.dropout = nn.Dropout(dropout)
# Initializing weights
self.reset_parameters()
def reset_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, graph_emb):
"""
Paramters
---------
graph_emb: torch.tensor [batch_size x emb_dim]
Returns
-------
probs: torch.tensor [batch_size x num_classes]
probabilities of classes
"""
#----------
# Decoding
#----------
h = graph_emb
for i in range(self.num_mlp_layers):
h = self.dropout(h)
h = torch.relu(self.mlp[i](h))
h = self.dropout(h)
logits = self.mlp[-1](h)
probs = F.log_softmax(logits, dim=-1)
return probs |