|
import torch |
|
from torch.utils.data import DataLoader, Subset |
|
from torch.optim import AdamW |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from datasets import load_from_disk |
|
import esm |
|
import numpy as np |
|
import math |
|
import os |
|
from transformers import AutoTokenizer |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from transformers import get_linear_schedule_with_warmup |
|
from tqdm import tqdm |
|
from torch.cuda.amp import autocast, GradScaler |
|
import gc |
|
import pdb |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '1' |
|
|
|
|
|
max_epochs = 30 |
|
batch_size = 4 |
|
lr = 1e-4 |
|
num_layers = 4 |
|
num_heads = 4 |
|
accumulation_steps = 4 |
|
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0' |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
print(f''' |
|
max_epochs = 30 |
|
batch_size = 4 |
|
lr = 1e-4 |
|
num_layers = 4 |
|
num_heads = 4 |
|
accumulation_steps = 4 |
|
checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0' |
|
''') |
|
|
|
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
|
train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_generator') |
|
val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_generator') |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
print(len(train_dataset), len(val_dataset)) |
|
|
|
def collate_fn(batch): |
|
|
|
binders = [] |
|
targets = [] |
|
|
|
global tokenizer |
|
|
|
for b in batch: |
|
binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1]) |
|
target = torch.tensor(b['target_input_ids']['input_ids'][1:-1]) |
|
|
|
if binder.dim() == 0 or binder.numel() == 0 or target.dim() == 0 or target.numel() == 0: |
|
continue |
|
binders.append(binder) |
|
targets.append(target) |
|
|
|
|
|
try: |
|
binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
|
|
target_input_ids = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id) |
|
except: |
|
pdb.set_trace() |
|
|
|
|
|
return { |
|
'binder_input_ids': binder_input_ids.long(), |
|
'target_input_ids': target_input_ids.long(), |
|
} |
|
|
|
|
|
def RoPE(x, seq_dim=0): |
|
""" |
|
Applies Rotary Positional Encoding to the input embeddings. |
|
:param x: Input tensor (seq_len, batch_size, embed_dim) |
|
:param seq_dim: The sequence dimension, usually 0 (first dimension in (seq_len, batch_size, embed_dim)) |
|
:return: Tensor with RoPE applied (seq_len, batch_size, embed_dim) |
|
""" |
|
seq_len = x.shape[seq_dim] |
|
d_model = x.shape[-1] |
|
|
|
|
|
theta = torch.arange(0, d_model, 2, dtype=torch.float32) / d_model |
|
theta = 10000 ** (-theta) |
|
seq_idx = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1) |
|
|
|
|
|
sin_emb = torch.sin(seq_idx * theta) |
|
cos_emb = torch.cos(seq_idx * theta) |
|
|
|
sin_emb = sin_emb.unsqueeze(1) |
|
cos_emb = cos_emb.unsqueeze(1) |
|
|
|
x1, x2 = x[..., ::2], x[..., 1::2] |
|
|
|
cos_emb = cos_emb.to(x1.device) |
|
sin_emb = sin_emb.to(x1.device) |
|
|
|
|
|
x_rotated = torch.cat([x1 * cos_emb - x2 * sin_emb, x1 * sin_emb + x2 * cos_emb], dim=-1) |
|
return x_rotated |
|
|
|
class BinderGenerator(nn.Module): |
|
def __init__(self, vocab_size=24, embed_dim=1280, num_heads=4, num_layers=4, lr=1e-4): |
|
super(BinderGenerator, self).__init__() |
|
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
for param in self.esm.parameters(): |
|
param.requires_grad = False |
|
|
|
self.transformer = nn.Transformer(d_model=embed_dim, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers) |
|
self.fc_out = nn.Linear(embed_dim, vocab_size) |
|
|
|
self.criterion = nn.CrossEntropyLoss(ignore_index=self.alphabet.padding_idx) |
|
self.vocab_size = vocab_size |
|
self.learning_rate = lr |
|
|
|
def forward(self, binder_tokens, target_tokens): |
|
with torch.no_grad(): |
|
binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int() |
|
binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1) |
|
|
|
target_pad_mask = (target_tokens != self.alphabet.padding_idx).int() |
|
target_embed = self.esm(target_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * target_pad_mask.unsqueeze(-1) |
|
|
|
binder_embed = binder_embed.transpose(0,1) |
|
target_embed = target_embed.transpose(0,1) |
|
|
|
binder_embed = RoPE(binder_embed) |
|
target_embed = RoPE(target_embed) |
|
|
|
output = self.transformer(binder_embed, target_embed) |
|
return self.fc_out(output).transpose(0,1) |
|
|
|
def compute_loss(self, binder_tokens, target_tokens): |
|
output = self.forward(binder_tokens, target_tokens) |
|
|
|
loss = self.criterion(output[:, :-1, :].reshape(-1, self.vocab_size), target_tokens[:, 1:].reshape(-1)) |
|
|
|
return loss |
|
|
|
def step(self, batch, compute_acc=False): |
|
binder_tokens = batch['binder_input_ids'] |
|
target_tokens = batch['target_input_ids'] |
|
|
|
binder_tokens = binder_tokens.to(device) |
|
target_tokens = target_tokens.to(device) |
|
|
|
loss = self.compute_loss(binder_tokens, target_tokens) |
|
|
|
if compute_acc: |
|
preds = torch.argmax(output[:-1], dim=-1) |
|
correct = (preds == target_tokens[1:]).sum().item() |
|
accuracy = correct / (target_tokens[1:] != self.alphabet.padding_idx).sum().item() |
|
return loss, accuracy |
|
else: |
|
return loss |
|
|
|
|
|
def train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size, max_epochs=10, accumulation_steps=4): |
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4) |
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4) |
|
|
|
max_val_acc = 0 |
|
for epoch in range(max_epochs): |
|
print(f"Epoch {epoch + 1}/{max_epochs}") |
|
|
|
scaler = GradScaler() |
|
|
|
model.train() |
|
running_loss = 0.0 |
|
optimizer.zero_grad() |
|
|
|
for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)): |
|
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} |
|
|
|
with autocast(): |
|
loss = model.step(batch) |
|
|
|
scaler.scale(loss).backward() |
|
|
|
if (batch_idx + 1) % accumulation_steps == 0: |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad() |
|
|
|
if scheduler.last_epoch < warmup_steps: |
|
scheduler.step() |
|
else: |
|
cosine_scheduler.step() |
|
|
|
running_loss += loss.item() |
|
|
|
print(f"Epoch {epoch}: Training Loss = {running_loss / len(train_loader)}") |
|
|
|
del train_loader, running_loss |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
val_acc = 0.0 |
|
with torch.no_grad(): |
|
for batch in tqdm(val_loader, total=len(val_loader)): |
|
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} |
|
val_loss_batch, val_acc_batch = model.step(batch, compute_acc=True) |
|
val_loss += val_loss_batch.item() |
|
val_acc += val_acc_batch.item() |
|
|
|
print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader)}\tVal Acc = {val_acc / len(val_dataset)}") |
|
|
|
if val_acc > max_val_acc: |
|
max_val_acc = val_acc |
|
|
|
global checkpoint_path |
|
torch.save(model.state_dict(), os.path.join(checkpoint_path, f"epoch={epoch}_acc={round(val_acc / len(val_dataset), 2)}")) |
|
|
|
|
|
|
|
model = BinderGenerator(vocab_size=24, embed_dim=1280, num_heads=num_heads, num_layers=num_layers, lr=lr).to(device) |
|
optimizer = AdamW(model.parameters(), lr=model.learning_rate, betas=(0.9, 0.95), weight_decay=1e-5) |
|
|
|
total_steps = len(train_dataset) // (batch_size*accumulation_steps) * max_epochs |
|
warmup_steps = int(0.1 * total_steps) |
|
|
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=warmup_steps, |
|
num_training_steps=total_steps |
|
) |
|
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0.1*lr) |
|
|
|
|
|
train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size=batch_size, max_epochs=max_epochs, accumulation_steps=accumulation_steps) |
|
|