Spaces:
Sleeping
Sleeping
import argparse | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import wandb | |
from torch.optim.lr_scheduler import StepLR | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from typing_extensions import Optional | |
from src.dataset import RandomPairDataset | |
from src.models import CrossAttentionClassifier, VGGLikeEncode | |
def visualize_attention(attn_heatmap, epoch: int): | |
fig, ax = plt.subplots(figsize=(6, 6)) | |
im = ax.imshow(attn_heatmap, cmap="hot", interpolation="nearest") | |
plt.colorbar(im, fraction=0.046, pad=0.04) | |
plt.title(f"Attention Heatmap (Flatten 64x64) | Epoch {epoch}") | |
wandb.log({"Flatten Attention Heatmap": wandb.Image(fig, caption=f"Flatten 64x64 | Epoch {epoch}")}) | |
plt.close(fig) | |
def get_data_loaders( | |
num_train_samples: int, | |
num_val_samples: int, | |
batch_size: int, | |
num_workers: int = 0, | |
shape_params: Optional[dict] = None, | |
): | |
train_dataset = RandomPairDataset( | |
shape_params=shape_params, | |
num_samples=num_train_samples, | |
train=True | |
) | |
val_dataset = RandomPairDataset( | |
shape_params=shape_params, | |
num_samples=num_val_samples, | |
train=False | |
) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=num_workers | |
) | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers | |
) | |
return train_loader, val_loader | |
def build_model( | |
path_to_encoder: str, | |
lr: float, | |
weight_decay: float, | |
step_size: int, | |
gamma: float, | |
device: torch.device | |
): | |
encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=False) | |
encoder.load_state_dict(torch.load(path_to_encoder)) | |
model = CrossAttentionClassifier(encoder=encoder) | |
model = model.to(device) | |
criterion = nn.BCEWithLogitsLoss() | |
optimizer = optim.Adam( | |
model.parameters(), | |
lr=lr, | |
weight_decay=weight_decay | |
) | |
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) | |
return model, criterion, optimizer, scheduler | |
def train_epoch( | |
model: nn.Module, | |
criterion: nn.Module, | |
optimizer: optim.Optimizer, | |
train_loader: DataLoader, | |
device: torch.device | |
): | |
model.train() | |
running_loss = 0.0 | |
correct = 0 | |
total = 0 | |
for img1, img2, labels in tqdm(train_loader, desc="Training", leave=False): | |
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) | |
optimizer.zero_grad() | |
logits, attn_weights = model(img1, img2) | |
loss = criterion(logits, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() * img1.size(0) | |
preds = (torch.sigmoid(logits) > 0.5).float() | |
correct += (preds == labels).sum().item() | |
total += labels.size(0) | |
epoch_loss = running_loss / len(train_loader.dataset) | |
epoch_acc = correct / total | |
return epoch_loss, epoch_acc | |
def validate( | |
model: nn.Module, | |
criterion: nn.Module, | |
val_loader: DataLoader, | |
device: torch.device | |
): | |
model.eval() | |
running_loss = 0.0 | |
correct = 0 | |
total = 0 | |
for img1, img2, labels in tqdm(val_loader, desc="Validation", leave=False): | |
img1, img2, labels = img1.to(device), img2.to(device), labels.to(device) | |
logits, attn_weights = model(img1, img2) | |
loss = criterion(logits, labels) | |
running_loss += loss.item() * img1.size(0) | |
preds = (torch.sigmoid(logits) > 0.5).float() | |
correct += (preds == labels).sum().item() | |
total += labels.size(0) | |
epoch_loss = running_loss / len(val_loader.dataset) | |
epoch_acc = correct / total | |
return epoch_loss, epoch_acc | |
def train( | |
model: nn.Module, | |
criterion: nn.Module, | |
optimizer: optim.Optimizer, | |
scheduler, | |
train_loader: DataLoader, | |
val_loader: DataLoader, | |
device: torch.device, | |
num_epochs: int = 30, | |
save_path: str = "best_attention_classifier.pth" | |
): | |
best_val_loss = float("inf") | |
epochs_no_improve = 0 | |
print("Start training...") | |
for epoch in range(num_epochs): | |
print(f"Epoch {epoch + 1}/{num_epochs}") | |
train_loss, train_acc = train_epoch(model, criterion, optimizer, train_loader, device) | |
val_loss, val_acc = validate(model, criterion, val_loader, device) | |
scheduler.step() | |
wandb.log({ | |
"epoch": epoch + 1, | |
"train_loss": train_loss, | |
"train_acc": train_acc, | |
"val_loss": val_loss, | |
"val_acc": val_acc, | |
"lr": optimizer.param_groups[0]["lr"], | |
}) | |
print( | |
f"learning rate: {optimizer.param_groups[0]['lr']:.6f}, " | |
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " | |
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}" | |
) | |
if val_loss < best_val_loss: | |
best_val_loss = val_loss | |
torch.save(model.state_dict(), save_path) | |
epochs_no_improve = 0 | |
else: | |
epochs_no_improve += 1 | |
with torch.no_grad(): | |
sample_img1, sample_img2, sample_labels = next(iter(val_loader)) | |
sample_img1, sample_img2 = sample_img1.to(device), sample_img2.to(device) | |
_, sample_attn_weights = model(sample_img1, sample_img2) | |
wandb.log({ | |
"attention_std": sample_attn_weights.std().item(), | |
"attention_mean": sample_attn_weights.mean().item(), | |
}) | |
attn_heatmap = sample_attn_weights[0].detach().cpu().numpy() | |
visualize_attention(attn_heatmap, epoch) | |
def main(config): | |
wandb.init(project="cross_attention_classifier", config=config) | |
train_loader, val_loader = get_data_loaders( | |
shape_params=config["shape_params"], | |
num_train_samples=config["num_train_samples"], | |
num_val_samples=config["num_val_samples"], | |
batch_size=config["batch_size"] | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model, criterion, optimizer, scheduler = build_model( | |
path_to_encoder=config["path_to_encoder"], | |
lr=config["lr"], | |
weight_decay=config["weight_decay"], | |
step_size=config["step_size"], | |
gamma=config["gamma"], | |
device=device | |
) | |
train( | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
train_loader=train_loader, | |
val_loader=val_loader, | |
device=device, | |
num_epochs=config["num_epochs"], | |
save_path=config["save_path"] | |
) | |
wandb.finish() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Train classifier model") | |
parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth") | |
parser.add_argument("--batch_size", type=int, default=256) | |
parser.add_argument("--lr", type=float, default=8e-5) | |
parser.add_argument("--weight_decay", type=float, default=1e-4) | |
parser.add_argument("--step_size", type=int, default=10) | |
parser.add_argument("--gamma", type=float, default=0.1) | |
parser.add_argument("--num_epochs", type=int, default=10) | |
parser.add_argument("--num_train_samples", type=int, default=10000) | |
parser.add_argument("--num_val_samples", type=int, default=2000) | |
parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth") | |
args = parser.parse_args() | |
config = { | |
"path_to_encoder": args.path_to_encoder, | |
"batch_size": args.batch_size, | |
"lr": args.lr, | |
"weight_decay": args.weight_decay, | |
"step_size": args.step_size, | |
"gamma": args.gamma, | |
"num_epochs": args.num_epochs, | |
"num_train_samples": args.num_train_samples, | |
"num_val_samples": args.num_val_samples, | |
"save_path": args.save_path, | |
} | |
if "shape_params" not in config: | |
config["shape_params"] = {} | |
main(config) | |