Spaces:
Sleeping
Sleeping
File size: 1,616 Bytes
6ae852e 953417b 6ae852e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
import torch.nn as nn
from lightning import LightningModule
class GraphDTA(LightningModule):
"""
From GraphDTA (Nguyen et al., 2020; https://doi.org/10.1093/bioinformatics/btaa921).
"""
def __init__(
self,
gnn: nn.Module,
num_features_protein: int,
n_filters: int,
embed_dim: int,
output_dim: int,
dropout: float
):
super().__init__()
self.gnn = gnn
# protein sequence encoder (1d conv)
self.embedding_xt = nn.Embedding(num_features_protein, embed_dim)
self.conv_xt = nn.LazyConv1d(out_channels=n_filters, kernel_size=8)
self.fc1_xt = nn.Linear(32 * 121, output_dim)
# combined layers
self.fc1 = nn.Linear(256, 1024)
self.fc2 = nn.Linear(1024, 512)
# activation and regularization
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
# protein input feedforward
def conv_forward_xt(self, v_p):
v_p = self.embedding_xt(v_p.long())
v_p = self.conv_xt(v_p)
# flatten
v_p = v_p.view(-1, 32 * 121)
v_p = self.fc1_xt(v_p)
return v_p
def forward(self, v_d, v_p):
v_d = self.gnn(v_d)
v_p = self.conv_forward_xt(v_p)
# concat
v_f = torch.cat((v_d, v_p), 1)
# dense layers
v_f = self.fc1(v_f)
v_f = self.relu(v_f)
v_f = self.dropout(v_f)
v_f = self.fc2(v_f)
v_f = self.relu(v_f)
v_f = self.dropout(v_f)
# v_f = self.out(v_f)
return v_f
|