Spaces:
Build error
Build error
import torch | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision.models import convnext_tiny | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
import os | |
# Dataset path on RunPod | |
dataset_path = "/workspace/VCR Cleaned" | |
train_dir = os.path.join(dataset_path, "train") | |
val_dir = os.path.join(dataset_path, "val") | |
# Transforms | |
train_transform = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225]) | |
]) | |
val_transform = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225]) | |
]) | |
# Load datasets | |
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) | |
val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) | |
# Verify class mapping | |
print("\nLabel mapping:", train_dataset.class_to_idx) | |
print("Number of classes:", len(train_dataset.classes)) | |
# Load model | |
model = convnext_tiny(pretrained=True) | |
model.classifier[2] = nn.Linear(768, len(train_dataset.classes)) | |
# Setup | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) | |
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) | |
# Paths | |
checkpoint_path = "/workspace/convnext_checkpoint.pth" | |
best_model_path = "/workspace/convnext_best_model.pth" | |
final_model_path = "/workspace/convnext_final_model.pth" | |
# Load checkpoint if available | |
start_epoch = 0 | |
train_losses = [] | |
val_losses = [] | |
val_accuracies = [] | |
best_acc = 0.0 | |
if os.path.exists(checkpoint_path): | |
print("\nLoading checkpoint...") | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
train_losses = checkpoint['train_losses'] | |
val_losses = checkpoint['val_losses'] | |
val_accuracies = checkpoint['val_accuracies'] | |
best_acc = max(val_accuracies) if val_accuracies else 0.0 | |
start_epoch = checkpoint['epoch'] | |
print(f"Resumed from epoch {start_epoch}") | |
else: | |
print("\nStarting training from scratch") | |
# Training loop | |
for epoch in range(start_epoch, 100): | |
model.train() | |
train_loss = 0 | |
for images, labels in tqdm(DataLoader(train_dataset, batch_size=64, shuffle=True), desc=f"Epoch {epoch+1}"): | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
scheduler.step() | |
# Validation | |
model.eval() | |
val_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for images, labels in DataLoader(val_dataset, batch_size=64): | |
images, labels = images.to(device), labels.to(device) | |
outputs = model(images) | |
val_loss += criterion(outputs, labels).item() | |
preds = outputs.argmax(dim=1) | |
correct += (preds == labels).sum().item() | |
# Metrics | |
epoch_train_loss = train_loss / len(train_dataset) | |
epoch_val_loss = val_loss / len(val_dataset) | |
epoch_val_acc = correct / len(val_dataset) | |
train_losses.append(epoch_train_loss) | |
val_losses.append(epoch_val_loss) | |
val_accuracies.append(epoch_val_acc) | |
print(f"\nEpoch {epoch+1}:") | |
print(f" Train Loss: {epoch_train_loss:.4f}") | |
print(f" Val Loss: {epoch_val_loss:.4f}") | |
print(f" Val Acc: {epoch_val_acc:.4f}") | |
# Save checkpoint | |
torch.save({ | |
'epoch': epoch + 1, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'scheduler_state_dict': scheduler.state_dict(), | |
'train_losses': train_losses, | |
'val_losses': val_losses, | |
'val_accuracies': val_accuracies | |
}, checkpoint_path) | |
# ✅ Save best model | |
if epoch_val_acc > best_acc: | |
best_acc = epoch_val_acc | |
torch.save(model.state_dict(), best_model_path) | |
print(f"✅ Best model saved at epoch {epoch+1} with acc {best_acc:.4f}") | |
# ✅ Save final model | |
torch.save(model.state_dict(), final_model_path) | |
print(f"\n✅ Final model saved to {final_model_path}") | |
# Plot | |
plt.figure(figsize=(12, 4)) | |
plt.subplot(1, 2, 1) | |
plt.plot(train_losses, label='Train Loss') | |
plt.plot(val_losses, label='Val Loss') | |
plt.title("Loss") | |
plt.legend() | |
plt.subplot(1, 2, 2) | |
plt.plot(val_accuracies, label='Val Accuracy') | |
plt.title("Validation Accuracy") | |
plt.legend() | |
plt.show() | |