Spaces:
Build error
Build error
Upload train_model(1).py
Browse files- train_model(1).py +153 -0
train_model(1).py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import datasets, transforms
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
from torchvision.models import convnext_tiny
|
7 |
+
from tqdm import tqdm
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import os
|
10 |
+
|
11 |
+
# Dataset path on RunPod
|
12 |
+
dataset_path = "/workspace/VCR Cleaned"
|
13 |
+
train_dir = os.path.join(dataset_path, "train")
|
14 |
+
val_dir = os.path.join(dataset_path, "val")
|
15 |
+
|
16 |
+
# Transforms
|
17 |
+
train_transform = transforms.Compose([
|
18 |
+
transforms.Resize((512, 512)),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
21 |
+
[0.229, 0.224, 0.225])
|
22 |
+
])
|
23 |
+
|
24 |
+
val_transform = transforms.Compose([
|
25 |
+
transforms.Resize((512, 512)),
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
28 |
+
[0.229, 0.224, 0.225])
|
29 |
+
])
|
30 |
+
|
31 |
+
# Load datasets
|
32 |
+
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
|
33 |
+
val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
|
34 |
+
|
35 |
+
# Verify class mapping
|
36 |
+
print("\nLabel mapping:", train_dataset.class_to_idx)
|
37 |
+
print("Number of classes:", len(train_dataset.classes))
|
38 |
+
|
39 |
+
# Load model
|
40 |
+
model = convnext_tiny(pretrained=True)
|
41 |
+
model.classifier[2] = nn.Linear(768, len(train_dataset.classes))
|
42 |
+
|
43 |
+
# Setup
|
44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
model = model.to(device)
|
46 |
+
criterion = nn.CrossEntropyLoss()
|
47 |
+
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
|
48 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
49 |
+
|
50 |
+
# Paths
|
51 |
+
checkpoint_path = "/workspace/convnext_checkpoint.pth"
|
52 |
+
best_model_path = "/workspace/convnext_best_model.pth"
|
53 |
+
final_model_path = "/workspace/convnext_final_model.pth"
|
54 |
+
|
55 |
+
# Load checkpoint if available
|
56 |
+
start_epoch = 0
|
57 |
+
train_losses = []
|
58 |
+
val_losses = []
|
59 |
+
val_accuracies = []
|
60 |
+
best_acc = 0.0
|
61 |
+
|
62 |
+
if os.path.exists(checkpoint_path):
|
63 |
+
print("\nLoading checkpoint...")
|
64 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
65 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
66 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
67 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
68 |
+
train_losses = checkpoint['train_losses']
|
69 |
+
val_losses = checkpoint['val_losses']
|
70 |
+
val_accuracies = checkpoint['val_accuracies']
|
71 |
+
best_acc = max(val_accuracies) if val_accuracies else 0.0
|
72 |
+
start_epoch = checkpoint['epoch']
|
73 |
+
print(f"Resumed from epoch {start_epoch}")
|
74 |
+
else:
|
75 |
+
print("\nStarting training from scratch")
|
76 |
+
|
77 |
+
# Training loop
|
78 |
+
for epoch in range(start_epoch, 100):
|
79 |
+
model.train()
|
80 |
+
train_loss = 0
|
81 |
+
|
82 |
+
for images, labels in tqdm(DataLoader(train_dataset, batch_size=64, shuffle=True), desc=f"Epoch {epoch+1}"):
|
83 |
+
images, labels = images.to(device), labels.to(device)
|
84 |
+
|
85 |
+
optimizer.zero_grad()
|
86 |
+
outputs = model(images)
|
87 |
+
loss = criterion(outputs, labels)
|
88 |
+
loss.backward()
|
89 |
+
optimizer.step()
|
90 |
+
train_loss += loss.item()
|
91 |
+
|
92 |
+
scheduler.step()
|
93 |
+
|
94 |
+
# Validation
|
95 |
+
model.eval()
|
96 |
+
val_loss = 0
|
97 |
+
correct = 0
|
98 |
+
with torch.no_grad():
|
99 |
+
for images, labels in DataLoader(val_dataset, batch_size=64):
|
100 |
+
images, labels = images.to(device), labels.to(device)
|
101 |
+
outputs = model(images)
|
102 |
+
val_loss += criterion(outputs, labels).item()
|
103 |
+
preds = outputs.argmax(dim=1)
|
104 |
+
correct += (preds == labels).sum().item()
|
105 |
+
|
106 |
+
# Metrics
|
107 |
+
epoch_train_loss = train_loss / len(train_dataset)
|
108 |
+
epoch_val_loss = val_loss / len(val_dataset)
|
109 |
+
epoch_val_acc = correct / len(val_dataset)
|
110 |
+
|
111 |
+
train_losses.append(epoch_train_loss)
|
112 |
+
val_losses.append(epoch_val_loss)
|
113 |
+
val_accuracies.append(epoch_val_acc)
|
114 |
+
|
115 |
+
print(f"\nEpoch {epoch+1}:")
|
116 |
+
print(f" Train Loss: {epoch_train_loss:.4f}")
|
117 |
+
print(f" Val Loss: {epoch_val_loss:.4f}")
|
118 |
+
print(f" Val Acc: {epoch_val_acc:.4f}")
|
119 |
+
|
120 |
+
# Save checkpoint
|
121 |
+
torch.save({
|
122 |
+
'epoch': epoch + 1,
|
123 |
+
'model_state_dict': model.state_dict(),
|
124 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
125 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
126 |
+
'train_losses': train_losses,
|
127 |
+
'val_losses': val_losses,
|
128 |
+
'val_accuracies': val_accuracies
|
129 |
+
}, checkpoint_path)
|
130 |
+
|
131 |
+
# ✅ Save best model
|
132 |
+
if epoch_val_acc > best_acc:
|
133 |
+
best_acc = epoch_val_acc
|
134 |
+
torch.save(model.state_dict(), best_model_path)
|
135 |
+
print(f"✅ Best model saved at epoch {epoch+1} with acc {best_acc:.4f}")
|
136 |
+
|
137 |
+
# ✅ Save final model
|
138 |
+
torch.save(model.state_dict(), final_model_path)
|
139 |
+
print(f"\n✅ Final model saved to {final_model_path}")
|
140 |
+
|
141 |
+
# Plot
|
142 |
+
plt.figure(figsize=(12, 4))
|
143 |
+
plt.subplot(1, 2, 1)
|
144 |
+
plt.plot(train_losses, label='Train Loss')
|
145 |
+
plt.plot(val_losses, label='Val Loss')
|
146 |
+
plt.title("Loss")
|
147 |
+
plt.legend()
|
148 |
+
|
149 |
+
plt.subplot(1, 2, 2)
|
150 |
+
plt.plot(val_accuracies, label='Val Accuracy')
|
151 |
+
plt.title("Validation Accuracy")
|
152 |
+
plt.legend()
|
153 |
+
plt.show()
|