color-recognition1 / train_model.py
Ayesha352's picture
Rename train_model(1).py to train_model.py
d3186fe verified
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.models import convnext_tiny
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
# Dataset path on RunPod
dataset_path = "/workspace/VCR Cleaned"
train_dir = os.path.join(dataset_path, "train")
val_dir = os.path.join(dataset_path, "val")
# Transforms
train_transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Load datasets
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
# Verify class mapping
print("\nLabel mapping:", train_dataset.class_to_idx)
print("Number of classes:", len(train_dataset.classes))
# Load model
model = convnext_tiny(pretrained=True)
model.classifier[2] = nn.Linear(768, len(train_dataset.classes))
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
# Paths
checkpoint_path = "/workspace/convnext_checkpoint.pth"
best_model_path = "/workspace/convnext_best_model.pth"
final_model_path = "/workspace/convnext_final_model.pth"
# Load checkpoint if available
start_epoch = 0
train_losses = []
val_losses = []
val_accuracies = []
best_acc = 0.0
if os.path.exists(checkpoint_path):
print("\nLoading checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
train_losses = checkpoint['train_losses']
val_losses = checkpoint['val_losses']
val_accuracies = checkpoint['val_accuracies']
best_acc = max(val_accuracies) if val_accuracies else 0.0
start_epoch = checkpoint['epoch']
print(f"Resumed from epoch {start_epoch}")
else:
print("\nStarting training from scratch")
# Training loop
for epoch in range(start_epoch, 100):
model.train()
train_loss = 0
for images, labels in tqdm(DataLoader(train_dataset, batch_size=64, shuffle=True), desc=f"Epoch {epoch+1}"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Validation
model.eval()
val_loss = 0
correct = 0
with torch.no_grad():
for images, labels in DataLoader(val_dataset, batch_size=64):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
val_loss += criterion(outputs, labels).item()
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
# Metrics
epoch_train_loss = train_loss / len(train_dataset)
epoch_val_loss = val_loss / len(val_dataset)
epoch_val_acc = correct / len(val_dataset)
train_losses.append(epoch_train_loss)
val_losses.append(epoch_val_loss)
val_accuracies.append(epoch_val_acc)
print(f"\nEpoch {epoch+1}:")
print(f" Train Loss: {epoch_train_loss:.4f}")
print(f" Val Loss: {epoch_val_loss:.4f}")
print(f" Val Acc: {epoch_val_acc:.4f}")
# Save checkpoint
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_losses': train_losses,
'val_losses': val_losses,
'val_accuracies': val_accuracies
}, checkpoint_path)
# ✅ Save best model
if epoch_val_acc > best_acc:
best_acc = epoch_val_acc
torch.save(model.state_dict(), best_model_path)
print(f"✅ Best model saved at epoch {epoch+1} with acc {best_acc:.4f}")
# ✅ Save final model
torch.save(model.state_dict(), final_model_path)
print(f"\n✅ Final model saved to {final_model_path}")
# Plot
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title("Loss")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Val Accuracy')
plt.title("Validation Accuracy")
plt.legend()
plt.show()