File size: 3,963 Bytes
6ae852e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn

# TODO this is an easy model; refactor it to be customized by config file only


class HyperAttentionDTI(nn.Module):
    def __init__(
            self,
            protein_kernel=(4, 8, 12),
            drug_kernel=(4, 6, 8),
            conv=40,
            char_dim=64,
            protein_max_len=1000,
            drug_max_len=100
    ):
        super().__init__()
        
        self.drug_embed = nn.Embedding(63, char_dim, padding_idx=0)
        self.drug_cnn = nn.Sequential(
            nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=drug_kernel[0]),
            nn.ReLU(),
            nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=drug_kernel[1]),
            nn.ReLU(),
            nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=drug_kernel[2]),
            nn.ReLU(),
        )
        self.drug_max_pool = nn.MaxPool1d(
            drug_max_len - drug_kernel[0] - drug_kernel[1] - drug_kernel[2] + 3)

        self.protein_embed = nn.Embedding(26, char_dim, padding_idx=0)
        self.protein_cnn = nn.Sequential(
            nn.Conv1d(in_channels=char_dim, out_channels=conv, kernel_size=protein_kernel[0]),
            nn.ReLU(),
            nn.Conv1d(in_channels=conv, out_channels=conv * 2, kernel_size=protein_kernel[1]),
            nn.ReLU(),
            nn.Conv1d(in_channels=conv * 2, out_channels=conv * 4, kernel_size=protein_kernel[2]),
            nn.ReLU(),
        )
        self.protein_max_pool = nn.MaxPool1d(
            protein_max_len - protein_kernel[0] - protein_kernel[1] - protein_kernel[2] + 3)

        self.attention_layer = nn.Linear(conv * 4, conv * 4)
        self.protein_attention_layer = nn.Linear(conv * 4, conv * 4)
        self.drug_attention_layer = nn.Linear(conv * 4, conv * 4)

        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        self.dropout3 = nn.Dropout(0.1)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.leaky_relu = nn.LeakyReLU()
        self.fc1 = nn.Linear(conv * 8, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 512)
        # self.out = nn.Linear(512, 1)

    def forward(self, drug, protein):
        drugembed = self.drug_embed(drug.long())
        proteinembed = self.protein_embed(protein.long())
        drugembed = drugembed.permute(0, 2, 1)
        proteinembed = proteinembed.permute(0, 2, 1)

        drug_conv = self.drug_cnn(drugembed)
        protein_conv = self.protein_cnn(proteinembed)

        drug_att = self.drug_attention_layer(drug_conv.permute(0, 2, 1))
        protein_att = self.protein_attention_layer(protein_conv.permute(0, 2, 1))

        d_att_layers = torch.unsqueeze(drug_att, 2).repeat(1, 1, protein_conv.shape[-1], 1)  # repeat along protein size
        p_att_layers = torch.unsqueeze(protein_att, 1).repeat(1, drug_conv.shape[-1], 1, 1)  # repeat along drug size
        atten_matrix = self.attention_layer(self.relu(d_att_layers + p_att_layers))
        compound_atte = torch.mean(atten_matrix, 2)
        protein_atte = torch.mean(atten_matrix, 1)
        compound_atte = self.sigmoid(compound_atte.permute(0, 2, 1))
        protein_atte = self.sigmoid(protein_atte.permute(0, 2, 1))

        drug_conv = drug_conv * 0.5 + drug_conv * compound_atte
        protein_conv = protein_conv * 0.5 + protein_conv * protein_atte

        drug_conv = self.drug_max_pool(drug_conv).squeeze(2)
        protein_conv = self.protein_max_pool(protein_conv).squeeze(2)

        preds = torch.cat([drug_conv, protein_conv], dim=1)
        preds = self.dropout1(preds)
        preds = self.leaky_relu(self.fc1(preds))
        preds = self.dropout2(preds)
        preds = self.leaky_relu(self.fc2(preds))
        preds = self.dropout3(preds)
        preds = self.leaky_relu(self.fc3(preds))
        # preds = self.out(preds)
        
        return preds