import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader import torch.nn as nn import torch.optim as optim from torchvision.models import convnext_tiny from tqdm import tqdm import matplotlib.pyplot as plt import os # Dataset path on RunPod dataset_path = "/workspace/VCR Cleaned" train_dir = os.path.join(dataset_path, "train") val_dir = os.path.join(dataset_path, "val") # Transforms train_transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load datasets train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) # Verify class mapping print("\nLabel mapping:", train_dataset.class_to_idx) print("Number of classes:", len(train_dataset.classes)) # Load model model = convnext_tiny(pretrained=True) model.classifier[2] = nn.Linear(768, len(train_dataset.classes)) # Setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # Paths checkpoint_path = "/workspace/convnext_checkpoint.pth" best_model_path = "/workspace/convnext_best_model.pth" final_model_path = "/workspace/convnext_final_model.pth" # Load checkpoint if available start_epoch = 0 train_losses = [] val_losses = [] val_accuracies = [] best_acc = 0.0 if os.path.exists(checkpoint_path): print("\nLoading checkpoint...") checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) train_losses = checkpoint['train_losses'] val_losses = checkpoint['val_losses'] val_accuracies = checkpoint['val_accuracies'] best_acc = max(val_accuracies) if val_accuracies else 0.0 start_epoch = checkpoint['epoch'] print(f"Resumed from epoch {start_epoch}") else: print("\nStarting training from scratch") # Training loop for epoch in range(start_epoch, 100): model.train() train_loss = 0 for images, labels in tqdm(DataLoader(train_dataset, batch_size=64, shuffle=True), desc=f"Epoch {epoch+1}"): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() scheduler.step() # Validation model.eval() val_loss = 0 correct = 0 with torch.no_grad(): for images, labels in DataLoader(val_dataset, batch_size=64): images, labels = images.to(device), labels.to(device) outputs = model(images) val_loss += criterion(outputs, labels).item() preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() # Metrics epoch_train_loss = train_loss / len(train_dataset) epoch_val_loss = val_loss / len(val_dataset) epoch_val_acc = correct / len(val_dataset) train_losses.append(epoch_train_loss) val_losses.append(epoch_val_loss) val_accuracies.append(epoch_val_acc) print(f"\nEpoch {epoch+1}:") print(f" Train Loss: {epoch_train_loss:.4f}") print(f" Val Loss: {epoch_val_loss:.4f}") print(f" Val Acc: {epoch_val_acc:.4f}") # Save checkpoint torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'train_losses': train_losses, 'val_losses': val_losses, 'val_accuracies': val_accuracies }, checkpoint_path) # āœ… Save best model if epoch_val_acc > best_acc: best_acc = epoch_val_acc torch.save(model.state_dict(), best_model_path) print(f"āœ… Best model saved at epoch {epoch+1} with acc {best_acc:.4f}") # āœ… Save final model torch.save(model.state_dict(), final_model_path) print(f"\nāœ… Final model saved to {final_model_path}") # Plot plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(train_losses, label='Train Loss') plt.plot(val_losses, label='Val Loss') plt.title("Loss") plt.legend() plt.subplot(1, 2, 2) plt.plot(val_accuracies, label='Val Accuracy') plt.title("Validation Accuracy") plt.legend() plt.show()