File size: 4,743 Bytes
65bd8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import esm
import numpy as np
import pdb


vhse8_values = {
    'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48],
    'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83],
    'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80],
    'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56],
    'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41],
    'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41],
    'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36],
    'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10],
    'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65],
    'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13],
    'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62],
    'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01],
    'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68],
    'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65],
    'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56],
    'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11],
    'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39],
    'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85],
    'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52],
    'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03],
}

aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7}

vhse8_tensor = torch.zeros(33, 8)
for aa, values in vhse8_values.items():
    aa_index = aa_to_idx[aa]
    vhse8_tensor[aa_index] = torch.tensor(values)


class muPPIt(torch.nn.Module):
    def __init__(self, d_node, num_heads, margin, lr, device):
        super(muPPIt, self).__init__()

        self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        for param in self.esm.parameters():
            param.requires_grad = False

        self.attention = torch.nn.MultiheadAttention(embed_dim=d_node, num_heads=num_heads)
        self.layer_norm = torch.nn.LayerNorm(d_node)

        self.map = torch.nn.Sequential(
            torch.nn.Linear(d_node, d_node // 2), 
            torch.nn.SiLU(),
            torch.nn.Linear(d_node // 2, 1)
        )

        for layer in self.map:
            if isinstance(layer, nn.Linear): 
                nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

        self.margin = margin
        self.learning_rate = lr
        self.loss_threshold = 20  # Set a threshold for identifying hard examples

        self.device = device

        # Easy and hard example tracking
        self.easy_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/ppiref_index.npy').tolist()
        self.hard_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/skempi_index.npy').tolist()

    def forward(self, binder_tokens, wt_embed, mut_embed):
        device = wt_embed.device

        global vhse8_tensor

        vhse8_tensor = vhse8_tensor.to(device)
        vhse8_tensor.requires_grad = False

        with torch.no_grad():
            binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=False)["representations"][33]
            binder_vhse8 = vhse8_tensor[binder_tokens]
            binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1)

        binder_wt = torch.concat([binder_embed, wt_embed], dim=1)
        binder_mut = torch.concat([binder_embed, mut_embed], dim=1)

        binder_wt = binder_wt.transpose(0,1)
        binder_mut = binder_mut.transpose(0,1)

        binder_wt_attn, _ = self.attention(binder_wt, binder_wt, binder_wt)
        binder_mut_attn, _ = self.attention(binder_mut, binder_mut, binder_mut)

        binder_wt_attn = binder_wt + binder_wt_attn
        binder_mut_attn = binder_mut + binder_mut_attn

        binder_wt_attn = binder_wt_attn.transpose(0, 1)
        binder_mut_attn = binder_mut_attn.transpose(0, 1)

        binder_wt_attn = self.layer_norm(binder_wt_attn)
        binder_mut_attn = self.layer_norm(binder_mut_attn)

        mapped_binder_wt = self.map(binder_wt_attn).squeeze(-1)      # B*(L1+L2)
        mapped_binder_mut = self.map(binder_mut_attn).squeeze(-1)     # B*(L1+L2)

        distance = torch.sqrt(torch.sum((mapped_binder_wt - mapped_binder_mut) ** 2, dim=-1))
        return distance

    def get_log_probs(self, binder_tokens, wt_embed, mut_embed):
        distance = self.forward(binder_tokens, wt_embed, mut_embed)
        log_prob = torch.log(1e-8 + torch.sigmoid(distance - self.margin))

        return log_prob