|
import torch |
|
from torch.utils.data import DataLoader, Subset |
|
from torch.optim import AdamW |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from datasets import load_from_disk |
|
import esm |
|
import numpy as np |
|
import math |
|
import os |
|
from transformers import AutoTokenizer |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from transformers import get_linear_schedule_with_warmup |
|
from tqdm import tqdm |
|
from torch.cuda.amp import autocast, GradScaler |
|
import gc |
|
import pdb |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
|
max_epochs = 30 |
|
batch_size = 4 |
|
lr = 1e-3 |
|
accumulation_steps = 4 |
|
n_heads = 4 |
|
d_k = 64 |
|
d_v = 64 |
|
checkpoint_path = f'/home/tc415/muPPIt_embedding/checkpoints/lr{lr}_dk{d_k}_dv{d_v}_h{n_heads}' |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
print(f''' |
|
max_epochs = {max_epochs} |
|
batch_size = {batch_size} |
|
lr = {lr} |
|
d_k = {d_k} |
|
d_v = {d_v} |
|
accumulation_steps = {accumulation_steps} |
|
num_heads = {n_heads} |
|
checkpoint_path = {checkpoint_path} |
|
''') |
|
|
|
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
|
vhse8_values = { |
|
'A': [0.15, -1.11, -1.35, -0.92, 0.02, -0.91, 0.36, -0.48], |
|
'R': [-1.47, 1.45, 1.24, 1.27, 1.55, 1.47, 1.30, 0.83], |
|
'N': [-0.99, 0.00, 0.69, -0.37, -0.55, 0.85, 0.73, -0.80], |
|
'D': [-1.15, 0.67, -0.41, -0.01, -2.68, 1.31, 0.03, 0.56], |
|
'C': [0.18, -1.67, -0.21, 0.00, 1.20, -1.61, -0.19, -0.41], |
|
'Q': [-0.96, 0.12, 0.18, 0.16, 0.09, 0.42, -0.20, -0.41], |
|
'E': [-1.18, 0.40, 0.10, 0.36, -2.16, -0.17, 0.91, 0.36], |
|
'G': [-0.20, -1.53, -2.63, 2.28, -0.53, -1.18, -1.34, 1.10], |
|
'H': [-0.43, -0.25, 0.37, 0.19, 0.51, 1.28, 0.93, 0.65], |
|
'I': [1.27, 0.14, 0.30, -1.80, 0.30, -1.61, -0.16, -0.13], |
|
'L': [1.36, 0.07, 0.26, -0.80, 0.22, -1.37, 0.08, -0.62], |
|
'K': [-1.17, 0.70, 0.80, 1.64, 0.67, 1.63, 0.13, -0.01], |
|
'M': [1.01, -0.53, 0.43, 0.00, 0.23, 0.10, -0.86, -0.68], |
|
'F': [1.52, 0.61, 0.95, -0.16, 0.25, 0.28, -1.33, -0.65], |
|
'P': [0.22, -0.17, -0.50, -0.05, 0.01, -1.34, 0.19, 3.56], |
|
'S': [-0.67, -0.86, -1.07, -0.41, -0.32, 0.27, -0.64, 0.11], |
|
'T': [-0.34, -0.51, -0.55, -1.06, 0.01, -0.01, -0.79, 0.39], |
|
'W': [1.50, 2.06, 1.79, 0.75, 0.75, 0.13, -1.06, -0.85], |
|
'Y': [0.61, 1.60, 1.17, 0.73, 0.53, 0.25, -0.96, -0.52], |
|
'V': [0.76, -0.92, 0.17, -1.91, 0.22, -1.40, -0.24, -0.03], |
|
} |
|
|
|
aa_to_idx = {'A': 5, 'R': 10, 'N': 17, 'D': 13, 'C': 23, 'Q': 16, 'E': 9, 'G': 6, 'H': 21, 'I': 12, 'L': 4, 'K': 15, 'M': 20, 'F': 18, 'P': 14, 'S': 8, 'T': 11, 'W': 22, 'Y': 19, 'V': 7} |
|
|
|
vhse8_tensor = torch.zeros(33, 8) |
|
for aa, values in vhse8_values.items(): |
|
aa_index = aa_to_idx[aa] |
|
vhse8_tensor[aa_index] = torch.tensor(values) |
|
|
|
vhse8_tensor = vhse8_tensor.to(device) |
|
vhse8_tensor.requires_grad = False |
|
|
|
train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_skempi_classifier') |
|
val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_skempi_classifier') |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
def collate_fn(batch): |
|
|
|
binders = [] |
|
mutants = [] |
|
wildtypes = [] |
|
wt_affs = [] |
|
mut_affs = [] |
|
|
|
global tokenizer |
|
|
|
for b in batch: |
|
binder = torch.tensor(b['binder_input_ids']['input_ids']) |
|
mutant = torch.tensor(b['mutant_input_ids']['input_ids']) |
|
wildtype = torch.tensor(b['wildtype_input_ids']['input_ids']) |
|
|
|
if binder.dim() == 0 or binder.numel() == 0 or mutant.dim() == 0 or mutant.numel() == 0 or wildtype.dim() == 0 or wildtype.numel() == 0: |
|
continue |
|
binders.append(binder) |
|
mutants.append(mutant) |
|
wildtypes.append(wildtype) |
|
|
|
wt_affs.append(b['wt_aff']) |
|
mut_affs.append(b['mut_aff']) |
|
|
|
|
|
|
|
try: |
|
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
mutant_input_ids = torch.nn.utils.rnn.pad_sequence(mutants, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
wildtype_input_ids = torch.nn.utils.rnn.pad_sequence(wildtypes, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
except: |
|
pdb.set_trace() |
|
|
|
wt_affs = torch.tensor(wt_affs) |
|
mut_affs = torch.tensor(mut_affs) |
|
|
|
return { |
|
'binder_input_ids': binder_input_ids.int(), |
|
'mutant_input_ids': mutant_input_ids.int(), |
|
'wildtype_input_ids': wildtype_input_ids.int(), |
|
'wt_affs': wt_affs, |
|
'mut_affs': mut_affs, |
|
} |
|
|
|
class muPPIt(torch.nn.Module): |
|
def __init__(self, d_node, d_k, d_v, n_heads, lr): |
|
super(muPPIt, self).__init__() |
|
|
|
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
for param in self.esm.parameters(): |
|
param.requires_grad = False |
|
|
|
self.q = nn.Linear(d_node, n_heads * d_k) |
|
self.k = nn.Linear(d_node, n_heads * d_k) |
|
self.v = nn.Linear(d_node, n_heads * d_v) |
|
self.layer_norm = torch.nn.LayerNorm(n_heads * d_v) |
|
|
|
self.map = torch.nn.Sequential( |
|
torch.nn.Linear(n_heads * d_v, (n_heads * d_v) // 2), |
|
torch.nn.SiLU(), |
|
torch.nn.Linear((n_heads * d_v) // 2, (n_heads * d_v) // 4), |
|
torch.nn.SiLU(), |
|
torch.nn.Linear((n_heads * d_v) // 4, 2) |
|
) |
|
|
|
for layer in self.map: |
|
if isinstance(layer, nn.Linear): |
|
nn.init.kaiming_uniform_(layer.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
|
if layer.bias is not None: |
|
nn.init.zeros_(layer.bias) |
|
|
|
self.learning_rate = lr |
|
self.n_heads = n_heads |
|
self.d_k = d_k |
|
self.d_v = d_v |
|
self.d_node = d_node |
|
|
|
|
|
self.easy_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/ppiref_index_classifier.npy').tolist() |
|
self.hard_example_indices = np.load('/home/tc415/muPPIt_embedding/dataset/skempi_index_classifier.npy').tolist() |
|
|
|
def forward(self, binder_tokens, wt_tokens, mut_tokens): |
|
global vhse8_tensor |
|
|
|
with torch.no_grad(): |
|
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int() |
|
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=False)["representations"][33] * binder_pad_mask.unsqueeze(-1) |
|
binder_vhse8 = vhse8_tensor[binder_tokens] |
|
binder_embed = torch.concat([binder_embed, binder_vhse8], dim=-1) |
|
|
|
mut_pad_mask = (mut_tokens != self.alphabet.padding_idx).int() |
|
mut_embed = self.esm(mut_tokens, repr_layers=[33], return_contacts=False)["representations"][33] * mut_pad_mask.unsqueeze(-1) |
|
mut_vhse8 = vhse8_tensor[mut_tokens] |
|
mut_embed = torch.concat([mut_embed, mut_vhse8], dim=-1) |
|
|
|
wt_pad_mask = (wt_tokens != self.alphabet.padding_idx).int() |
|
wt_embed = self.esm(wt_tokens, repr_layers=[33], return_contacts=False)["representations"][33] * wt_pad_mask.unsqueeze(-1) |
|
wt_vhse8 = vhse8_tensor[wt_tokens] |
|
wt_embed = torch.concat([wt_embed, wt_vhse8], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
binder_wt_reciprocal = self.cross_attention(binder_embed, wt_embed) |
|
binder_mut_reciprocal = self.cross_attention(binder_embed, mut_embed) |
|
|
|
binder_wt_reciprocal = self.layer_norm(binder_wt_reciprocal).mean(dim=1) |
|
binder_mut_reciprocal = self.layer_norm(binder_mut_reciprocal).mean(dim=1) |
|
|
|
difference = binder_wt_reciprocal - binder_mut_reciprocal |
|
|
|
logits = self.map(difference) |
|
|
|
return logits |
|
|
|
def cross_attention(self, embed_1, embed_2): |
|
B, L1, _ = embed_1.shape |
|
_, L2, _ = embed_2.shape |
|
|
|
Q = self.q(embed_1).view(B, L1, self.n_heads, self.d_k) |
|
K = self.k(embed_2).view(B, L2, self.n_heads, self.d_k) |
|
V = self.v(embed_2).view(B, L2, self.n_heads, self.d_v) |
|
|
|
Q = Q.transpose(1, 2) |
|
K = K.transpose(1, 2) |
|
V = V.transpose(1, 2) |
|
|
|
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) |
|
attention_weights = F.softmax(attention_scores, dim=-1) |
|
|
|
output = torch.matmul(attention_weights, V) |
|
|
|
output = output.transpose(1, 2).contiguous().view(B, L1, self.n_heads * self.d_v) |
|
|
|
return output |
|
|
|
def compute_loss(self, binder_tokens, wt_tokens, mut_tokens, wt_affs, mut_affs): |
|
p_wt = wt_affs / (wt_affs + mut_affs) |
|
p_mut = mut_affs / (wt_affs + mut_affs) |
|
|
|
logits = self.forward(binder_tokens, wt_tokens, mut_tokens) |
|
prediction = F.softmax(logits, dim=-1) |
|
target = torch.stack((p_wt, p_mut), dim=1) |
|
loss = F.kl_div(prediction.log(), target, reduction='batchmean') |
|
|
|
|
|
|
|
logits_rev = self.forward(binder_tokens, mut_tokens, wt_tokens) |
|
prediction_rev = F.softmax(logits_rev, dim=-1) |
|
target_rev = torch.stack((p_mut, p_wt), dim=1) |
|
loss_rev = F.kl_div(prediction_rev.log(), target_rev, reduction='batchmean') |
|
|
|
loss_diff = abs((prediction[:, 0] - prediction_rev[:, 1]).mean()) |
|
|
|
loss_sum = loss + loss_rev + loss_diff |
|
|
|
correct = (torch.argmax(prediction, dim=-1) == torch.argmax(target, dim=-1)).sum() + \ |
|
(torch.argmax(prediction_rev, dim=-1) == torch.argmax(target_rev, dim=-1)).sum() |
|
|
|
return loss_sum, correct |
|
|
|
def step(self, batch): |
|
binder_tokens = batch['binder_input_ids'] |
|
mut_tokens = batch['mutant_input_ids'] |
|
wt_tokens = batch['wildtype_input_ids'] |
|
wt_affs = batch['wt_affs'] |
|
mut_affs = batch['mut_affs'] |
|
|
|
binder_tokens = binder_tokens.to(device) |
|
wt_tokens = wt_tokens.to(device) |
|
mut_tokens = mut_tokens.to(device) |
|
wt_affs = wt_affs.to(device) |
|
mut_affs = mut_affs.to(device) |
|
|
|
loss, accuracy = self.compute_loss(binder_tokens, wt_tokens, mut_tokens, wt_affs, mut_affs) |
|
|
|
return loss, accuracy |
|
|
|
|
|
def train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size, max_epochs=10, accumulation_steps=4): |
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4) |
|
|
|
max_val_acc = 0.0 |
|
for epoch in range(max_epochs): |
|
print(f"Epoch {epoch + 1}/{max_epochs}") |
|
|
|
if epoch < 3: |
|
train_subset = Subset(train_dataset, model.easy_example_indices) |
|
else: |
|
num_hard_examples = int((epoch / max_epochs) * len(model.hard_example_indices)) |
|
selected_hard_indices = model.hard_example_indices[:num_hard_examples] |
|
combined_indices = model.easy_example_indices + selected_hard_indices |
|
train_subset = Subset(train_dataset, combined_indices) |
|
|
|
train_loader = DataLoader(train_subset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4) |
|
scaler = GradScaler() |
|
|
|
model.train() |
|
running_loss = 0.0 |
|
optimizer.zero_grad() |
|
|
|
for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)): |
|
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} |
|
|
|
with autocast(): |
|
loss, _ = model.step(batch) |
|
|
|
scaler.scale(loss).backward() |
|
|
|
if (batch_idx + 1) % accumulation_steps == 0: |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad() |
|
|
|
if scheduler.last_epoch < warmup_steps: |
|
scheduler.step() |
|
else: |
|
cosine_scheduler.step() |
|
|
|
running_loss += loss.item() |
|
|
|
print(f"Epoch {epoch}: Training Loss = {running_loss / len(train_loader)}") |
|
|
|
del train_loader, running_loss |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
val_correct = 0 |
|
val_acc = 0.0 |
|
with torch.no_grad(): |
|
for batch in tqdm(val_loader, total=len(val_loader)): |
|
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} |
|
val_loss_batch, val_correct_batch = model.step(batch) |
|
val_loss += val_loss_batch.item() |
|
val_correct += val_correct_batch.item() |
|
val_acc = val_correct / (2*len(val_dataset)) |
|
print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader)}\tVal Acc = {val_acc}") |
|
|
|
if val_acc > max_val_acc: |
|
max_val_acc = val_acc |
|
|
|
global checkpoint_path |
|
torch.save(model.state_dict(), os.path.join(checkpoint_path, f"epoch={epoch}_acc={round(val_acc, 2)}")) |
|
|
|
|
|
model = muPPIt(d_node=1288, d_k=d_k, d_v=d_v, n_heads=n_heads, lr=lr).to(device) |
|
|
|
optimizer = AdamW(model.parameters(), lr=model.learning_rate, betas=(0.9, 0.98), weight_decay=1e-5) |
|
|
|
total_steps = len(train_dataset) // (batch_size*accumulation_steps) * max_epochs |
|
warmup_steps = int(0.1 * total_steps) |
|
|
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=warmup_steps, |
|
num_training_steps=total_steps |
|
) |
|
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0.1*lr) |
|
|
|
train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size=batch_size, max_epochs=max_epochs, accumulation_steps=accumulation_steps) |