cels / train_cross_classifier.py
alexandraroze's picture
fixed config
b265c62
import argparse
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing_extensions import Optional
from src.dataset import RandomPairDataset
from src.models import CrossAttentionClassifier, VGGLikeEncode
def visualize_attention(attn_heatmap, epoch: int):
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(attn_heatmap, cmap="hot", interpolation="nearest")
plt.colorbar(im, fraction=0.046, pad=0.04)
plt.title(f"Attention Heatmap (Flatten 64x64) | Epoch {epoch}")
wandb.log({"Flatten Attention Heatmap": wandb.Image(fig, caption=f"Flatten 64x64 | Epoch {epoch}")})
plt.close(fig)
def get_data_loaders(
num_train_samples: int,
num_val_samples: int,
batch_size: int,
num_workers: int = 0,
shape_params: Optional[dict] = None,
):
train_dataset = RandomPairDataset(
shape_params=shape_params,
num_samples=num_train_samples,
train=True
)
val_dataset = RandomPairDataset(
shape_params=shape_params,
num_samples=num_val_samples,
train=False
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
return train_loader, val_loader
def build_model(
path_to_encoder: str,
lr: float,
weight_decay: float,
step_size: int,
gamma: float,
device: torch.device
):
encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=False)
encoder.load_state_dict(torch.load(path_to_encoder))
model = CrossAttentionClassifier(encoder=encoder)
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
model.parameters(),
lr=lr,
weight_decay=weight_decay
)
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
return model, criterion, optimizer, scheduler
def train_epoch(
model: nn.Module,
criterion: nn.Module,
optimizer: optim.Optimizer,
train_loader: DataLoader,
device: torch.device
):
model.train()
running_loss = 0.0
correct = 0
total = 0
for img1, img2, labels in tqdm(train_loader, desc="Training", leave=False):
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
optimizer.zero_grad()
logits, attn_weights = model(img1, img2)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * img1.size(0)
preds = (torch.sigmoid(logits) > 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / len(train_loader.dataset)
epoch_acc = correct / total
return epoch_loss, epoch_acc
@torch.no_grad()
def validate(
model: nn.Module,
criterion: nn.Module,
val_loader: DataLoader,
device: torch.device
):
model.eval()
running_loss = 0.0
correct = 0
total = 0
for img1, img2, labels in tqdm(val_loader, desc="Validation", leave=False):
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
logits, attn_weights = model(img1, img2)
loss = criterion(logits, labels)
running_loss += loss.item() * img1.size(0)
preds = (torch.sigmoid(logits) > 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / len(val_loader.dataset)
epoch_acc = correct / total
return epoch_loss, epoch_acc
def train(
model: nn.Module,
criterion: nn.Module,
optimizer: optim.Optimizer,
scheduler,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
num_epochs: int = 30,
save_path: str = "best_attention_classifier.pth"
):
best_val_loss = float("inf")
epochs_no_improve = 0
print("Start training...")
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
train_loss, train_acc = train_epoch(model, criterion, optimizer, train_loader, device)
val_loss, val_acc = validate(model, criterion, val_loader, device)
scheduler.step()
wandb.log({
"epoch": epoch + 1,
"train_loss": train_loss,
"train_acc": train_acc,
"val_loss": val_loss,
"val_acc": val_acc,
"lr": optimizer.param_groups[0]["lr"],
})
print(
f"learning rate: {optimizer.param_groups[0]['lr']:.6f}, "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
)
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), save_path)
epochs_no_improve = 0
else:
epochs_no_improve += 1
with torch.no_grad():
sample_img1, sample_img2, sample_labels = next(iter(val_loader))
sample_img1, sample_img2 = sample_img1.to(device), sample_img2.to(device)
_, sample_attn_weights = model(sample_img1, sample_img2)
wandb.log({
"attention_std": sample_attn_weights.std().item(),
"attention_mean": sample_attn_weights.mean().item(),
})
attn_heatmap = sample_attn_weights[0].detach().cpu().numpy()
visualize_attention(attn_heatmap, epoch)
def main(config):
wandb.init(project="cross_attention_classifier", config=config)
train_loader, val_loader = get_data_loaders(
shape_params=config["shape_params"],
num_train_samples=config["num_train_samples"],
num_val_samples=config["num_val_samples"],
batch_size=config["batch_size"]
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, criterion, optimizer, scheduler = build_model(
path_to_encoder=config["path_to_encoder"],
lr=config["lr"],
weight_decay=config["weight_decay"],
step_size=config["step_size"],
gamma=config["gamma"],
device=device
)
train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
train_loader=train_loader,
val_loader=val_loader,
device=device,
num_epochs=config["num_epochs"],
save_path=config["save_path"]
)
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train classifier model")
parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth")
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--lr", type=float, default=8e-5)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--step_size", type=int, default=10)
parser.add_argument("--gamma", type=float, default=0.1)
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--num_train_samples", type=int, default=10000)
parser.add_argument("--num_val_samples", type=int, default=2000)
parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth")
args = parser.parse_args()
config = {
"path_to_encoder": args.path_to_encoder,
"batch_size": args.batch_size,
"lr": args.lr,
"weight_decay": args.weight_decay,
"step_size": args.step_size,
"gamma": args.gamma,
"num_epochs": args.num_epochs,
"num_train_samples": args.num_train_samples,
"num_val_samples": args.num_val_samples,
"save_path": args.save_path,
}
if "shape_params" not in config:
config["shape_params"] = {}
main(config)