fisherman611's picture
Upload 3 files
324b9ef verified
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
import time
import wandb
from datetime import datetime
from tqdm.auto import tqdm
from models.can.can import CAN, create_can_model
from models.can.can_dataloader import create_dataloaders_for_can, Vocabulary
import albumentations as A
import cv2
import random
import json
with open("config.json", "r") as json_file:
cfg = json.load(json_file)
CAN_CONFIG = cfg["can"]
# Global constants
BASE_DIR = CAN_CONFIG["base_dir"]
SEED = CAN_CONFIG["seed"]
CHECKPOINT_DIR = CAN_CONFIG["checkpoint_dir"]
PRETRAINED_BACKBONE = True if CAN_CONFIG["pretrained_backbone"] == 1 else False
BACKBONE_TYPE = CAN_CONFIG["backbone_type"]
CHECKPOINT_NAME = f'{BACKBONE_TYPE}_can_best.pth' if PRETRAINED_BACKBONE == False else f'p_{BACKBONE_TYPE}_can_best.pth'
BATCH_SIZE = CAN_CONFIG["batch_size"]
HIDDEN_SIZE = CAN_CONFIG["hidden_size"]
EMBEDDING_DIM = CAN_CONFIG["embedding_dim"]
USE_COVERAGE = True if CAN_CONFIG["use_coverage"] == 1 else False
LAMBDA_COUNT = CAN_CONFIG["lambda_count"]
LR = CAN_CONFIG["lr"]
EPOCHS = CAN_CONFIG["epochs"]
GRAD_CLIP = CAN_CONFIG["grad_clip"]
PRINT_FREQ = CAN_CONFIG["print_freq"]
T = CAN_CONFIG["t"]
T_MULT = CAN_CONFIG["t_mult"]
PROJECT_NAME = f'final-hmer-can-{BACKBONE_TYPE}-pretrained' if PRETRAINED_BACKBONE == True else f'final-hmer-can-{BACKBONE_TYPE}'
NUM_WORKERS = cfg["can"]["num_workers"]
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class RandomMorphology(A.ImageOnlyTransform):
def __init__(self, p=0.5, kernel_size=3):
super(RandomMorphology, self).__init__(p)
self.kernel_size = kernel_size
def apply(self, img, **params):
op = random.choice(['erode', 'dilate'])
kernel = np.ones((self.kernel_size, self.kernel_size), np.uint8)
if op == 'erode':
return cv2.erode(img, kernel, iterations=1)
else:
return cv2.dilate(img, kernel, iterations=1)
# Custom transforms for CAN model (grayscale images)
train_transforms = A.Compose([
A.Rotate(limit=5, p=0.25, border_mode=cv2.BORDER_REPLICATE),
A.ElasticTransform(alpha=100,
sigma=7,
p=0.5,
interpolation=cv2.INTER_CUBIC),
RandomMorphology(p=0.5, kernel_size=2),
A.Normalize(mean=[0.0], std=[1.0]), # For grayscale
A.pytorch.ToTensorV2()
])
def train_epoch(model,
train_loader,
optimizer,
device,
grad_clip=5.0,
lambda_count=0.01,
print_freq=10):
"""
Train the model for one epoch
"""
model.train()
total_loss = 0.0
total_cls_loss = 0.0
total_count_loss = 0.0
batch_count = 0
for i, (images, captions, caption_lengths,
count_targets) in tqdm(enumerate(train_loader),
total=len(train_loader)):
batch_count += 1
images = images.to(device)
captions = captions.to(device)
count_targets = count_targets.to(device)
# Forward pass
outputs, count_vectors = model(images,
captions,
teacher_forcing_ratio=0.5)
# Calculate loss
loss, cls_loss, counting_loss = model.calculate_loss(
outputs=outputs,
targets=captions,
count_vectors=count_vectors,
count_targets=count_targets,
lambda_count=lambda_count)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Clip gradients
if grad_clip:
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# Update weights
optimizer.step()
# Track losses
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_count_loss += counting_loss.item()
# Print progress
if i % print_freq == 0 and i > 0:
print(
f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}, '
f'Cls Loss: {cls_loss.item():.4f}, Count Loss: {counting_loss.item():.4f}'
)
return total_loss / batch_count, total_cls_loss / batch_count, total_count_loss / batch_count
def validate(model, val_loader, device, lambda_count=0.01):
"""
Validate the model
"""
model.eval()
total_loss = 0.0
total_cls_loss = 0.0
total_count_loss = 0.0
batch_count = 0
with torch.no_grad():
for i, (images, captions, caption_lengths,
count_targets) in tqdm(enumerate(val_loader),
total=len(val_loader)):
batch_count += 1
images = images.to(device)
captions = captions.to(device)
count_targets = count_targets.to(device)
# Forward pass
outputs, count_vectors = model(
images, captions,
teacher_forcing_ratio=0.0) # No teacher forcing in validation
# Calculate loss
loss, cls_loss, counting_loss = model.calculate_loss(
outputs=outputs,
targets=captions,
count_vectors=count_vectors,
count_targets=count_targets,
lambda_count=lambda_count)
# Track losses
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_count_loss += counting_loss.item()
return total_loss / batch_count, total_cls_loss / batch_count, total_count_loss / batch_count
def main():
# Configuration
dataset_dir = BASE_DIR
seed = SEED
checkpoints_dir = CHECKPOINT_DIR
checkpoint_name = CHECKPOINT_NAME
batch_size = BATCH_SIZE
# Model parameters
hidden_size = HIDDEN_SIZE
embedding_dim = EMBEDDING_DIM
use_coverage = USE_COVERAGE
lambda_count = LAMBDA_COUNT
# Training parameters
lr = LR
epochs = EPOCHS
grad_clip = GRAD_CLIP
print_freq = PRINT_FREQ
# Scheduler parameters
T_0 = T
T_mult = T_MULT
# Set random seeds
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Create checkpoint directory
os.makedirs(checkpoints_dir, exist_ok=True)
# Set device
device = DEVICE
print(f'Using device: {device}')
# Create dataloaders
train_loader, val_loader, test_loader, vocab = create_dataloaders_for_can(
base_dir=dataset_dir, batch_size=batch_size, num_workers=NUM_WORKERS)
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
print(f"Vocabulary size: {len(vocab)}")
# Create model
model = create_can_model(num_classes=len(vocab),
hidden_size=hidden_size,
embedding_dim=embedding_dim,
use_coverage=use_coverage,
pretrained_backbone=PRETRAINED_BACKBONE,
backbone_type=BACKBONE_TYPE).to(device)
# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# Create learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
T_0=T_0,
T_mult=T_mult)
# Initialize wandb
run_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
wandb.init(project=PROJECT_NAME,
name=run_name,
config={
'seed': seed,
'batch_size': batch_size,
'hidden_size': hidden_size,
'embedding_dim': embedding_dim,
'use_coverage': use_coverage,
'lambda_count': lambda_count,
'lr': lr,
'epochs': epochs,
'grad_clip': grad_clip,
'T_0': T_0,
'T_mult': T_mult
})
# Training loop
best_val_loss = float('inf')
for epoch in tqdm(range(epochs)):
curr_lr = scheduler.get_last_lr()[0]
print(f'Epoch {epoch+1:03}/{epochs:03}')
t1 = time.time()
# Train
train_loss, train_cls_loss, train_count_loss = train_epoch(
model=model,
train_loader=train_loader,
optimizer=optimizer,
device=device,
grad_clip=grad_clip,
lambda_count=lambda_count,
print_freq=print_freq)
# Validate
val_loss, val_cls_loss, val_count_loss = validate(
model=model,
val_loader=val_loader,
device=device,
lambda_count=lambda_count)
# Update learning rate
scheduler.step()
t2 = time.time()
# Print stats
print(
f'Train - Total Loss: {train_loss:.4f}, Class Loss: {train_cls_loss:.4f}, Count Loss: {train_count_loss:.4f}'
)
print(
f'Val - Total Loss: {val_loss:.4f}, Class Loss: {val_cls_loss:.4f}, Count Loss: {val_count_loss:.4f}'
)
print(f'Time: {t2 - t1:.2f}s, Learning Rate: {curr_lr:.6f}')
# Log metrics to wandb
wandb.log({
'train_loss': train_loss,
'train_cls_loss': train_cls_loss,
'train_count_loss': train_count_loss,
'val_loss': val_loss,
'val_cls_loss': val_cls_loss,
'val_count_loss': val_count_loss,
'learning_rate': curr_lr,
'epoch': epoch
})
# Save checkpoint
if val_loss < best_val_loss:
best_val_loss = val_loss
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'val_loss': val_loss,
'vocab': vocab
}
torch.save(checkpoint, os.path.join(checkpoints_dir,
checkpoint_name))
print('Model saved!')
print('Training completed!')
if __name__ == "__main__":
main()