Ayesha352 commited on
Commit
ebda3d1
·
verified ·
1 Parent(s): 215bcc2

Upload train_model(1).py

Browse files
Files changed (1) hide show
  1. train_model(1).py +153 -0
train_model(1).py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import datasets, transforms
3
+ from torch.utils.data import DataLoader
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torchvision.models import convnext_tiny
7
+ from tqdm import tqdm
8
+ import matplotlib.pyplot as plt
9
+ import os
10
+
11
+ # Dataset path on RunPod
12
+ dataset_path = "/workspace/VCR Cleaned"
13
+ train_dir = os.path.join(dataset_path, "train")
14
+ val_dir = os.path.join(dataset_path, "val")
15
+
16
+ # Transforms
17
+ train_transform = transforms.Compose([
18
+ transforms.Resize((512, 512)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.485, 0.456, 0.406],
21
+ [0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ val_transform = transforms.Compose([
25
+ transforms.Resize((512, 512)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.485, 0.456, 0.406],
28
+ [0.229, 0.224, 0.225])
29
+ ])
30
+
31
+ # Load datasets
32
+ train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
33
+ val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
34
+
35
+ # Verify class mapping
36
+ print("\nLabel mapping:", train_dataset.class_to_idx)
37
+ print("Number of classes:", len(train_dataset.classes))
38
+
39
+ # Load model
40
+ model = convnext_tiny(pretrained=True)
41
+ model.classifier[2] = nn.Linear(768, len(train_dataset.classes))
42
+
43
+ # Setup
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ model = model.to(device)
46
+ criterion = nn.CrossEntropyLoss()
47
+ optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
48
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
49
+
50
+ # Paths
51
+ checkpoint_path = "/workspace/convnext_checkpoint.pth"
52
+ best_model_path = "/workspace/convnext_best_model.pth"
53
+ final_model_path = "/workspace/convnext_final_model.pth"
54
+
55
+ # Load checkpoint if available
56
+ start_epoch = 0
57
+ train_losses = []
58
+ val_losses = []
59
+ val_accuracies = []
60
+ best_acc = 0.0
61
+
62
+ if os.path.exists(checkpoint_path):
63
+ print("\nLoading checkpoint...")
64
+ checkpoint = torch.load(checkpoint_path, map_location=device)
65
+ model.load_state_dict(checkpoint['model_state_dict'])
66
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
67
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
68
+ train_losses = checkpoint['train_losses']
69
+ val_losses = checkpoint['val_losses']
70
+ val_accuracies = checkpoint['val_accuracies']
71
+ best_acc = max(val_accuracies) if val_accuracies else 0.0
72
+ start_epoch = checkpoint['epoch']
73
+ print(f"Resumed from epoch {start_epoch}")
74
+ else:
75
+ print("\nStarting training from scratch")
76
+
77
+ # Training loop
78
+ for epoch in range(start_epoch, 100):
79
+ model.train()
80
+ train_loss = 0
81
+
82
+ for images, labels in tqdm(DataLoader(train_dataset, batch_size=64, shuffle=True), desc=f"Epoch {epoch+1}"):
83
+ images, labels = images.to(device), labels.to(device)
84
+
85
+ optimizer.zero_grad()
86
+ outputs = model(images)
87
+ loss = criterion(outputs, labels)
88
+ loss.backward()
89
+ optimizer.step()
90
+ train_loss += loss.item()
91
+
92
+ scheduler.step()
93
+
94
+ # Validation
95
+ model.eval()
96
+ val_loss = 0
97
+ correct = 0
98
+ with torch.no_grad():
99
+ for images, labels in DataLoader(val_dataset, batch_size=64):
100
+ images, labels = images.to(device), labels.to(device)
101
+ outputs = model(images)
102
+ val_loss += criterion(outputs, labels).item()
103
+ preds = outputs.argmax(dim=1)
104
+ correct += (preds == labels).sum().item()
105
+
106
+ # Metrics
107
+ epoch_train_loss = train_loss / len(train_dataset)
108
+ epoch_val_loss = val_loss / len(val_dataset)
109
+ epoch_val_acc = correct / len(val_dataset)
110
+
111
+ train_losses.append(epoch_train_loss)
112
+ val_losses.append(epoch_val_loss)
113
+ val_accuracies.append(epoch_val_acc)
114
+
115
+ print(f"\nEpoch {epoch+1}:")
116
+ print(f" Train Loss: {epoch_train_loss:.4f}")
117
+ print(f" Val Loss: {epoch_val_loss:.4f}")
118
+ print(f" Val Acc: {epoch_val_acc:.4f}")
119
+
120
+ # Save checkpoint
121
+ torch.save({
122
+ 'epoch': epoch + 1,
123
+ 'model_state_dict': model.state_dict(),
124
+ 'optimizer_state_dict': optimizer.state_dict(),
125
+ 'scheduler_state_dict': scheduler.state_dict(),
126
+ 'train_losses': train_losses,
127
+ 'val_losses': val_losses,
128
+ 'val_accuracies': val_accuracies
129
+ }, checkpoint_path)
130
+
131
+ # ✅ Save best model
132
+ if epoch_val_acc > best_acc:
133
+ best_acc = epoch_val_acc
134
+ torch.save(model.state_dict(), best_model_path)
135
+ print(f"✅ Best model saved at epoch {epoch+1} with acc {best_acc:.4f}")
136
+
137
+ # ✅ Save final model
138
+ torch.save(model.state_dict(), final_model_path)
139
+ print(f"\n✅ Final model saved to {final_model_path}")
140
+
141
+ # Plot
142
+ plt.figure(figsize=(12, 4))
143
+ plt.subplot(1, 2, 1)
144
+ plt.plot(train_losses, label='Train Loss')
145
+ plt.plot(val_losses, label='Val Loss')
146
+ plt.title("Loss")
147
+ plt.legend()
148
+
149
+ plt.subplot(1, 2, 2)
150
+ plt.plot(val_accuracies, label='Val Accuracy')
151
+ plt.title("Validation Accuracy")
152
+ plt.legend()
153
+ plt.show()