muPPIt / binder_generator_train.py
AlienChen's picture
Upload 139 files
65bd8af verified
import torch
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW
import torch.nn.functional as F
import torch.nn as nn
from datasets import load_from_disk
import esm
import numpy as np
import math
import os
from transformers import AutoTokenizer
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import gc
import pdb
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
##################### Hyper-parameters #############################################
max_epochs = 30
batch_size = 4
lr = 1e-4
num_layers = 4
num_heads = 4
accumulation_steps = 4
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'''
max_epochs = 30
batch_size = 4
lr = 1e-4
num_layers = 4
num_heads = 4
accumulation_steps = 4
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0'
''')
####################################################################################
os.makedirs(checkpoint_path, exist_ok=True)
train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_generator')
val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_generator')
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
print(len(train_dataset), len(val_dataset))
def collate_fn(batch):
# Unpack the batch
binders = []
targets = []
global tokenizer
for b in batch:
binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1])
target = torch.tensor(b['target_input_ids']['input_ids'][1:-1])
if binder.dim() == 0 or binder.numel() == 0 or target.dim() == 0 or target.numel() == 0:
continue
binders.append(binder) # shape: 1*L1 -> L1
targets.append(target) # shape: 1*L2 -> L2
# Collate the tensors using torch's pad_sequence
try:
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id)
target_input_ids = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id)
except:
pdb.set_trace()
# Return the collated batch
return {
'binder_input_ids': binder_input_ids.long(),
'target_input_ids': target_input_ids.long(),
}
def RoPE(x, seq_dim=0):
"""
Applies Rotary Positional Encoding to the input embeddings.
:param x: Input tensor (seq_len, batch_size, embed_dim)
:param seq_dim: The sequence dimension, usually 0 (first dimension in (seq_len, batch_size, embed_dim))
:return: Tensor with RoPE applied (seq_len, batch_size, embed_dim)
"""
seq_len = x.shape[seq_dim]
d_model = x.shape[-1]
# Create the positions and the sine-cosine rotational matrices
theta = torch.arange(0, d_model, 2, dtype=torch.float32) / d_model
theta = 10000 ** (-theta) # scaling factor for RoPE
seq_idx = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1)
# Compute sine and cosine embedding for each position
sin_emb = torch.sin(seq_idx * theta)
cos_emb = torch.cos(seq_idx * theta)
sin_emb = sin_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2]
cos_emb = cos_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2]
x1, x2 = x[..., ::2], x[..., 1::2] # Split embedding into even and odd indices
cos_emb = cos_emb.to(x1.device)
sin_emb = sin_emb.to(x1.device)
# Apply rotary transformation
x_rotated = torch.cat([x1 * cos_emb - x2 * sin_emb, x1 * sin_emb + x2 * cos_emb], dim=-1)
return x_rotated
class BinderGenerator(nn.Module):
def __init__(self, vocab_size=24, embed_dim=1280, num_heads=4, num_layers=4, lr=1e-4):
super(BinderGenerator, self).__init__()
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
for param in self.esm.parameters():
param.requires_grad = False
self.transformer = nn.Transformer(d_model=embed_dim, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers)
self.fc_out = nn.Linear(embed_dim, vocab_size)
self.criterion = nn.CrossEntropyLoss(ignore_index=self.alphabet.padding_idx)
self.vocab_size = vocab_size
self.learning_rate = lr
def forward(self, binder_tokens, target_tokens):
with torch.no_grad():
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int()
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1)
target_pad_mask = (target_tokens != self.alphabet.padding_idx).int()
target_embed = self.esm(target_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * target_pad_mask.unsqueeze(-1)
binder_embed = binder_embed.transpose(0,1)
target_embed = target_embed.transpose(0,1)
binder_embed = RoPE(binder_embed) # [src_len, batch_size, embed_dim]
target_embed = RoPE(target_embed) # [tgt_len, batch_size, embed_dim]
output = self.transformer(binder_embed, target_embed) # [tgt_len, batch_size, embed_dim]
return self.fc_out(output).transpose(0,1) # [batch_size, tgt_len, vocab_size]
def compute_loss(self, binder_tokens, target_tokens):
output = self.forward(binder_tokens, target_tokens)
loss = self.criterion(output[:, :-1, :].reshape(-1, self.vocab_size), target_tokens[:, 1:].reshape(-1))
return loss
def step(self, batch, compute_acc=False):
binder_tokens = batch['binder_input_ids']
target_tokens = batch['target_input_ids']
binder_tokens = binder_tokens.to(device)
target_tokens = target_tokens.to(device)
loss = self.compute_loss(binder_tokens, target_tokens)
if compute_acc:
preds = torch.argmax(output[:-1], dim=-1)
correct = (preds == target_tokens[1:]).sum().item()
accuracy = correct / (target_tokens[1:] != self.alphabet.padding_idx).sum().item()
return loss, accuracy
else:
return loss
def train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size, max_epochs=10, accumulation_steps=4):
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4)
max_val_acc = 0
for epoch in range(max_epochs):
print(f"Epoch {epoch + 1}/{max_epochs}")
scaler = GradScaler()
model.train()
running_loss = 0.0
optimizer.zero_grad()
for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} # Transfer batch to GPU
with autocast():
loss = model.step(batch)
scaler.scale(loss).backward()
if (batch_idx + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if scheduler.last_epoch < warmup_steps:
scheduler.step()
else:
cosine_scheduler.step()
running_loss += loss.item()
print(f"Epoch {epoch}: Training Loss = {running_loss / len(train_loader)}")
del train_loader, running_loss
gc.collect()
torch.cuda.empty_cache()
model.eval()
val_loss = 0.0
val_acc = 0.0
with torch.no_grad():
for batch in tqdm(val_loader, total=len(val_loader)):
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
val_loss_batch, val_acc_batch = model.step(batch, compute_acc=True)
val_loss += val_loss_batch.item()
val_acc += val_acc_batch.item()
print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader)}\tVal Acc = {val_acc / len(val_dataset)}")
if val_acc > max_val_acc:
max_val_acc = val_acc
global checkpoint_path
torch.save(model.state_dict(), os.path.join(checkpoint_path, f"epoch={epoch}_acc={round(val_acc / len(val_dataset), 2)}"))
model = BinderGenerator(vocab_size=24, embed_dim=1280, num_heads=num_heads, num_layers=num_layers, lr=lr).to(device)
optimizer = AdamW(model.parameters(), lr=model.learning_rate, betas=(0.9, 0.95), weight_decay=1e-5)
total_steps = len(train_dataset) // (batch_size*accumulation_steps) * max_epochs # Assuming batch_size=32, max_epochs=10
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0.1*lr)
train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size=batch_size, max_epochs=max_epochs, accumulation_steps=accumulation_steps)