import torch from torch.utils.data import DataLoader, Subset from torch.optim import AdamW import torch.nn.functional as F import torch.nn as nn from datasets import load_from_disk import esm import numpy as np import math import os from transformers import AutoTokenizer from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import get_linear_schedule_with_warmup from tqdm import tqdm from torch.cuda.amp import autocast, GradScaler import gc import pdb os.environ['CUDA_VISIBLE_DEVICES'] = '1' ##################### Hyper-parameters ############################################# max_epochs = 30 batch_size = 4 lr = 1e-4 num_layers = 4 num_heads = 4 accumulation_steps = 4 checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f''' max_epochs = 30 batch_size = 4 lr = 1e-4 num_layers = 4 num_heads = 4 accumulation_steps = 4 checkpoint_path = '/home/tc415/muPPIt_embedding/checkpoints/generator_0' ''') #################################################################################### os.makedirs(checkpoint_path, exist_ok=True) train_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/train/ppiref_generator') val_dataset = load_from_disk('/home/tc415/muPPIt_embedding/dataset/val/ppiref_generator') tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") print(len(train_dataset), len(val_dataset)) def collate_fn(batch): # Unpack the batch binders = [] targets = [] global tokenizer for b in batch: binder = torch.tensor(b['binder_input_ids']['input_ids'][1:-1]) target = torch.tensor(b['target_input_ids']['input_ids'][1:-1]) if binder.dim() == 0 or binder.numel() == 0 or target.dim() == 0 or target.numel() == 0: continue binders.append(binder) # shape: 1*L1 -> L1 targets.append(target) # shape: 1*L2 -> L2 # Collate the tensors using torch's pad_sequence try: binder_input_ids = torch.nn.utils.rnn.pad_sequence(binders, batch_first=True, padding_value=tokenizer.pad_token_id) target_input_ids = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id) except: pdb.set_trace() # Return the collated batch return { 'binder_input_ids': binder_input_ids.long(), 'target_input_ids': target_input_ids.long(), } def RoPE(x, seq_dim=0): """ Applies Rotary Positional Encoding to the input embeddings. :param x: Input tensor (seq_len, batch_size, embed_dim) :param seq_dim: The sequence dimension, usually 0 (first dimension in (seq_len, batch_size, embed_dim)) :return: Tensor with RoPE applied (seq_len, batch_size, embed_dim) """ seq_len = x.shape[seq_dim] d_model = x.shape[-1] # Create the positions and the sine-cosine rotational matrices theta = torch.arange(0, d_model, 2, dtype=torch.float32) / d_model theta = 10000 ** (-theta) # scaling factor for RoPE seq_idx = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1) # Compute sine and cosine embedding for each position sin_emb = torch.sin(seq_idx * theta) cos_emb = torch.cos(seq_idx * theta) sin_emb = sin_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2] cos_emb = cos_emb.unsqueeze(1) # [seq_len, 1, embed_dim//2] x1, x2 = x[..., ::2], x[..., 1::2] # Split embedding into even and odd indices cos_emb = cos_emb.to(x1.device) sin_emb = sin_emb.to(x1.device) # Apply rotary transformation x_rotated = torch.cat([x1 * cos_emb - x2 * sin_emb, x1 * sin_emb + x2 * cos_emb], dim=-1) return x_rotated class BinderGenerator(nn.Module): def __init__(self, vocab_size=24, embed_dim=1280, num_heads=4, num_layers=4, lr=1e-4): super(BinderGenerator, self).__init__() self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() for param in self.esm.parameters(): param.requires_grad = False self.transformer = nn.Transformer(d_model=embed_dim, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers) self.fc_out = nn.Linear(embed_dim, vocab_size) self.criterion = nn.CrossEntropyLoss(ignore_index=self.alphabet.padding_idx) self.vocab_size = vocab_size self.learning_rate = lr def forward(self, binder_tokens, target_tokens): with torch.no_grad(): binder_pad_mask = (binder_tokens != self.alphabet.padding_idx).int() binder_embed = self.esm(binder_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * binder_pad_mask.unsqueeze(-1) target_pad_mask = (target_tokens != self.alphabet.padding_idx).int() target_embed = self.esm(target_tokens, repr_layers=[33], return_contacts=True)["representations"][33] * target_pad_mask.unsqueeze(-1) binder_embed = binder_embed.transpose(0,1) target_embed = target_embed.transpose(0,1) binder_embed = RoPE(binder_embed) # [src_len, batch_size, embed_dim] target_embed = RoPE(target_embed) # [tgt_len, batch_size, embed_dim] output = self.transformer(binder_embed, target_embed) # [tgt_len, batch_size, embed_dim] return self.fc_out(output).transpose(0,1) # [batch_size, tgt_len, vocab_size] def compute_loss(self, binder_tokens, target_tokens): output = self.forward(binder_tokens, target_tokens) loss = self.criterion(output[:, :-1, :].reshape(-1, self.vocab_size), target_tokens[:, 1:].reshape(-1)) return loss def step(self, batch, compute_acc=False): binder_tokens = batch['binder_input_ids'] target_tokens = batch['target_input_ids'] binder_tokens = binder_tokens.to(device) target_tokens = target_tokens.to(device) loss = self.compute_loss(binder_tokens, target_tokens) if compute_acc: preds = torch.argmax(output[:-1], dim=-1) correct = (preds == target_tokens[1:]).sum().item() accuracy = correct / (target_tokens[1:] != self.alphabet.padding_idx).sum().item() return loss, accuracy else: return loss def train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size, max_epochs=10, accumulation_steps=4): train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False, num_workers=4) max_val_acc = 0 for epoch in range(max_epochs): print(f"Epoch {epoch + 1}/{max_epochs}") scaler = GradScaler() model.train() running_loss = 0.0 optimizer.zero_grad() for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)): batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} # Transfer batch to GPU with autocast(): loss = model.step(batch) scaler.scale(loss).backward() if (batch_idx + 1) % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() if scheduler.last_epoch < warmup_steps: scheduler.step() else: cosine_scheduler.step() running_loss += loss.item() print(f"Epoch {epoch}: Training Loss = {running_loss / len(train_loader)}") del train_loader, running_loss gc.collect() torch.cuda.empty_cache() model.eval() val_loss = 0.0 val_acc = 0.0 with torch.no_grad(): for batch in tqdm(val_loader, total=len(val_loader)): batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()} val_loss_batch, val_acc_batch = model.step(batch, compute_acc=True) val_loss += val_loss_batch.item() val_acc += val_acc_batch.item() print(f"Epoch {epoch}: Val Loss = {val_loss / len(val_loader)}\tVal Acc = {val_acc / len(val_dataset)}") if val_acc > max_val_acc: max_val_acc = val_acc global checkpoint_path torch.save(model.state_dict(), os.path.join(checkpoint_path, f"epoch={epoch}_acc={round(val_acc / len(val_dataset), 2)}")) model = BinderGenerator(vocab_size=24, embed_dim=1280, num_heads=num_heads, num_layers=num_layers, lr=lr).to(device) optimizer = AdamW(model.parameters(), lr=model.learning_rate, betas=(0.9, 0.95), weight_decay=1e-5) total_steps = len(train_dataset) // (batch_size*accumulation_steps) * max_epochs # Assuming batch_size=32, max_epochs=10 warmup_steps = int(0.1 * total_steps) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps ) cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=0.1*lr) train(model, optimizer, scheduler, cosine_scheduler, train_dataset, val_dataset, batch_size=batch_size, max_epochs=max_epochs, accumulation_steps=accumulation_steps)