dhruv2842 commited on
Commit
c35b02c
·
verified ·
1 Parent(s): b050424

Delete densenet_withglam.ipynb

Browse files
Files changed (1) hide show
  1. densenet_withglam.ipynb +0 -1731
densenet_withglam.ipynb DELETED
@@ -1,1731 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import torch\n",
10
- "import torch.nn as nn\n",
11
- "import torch.optim as optim\n",
12
- "from torchvision import datasets, models, transforms\n",
13
- "from torch.utils.data import DataLoader, random_split\n",
14
- "from sklearn.metrics import (\n",
15
- " accuracy_score, f1_score, precision_score, recall_score,\n",
16
- " roc_auc_score, roc_curve, confusion_matrix, auc, classification_report\n",
17
- ")\n",
18
- "from sklearn.preprocessing import label_binarize\n",
19
- "import numpy as np\n",
20
- "import matplotlib.pyplot as plt\n",
21
- "import itertools\n",
22
- "import random\n",
23
- "import json\n",
24
- "from tqdm import tqdm\n",
25
- "from timm import create_model\n",
26
- "import os"
27
- ]
28
- },
29
- {
30
- "cell_type": "code",
31
- "execution_count": null,
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "\n",
36
- "checkpoint = torch.load('densenet169_seed40_best.pt', map_location='cpu')\n",
37
- "state_dict = checkpoint['state_dict']\n",
38
- "\n",
39
- "for k in list(state_dict.keys())[:10]:\n",
40
- " print(k)"
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": 6,
46
- "metadata": {},
47
- "outputs": [
48
- {
49
- "name": "stderr",
50
- "output_type": "stream",
51
- "text": [
52
- "/tmp/ipykernel_65719/145292541.py:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
53
- " checkpoint = torch.load('densenet169_seed40_best.pt', map_location='cpu')\n"
54
- ]
55
- },
56
- {
57
- "name": "stdout",
58
- "output_type": "stream",
59
- "text": [
60
- "features.0.conv0.weight\n",
61
- "features.0.norm0.weight\n",
62
- "features.0.norm0.bias\n",
63
- "features.0.norm0.running_mean\n",
64
- "features.0.norm0.running_var\n",
65
- "features.0.norm0.num_batches_tracked\n",
66
- "features.0.denseblock1.denselayer1.norm1.weight\n",
67
- "features.0.denseblock1.denselayer1.norm1.bias\n",
68
- "features.0.denseblock1.denselayer1.norm1.running_mean\n",
69
- "features.0.denseblock1.denselayer1.norm1.running_var\n"
70
- ]
71
- }
72
- ],
73
- "source": [
74
- "\n",
75
- "checkpoint = torch.load('densenet169_seed40_best.pt', map_location='cpu')\n",
76
- "state_dict = checkpoint['state_dict']\n",
77
- "\n",
78
- "for k in list(state_dict.keys())[:10]:\n",
79
- " print(k)"
80
- ]
81
- },
82
- {
83
- "cell_type": "code",
84
- "execution_count": 2,
85
- "metadata": {},
86
- "outputs": [],
87
- "source": [
88
- "import torch\n",
89
- "import torch.nn as nn\n",
90
- "import torch.nn.functional as F\n",
91
- "\n",
92
- "\n",
93
- "class GLAM(nn.Module):\n",
94
- " \"\"\"\n",
95
- " Global-Local Attention Module (GLAM) that produces a refined feature map.\n",
96
- " \"\"\"\n",
97
- " def __init__(self, in_channels, reduction_ratio=8):\n",
98
- " super(GLAM, self).__init__()\n",
99
- " \n",
100
- " # --- Local Channel Attention ---\n",
101
- " self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)\n",
102
- " self.local_channel_act = nn.Sigmoid()\n",
103
- " self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)\n",
104
- " \n",
105
- " # --- Local Spatial Attention ---\n",
106
- " # 3-dilated, 5-dilated conv merges\n",
107
- " self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)\n",
108
- " self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)\n",
109
- " self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)\n",
110
- " self.local_spatial_act = nn.Sigmoid()\n",
111
- " \n",
112
- " # --- Global Channel Attention ---\n",
113
- " self.global_avg_pool = nn.AdaptiveAvgPool2d(1)\n",
114
- " self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)\n",
115
- " self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)\n",
116
- " self.global_channel_act = nn.Sigmoid()\n",
117
- " \n",
118
- " # --- Global Spatial Attention ---\n",
119
- " self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)\n",
120
- " self.global_spatial_softmax = nn.Softmax(dim=-1)\n",
121
- " \n",
122
- " \n",
123
- " # --- Weighted paramerers initialization ---\n",
124
- " self.local_attention_weight = nn.Parameter(torch.tensor(1.0)) \n",
125
- " self.global_attention_weight = nn.Parameter(torch.tensor(1.0))\n",
126
- " \n",
127
- "\n",
128
- " def forward(self, x):\n",
129
- " # Local Channel Attention\n",
130
- " lca = self.local_channel_conv(x) \n",
131
- " lca = self.local_channel_act(lca) \n",
132
- " lca = self.local_channel_expand(lca) \n",
133
- " lca_out = lca * x \n",
134
- "\n",
135
- " # Local Spatial Attention\n",
136
- " lsa3 = self.local_spatial_conv3(x)\n",
137
- " lsa5 = self.local_spatial_conv5(x)\n",
138
- " lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)\n",
139
- " lsa = self.local_spatial_merge(lsa_cat)\n",
140
- " lsa = self.local_spatial_act(lsa)\n",
141
- " lsa_out = lsa * lca_out\n",
142
- " lsa_out = lsa_out + lca_out\n",
143
- "\n",
144
- " # Global Channel Attention\n",
145
- " B, C, H, W = x.size()\n",
146
- " gca = self.global_avg_pool(x).view(B, C) \n",
147
- " gca = F.relu(self.global_channel_fc1(gca), inplace=True)\n",
148
- " gca = self.global_channel_fc2(gca)\n",
149
- " gca = self.global_channel_act(gca)\n",
150
- " gca = gca.view(B, C, 1, 1)\n",
151
- " gca_out = gca * x\n",
152
- "\n",
153
- " # Global Spatial Attention\n",
154
- " gsa = self.global_spatial_conv(x) # [B, 1, H, W]\n",
155
- " gsa = gsa.view(B, -1) # [B, H*W]\n",
156
- " gsa = self.global_spatial_softmax(gsa)\n",
157
- " gsa = gsa.view(B, 1, H, W)\n",
158
- " gsa_out = gsa * gca_out\n",
159
- " gsa_out = gsa_out + gca_out\n",
160
- "\n",
161
- " # Fuse\n",
162
- " out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x\n",
163
- " return out\n"
164
- ]
165
- },
166
- {
167
- "cell_type": "code",
168
- "execution_count": 3,
169
- "metadata": {},
170
- "outputs": [],
171
- "source": [
172
- "def get_model_with_attention(model_name, num_classes):\n",
173
- " \"\"\"Get a pretrained model and attach GLAM attention block.\"\"\"\n",
174
- " # Ensure GLAM is imported or defined in the notebook\n",
175
- " # from your_glam_module import GLAM\n",
176
- "\n",
177
- " if model_name == 'mobilenet_v2':\n",
178
- " model = models.mobilenet_v2(pretrained=True)\n",
179
- " in_channels = 1280\n",
180
- " model.features.add_module(\"attention\", GLAM(in_channels))\n",
181
- " model.classifier[1] = nn.Linear(in_channels, num_classes)\n",
182
- "\n",
183
- " elif model_name == 'mobilenet_v3':\n",
184
- " model = models.mobilenet_v3_large(pretrained=True)\n",
185
- " in_channels = 960\n",
186
- " model.features.add_module(\"attention\", GLAM(in_channels))\n",
187
- " model.classifier = nn.Sequential(\n",
188
- " nn.Linear(in_channels, 1280),\n",
189
- " nn.Hardswish(),\n",
190
- " nn.Dropout(p=0.2),\n",
191
- " nn.Linear(1280, num_classes)\n",
192
- " )\n",
193
- "\n",
194
- " elif model_name == 'efficientnet':\n",
195
- " model = models.efficientnet_b0(pretrained=True)\n",
196
- " in_channels = model.classifier[1].in_features\n",
197
- " model.features = nn.Sequential(model.features, GLAM(in_channels))\n",
198
- " model.classifier[1] = nn.Linear(in_channels, num_classes)\n",
199
- "\n",
200
- " elif model_name == 'densenet121':\n",
201
- " model = models.densenet121(pretrained=True)\n",
202
- " in_channels = model.classifier.in_features\n",
203
- " model.features = nn.Sequential(model.features, nn.ReLU(inplace=True), GLAM(in_channels))\n",
204
- " model.classifier = nn.Linear(in_channels, num_classes)\n",
205
- "\n",
206
- " elif model_name == 'densenet161':\n",
207
- " model = models.densenet161(pretrained=True)\n",
208
- " in_channels = model.classifier.in_features\n",
209
- " model.features = nn.Sequential(model.features, nn.ReLU(inplace=True), GLAM(in_channels))\n",
210
- " model.classifier = nn.Linear(in_channels, num_classes)\n",
211
- "\n",
212
- " elif model_name == 'densenet169':\n",
213
- " model = models.densenet169(pretrained=True)\n",
214
- " in_channels = model.classifier.in_features\n",
215
- " model.features = nn.Sequential(model.features, nn.ReLU(inplace=True), GLAM(in_channels))\n",
216
- " model.classifier = nn.Linear(in_channels, num_classes)\n",
217
- "\n",
218
- " elif model_name == 'vgg16':\n",
219
- " model = models.vgg16(pretrained=True)\n",
220
- " in_channels = 512\n",
221
- " model.features = nn.Sequential(model.features, GLAM(in_channels))\n",
222
- " model.classifier[6] = nn.Linear(4096, num_classes)\n",
223
- "\n",
224
- " elif model_name == 'resnet18':\n",
225
- " model = models.resnet18(pretrained=True)\n",
226
- " in_channels = model.fc.in_features\n",
227
- " model.layer4.add_module(\"attention\", GLAM(in_channels))\n",
228
- " model.fc = nn.Linear(in_channels, num_classes)\n",
229
- "\n",
230
- " elif model_name == 'resnet50':\n",
231
- " model = models.resnet50(pretrained=True)\n",
232
- " in_channels = model.fc.in_features\n",
233
- " model.layer4.add_module(\"attention\", GLAM(in_channels))\n",
234
- " model.fc = nn.Linear(in_channels, num_classes)\n",
235
- "\n",
236
- " else:\n",
237
- " raise ValueError(f\"Unsupported model name: {model_name}\")\n",
238
- "\n",
239
- " return model\n"
240
- ]
241
- },
242
- {
243
- "cell_type": "code",
244
- "execution_count": 4,
245
- "metadata": {},
246
- "outputs": [],
247
- "source": [
248
- "class EarlyStopping:\n",
249
- " \"\"\"Early stopping class with checkpointing support.\"\"\"\n",
250
- " def __init__(self, patience=30, verbose=False, delta=0.0, path='checkpoint.pt', minEpochs=10):\n",
251
- " self.patience = patience\n",
252
- " self.verbose = verbose\n",
253
- " self.delta = delta\n",
254
- " self.path = path\n",
255
- " self.counter = 0\n",
256
- " self.best_loss = None\n",
257
- " self.early_stop = False\n",
258
- " self.val_loss_min = np.Inf\n",
259
- " self.minEpochs = minEpochs\n",
260
- "\n",
261
- " def __call__(self, val_loss, model, epochNum):\n",
262
- " \"\"\"Check if early stopping should be activated.\"\"\"\n",
263
- " if epochNum < self.minEpochs:\n",
264
- " return\n",
265
- " if self.best_loss is None:\n",
266
- " self.best_loss = val_loss\n",
267
- " self.save_checkpoint(val_loss, model, epochNum)\n",
268
- " elif val_loss > self.best_loss - self.delta:\n",
269
- " self.counter += 1\n",
270
- " if self.verbose:\n",
271
- " print(f'EarlyStopping counter: {self.counter} out of {self.patience}')\n",
272
- " if self.counter >= self.patience:\n",
273
- " self.early_stop = True\n",
274
- " else:\n",
275
- " self.best_loss = val_loss\n",
276
- " self.save_checkpoint(val_loss, model, epochNum)\n",
277
- " self.counter = 0\n",
278
- "\n",
279
- " def save_checkpoint(self, val_loss, model, epochNum):\n",
280
- " \"\"\"Save checkpoint dict with state_dict and relevant metadata.\"\"\"\n",
281
- " if self.verbose:\n",
282
- " print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving checkpoint...')\n",
283
- " checkpoint = {\n",
284
- " 'state_dict': model.state_dict(),\n",
285
- " 'val_loss': val_loss,\n",
286
- " 'epoch': epochNum\n",
287
- " }\n",
288
- " torch.save(checkpoint, self.path)\n",
289
- " self.val_loss_min = val_loss\n",
290
- "\n",
291
- "\n",
292
- "class ModelTrainer:\n",
293
- " \"\"\"Trainer for PyTorch model with early stopping, training, and evaluation.\"\"\"\n",
294
- " def __init__(\n",
295
- " self, model, train_loader, val_loader, test_loader,\n",
296
- " criterion, optimizer, device, model_name='model', seed=42\n",
297
- " ):\n",
298
- " self.model = model\n",
299
- " self.train_loader = train_loader\n",
300
- " self.val_loader = val_loader\n",
301
- " self.test_loader = test_loader\n",
302
- " self.criterion = criterion\n",
303
- " self.optimizer = optimizer\n",
304
- " self.device = device\n",
305
- " self.model_name = f\"{model_name}_seed{seed}\"\n",
306
- " self.history = {\n",
307
- " 'train_loss': [], 'val_loss': [],\n",
308
- " 'train_acc': [], 'val_acc': [],\n",
309
- " 'best_epoch': 0\n",
310
- " }\n",
311
- "\n",
312
- " def train_epoch(self):\n",
313
- " \"\"\"Perform one training epoch.\"\"\"\n",
314
- " self.model.train()\n",
315
- " running_loss = 0.0\n",
316
- " running_corrects = 0\n",
317
- " total_samples = 0\n",
318
- " for inputs, labels in tqdm(self.train_loader, desc='Training'):\n",
319
- " inputs = inputs.to(self.device)\n",
320
- " labels = labels.to(self.device)\n",
321
- "\n",
322
- " self.optimizer.zero_grad()\n",
323
- " outputs = self.model(inputs)\n",
324
- "\n",
325
- " loss = self.criterion(outputs, labels)\n",
326
- "\n",
327
- " loss.backward()\n",
328
- " self.optimizer.step()\n",
329
- "\n",
330
- " _, preds = torch.max(outputs, 1)\n",
331
- " running_loss += loss.item() * inputs.size(0)\n",
332
- " running_corrects += torch.sum(preds == labels.data).item()\n",
333
- " total_samples += inputs.size(0)\n",
334
- "\n",
335
- " return running_loss / total_samples, running_corrects / total_samples, None\n",
336
- "\n",
337
- " def validate_epoch(self):\n",
338
- " \"\"\"Perform one validation epoch.\"\"\"\n",
339
- " self.model.eval()\n",
340
- " running_loss = 0.0\n",
341
- " running_corrects = 0\n",
342
- " total_samples = 0\n",
343
- "\n",
344
- " with torch.no_grad():\n",
345
- " for inputs, labels in self.val_loader:\n",
346
- " inputs = inputs.to(self.device)\n",
347
- " labels = labels.to(self.device)\n",
348
- "\n",
349
- " outputs = self.model(inputs)\n",
350
- "\n",
351
- " loss = self.criterion(outputs, labels)\n",
352
- "\n",
353
- " _, preds = torch.max(outputs, 1)\n",
354
- " running_loss += loss.item() * inputs.size(0)\n",
355
- " running_corrects += torch.sum(preds == labels.data).item()\n",
356
- " total_samples += inputs.size(0)\n",
357
- "\n",
358
- " return running_loss / total_samples, running_corrects / total_samples, None\n",
359
- "\n",
360
- " def train(self, num_epochs=100, patience=20):\n",
361
- " \"\"\"Train the model for a number of epochs with early stopping.\"\"\"\n",
362
- " early_stopping = EarlyStopping(\n",
363
- " patience=patience,\n",
364
- " verbose=True,\n",
365
- " path=f'{self.model_name}_best.pt',\n",
366
- " minEpochs=num_epochs // 4\n",
367
- " )\n",
368
- " for epoch in range(num_epochs):\n",
369
- " print(f'\\nEpoch {epoch + 1}/{num_epochs}')\n",
370
- " train_loss, train_acc, _ = self.train_epoch()\n",
371
- " val_loss, val_acc, _ = self.validate_epoch()\n",
372
- "\n",
373
- " self.history['train_loss'].append(train_loss)\n",
374
- " self.history['train_acc'].append(train_acc)\n",
375
- " self.history['val_loss'].append(val_loss)\n",
376
- " self.history['val_acc'].append(val_acc)\n",
377
- "\n",
378
- " print(f'Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}')\n",
379
- " print(f'Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f}')\n",
380
- "\n",
381
- " early_stopping(val_loss, self.model, epoch)\n",
382
- " if early_stopping.early_stop:\n",
383
- " print(\"Early stopping triggered.\")\n",
384
- " self.history['best_epoch'] = epoch\n",
385
- " break\n",
386
- "\n",
387
- " # Load best checkpoint\n",
388
- " checkpoint = torch.load(f'{self.model_name}_best.pt')\n",
389
- " self.model.load_state_dict(checkpoint['state_dict'])\n",
390
- " return self.history\n",
391
- "\n",
392
- " def evaluate(self, save_root=\"./eval_plots\"):\n",
393
- " \"\"\"Evaluate the model and save metrics and ROC curves.\"\"\"\n",
394
- " self.model.eval()\n",
395
- " y_true, y_pred, y_prob = [], [], []\n",
396
- " pbar = tqdm(self.test_loader, desc='Testing')\n",
397
- " with torch.no_grad():\n",
398
- " for inputs, labels in pbar:\n",
399
- " inputs = inputs.to(self.device)\n",
400
- " labels = labels.to(self.device)\n",
401
- "\n",
402
- " logits = self.model(inputs)\n",
403
- " probs = torch.softmax(logits, dim=1)\n",
404
- " preds = probs.argmax(dim=1)\n",
405
- "\n",
406
- " y_true.extend(labels.cpu().numpy())\n",
407
- " y_pred.extend(preds.cpu().numpy())\n",
408
- " y_prob.extend(probs.cpu().numpy())\n",
409
- "\n",
410
- " y_true = np.array(y_true)\n",
411
- " y_pred = np.array(y_pred)\n",
412
- " y_prob = np.array(y_prob)\n",
413
- " num_classes = y_prob.shape[1]\n",
414
- " y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))\n",
415
- "\n",
416
- " # Metrics\n",
417
- " accuracy = accuracy_score(y_true, y_pred)\n",
418
- " f1 = f1_score(y_true, y_pred, average='weighted')\n",
419
- " precision = precision_score(y_true, y_pred, average='weighted')\n",
420
- " recall = recall_score(y_true, y_pred, average='weighted')\n",
421
- " auc_macro = roc_auc_score(y_true_bin, y_prob, average='macro', multi_class='ovr')\n",
422
- "\n",
423
- " print(f\"Accuracy: {accuracy:.4f}\")\n",
424
- " print(f\"F1 Score: {f1:.4f}\")\n",
425
- " print(f\"Precision: {precision:.4f}\")\n",
426
- " print(f\"Recall: {recall:.4f}\")\n",
427
- " print(f\"AUC (macro): {auc_macro:.4f}\")\n",
428
- "\n",
429
- " # Save directory\n",
430
- " save_dir = os.path.join(save_root, self.model_name)\n",
431
- " os.makedirs(save_dir, exist_ok=True)\n",
432
- "\n",
433
- " # Confusion Matrix\n",
434
- " cm = confusion_matrix(y_true, y_pred)\n",
435
- " plt.figure(figsize=(8, 6))\n",
436
- " plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n",
437
- " plt.title(\"Confusion Matrix\")\n",
438
- " plt.colorbar()\n",
439
- " ticks = np.arange(num_classes)\n",
440
- " plt.xticks(ticks, ticks)\n",
441
- " plt.yticks(ticks, ticks)\n",
442
- " thresh = cm.max() / 2.\n",
443
- " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
444
- " plt.text(j, i, format(cm[i, j], 'd'),\n",
445
- " ha=\"center\",\n",
446
- " color=\"white\" if cm[i, j] > thresh else \"black\")\n",
447
- " plt.ylabel('True label')\n",
448
- " plt.xlabel('Predicted label')\n",
449
- " plt.tight_layout()\n",
450
- " plt.savefig(os.path.join(save_dir, \"confusion_matrix.jpg\"), format='jpg')\n",
451
- " plt.close()\n",
452
- "\n",
453
- " # Class-wise ROC\n",
454
- " fpr, tpr, roc_auc_values = {}, {}, {}\n",
455
- " for i in range(num_classes):\n",
456
- " fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_prob[:, i])\n",
457
- " roc_auc_values[i] = auc(fpr[i], tpr[i])\n",
458
- "\n",
459
- " plt.figure()\n",
460
- " for i in range(num_classes):\n",
461
- " plt.plot(fpr[i], tpr[i], label=f\"Class {i} (AUC = {roc_auc_values[i]:.2f})\")\n",
462
- " plt.plot([0, 1], [0, 1], linestyle='--')\n",
463
- " plt.title(\"Class-wise ROC Curves\")\n",
464
- " plt.xlabel(\"False Positive Rate\")\n",
465
- " plt.ylabel(\"True Positive Rate\")\n",
466
- " plt.legend(loc=\"lower right\")\n",
467
- " plt.tight_layout()\n",
468
- " plt.savefig(os.path.join(save_dir, \"classwise_auc_roc.jpg\"), format='jpg')\n",
469
- " plt.close()\n",
470
- "\n",
471
- " # Overall ROC\n",
472
- " fpr_mac, tpr_mac, _ = roc_curve(y_true_bin.ravel(), y_prob.ravel())\n",
473
- " roc_auc_mac = auc(fpr_mac, tpr_mac)\n",
474
- " plt.figure()\n",
475
- " plt.plot(fpr_mac, tpr_mac, label=f\"Overall (macro) ROC (AUC = {roc_auc_mac:.2f})\")\n",
476
- " plt.plot([0, 1], [0, 1], linestyle='--')\n",
477
- " plt.title(\"Overall (Macro) ROC Curve\")\n",
478
- " plt.xlabel(\"False Positive Rate\")\n",
479
- " plt.ylabel(\"True Positive Rate\")\n",
480
- " plt.legend(loc=\"lower right\")\n",
481
- " plt.tight_layout()\n",
482
- " plt.savefig(os.path.join(save_dir, \"overall_auc_roc.jpg\"), format='jpg')\n",
483
- " plt.close()\n",
484
- "\n",
485
- " # Classification Report\n",
486
- " report = classification_report(y_true, y_pred)\n",
487
- " print(\"\\nClassification Report:\\n\", report)\n",
488
- "\n",
489
- " return {\n",
490
- " 'accuracy': accuracy,\n",
491
- " 'f1_score': f1,\n",
492
- " 'precision': precision,\n",
493
- " 'recall': recall,\n",
494
- " 'auc_macro': auc_macro,\n",
495
- " 'confusion_matrix': cm.tolist(),\n",
496
- " 'classification_report': report,\n",
497
- " 'plots_dir': save_dir\n",
498
- " }\n",
499
- "\n",
500
- "\n",
501
- "def set_seed(seed):\n",
502
- " \"\"\"Set random seed for reproducibility across Python, PyTorch, and NumPy.\"\"\"\n",
503
- " torch.manual_seed(seed)\n",
504
- " torch.cuda.manual_seed(seed)\n",
505
- " torch.cuda.manual_seed_all(seed)\n",
506
- " np.random.seed(seed)\n",
507
- " random.seed(seed)\n",
508
- " torch.backends.cudnn.deterministic = True\n",
509
- " torch.backends.cudnn.benchmark = False"
510
- ]
511
- },
512
- {
513
- "cell_type": "code",
514
- "execution_count": 5,
515
- "metadata": {},
516
- "outputs": [
517
- {
518
- "name": "stdout",
519
- "output_type": "stream",
520
- "text": [
521
- "🖥️ Using device: cuda:7\n",
522
- "\n",
523
- "🚀 Training densenet169 with seed 40\n"
524
- ]
525
- },
526
- {
527
- "name": "stderr",
528
- "output_type": "stream",
529
- "text": [
530
- "/raid/home/sbag/anaconda3/envs/yashwantvenv/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
531
- " warnings.warn(\n",
532
- "/raid/home/sbag/anaconda3/envs/yashwantvenv/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=DenseNet169_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet169_Weights.DEFAULT` to get the most up-to-date weights.\n",
533
- " warnings.warn(msg)\n"
534
- ]
535
- },
536
- {
537
- "name": "stdout",
538
- "output_type": "stream",
539
- "text": [
540
- "\n",
541
- "Epoch 1/100\n"
542
- ]
543
- },
544
- {
545
- "name": "stderr",
546
- "output_type": "stream",
547
- "text": [
548
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.52it/s]\n"
549
- ]
550
- },
551
- {
552
- "name": "stdout",
553
- "output_type": "stream",
554
- "text": [
555
- "Train Loss: 0.6409 | Acc: 0.7247\n",
556
- "Val Loss: 0.5123 | Acc: 0.8097\n",
557
- "\n",
558
- "Epoch 2/100\n"
559
- ]
560
- },
561
- {
562
- "name": "stderr",
563
- "output_type": "stream",
564
- "text": [
565
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.55it/s]\n"
566
- ]
567
- },
568
- {
569
- "name": "stdout",
570
- "output_type": "stream",
571
- "text": [
572
- "Train Loss: 0.2327 | Acc: 0.9251\n",
573
- "Val Loss: 0.4905 | Acc: 0.7773\n",
574
- "\n",
575
- "Epoch 3/100\n"
576
- ]
577
- },
578
- {
579
- "name": "stderr",
580
- "output_type": "stream",
581
- "text": [
582
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.55it/s]\n"
583
- ]
584
- },
585
- {
586
- "name": "stdout",
587
- "output_type": "stream",
588
- "text": [
589
- "Train Loss: 0.1046 | Acc: 0.9717\n",
590
- "Val Loss: 0.5231 | Acc: 0.8057\n",
591
- "\n",
592
- "Epoch 4/100\n"
593
- ]
594
- },
595
- {
596
- "name": "stderr",
597
- "output_type": "stream",
598
- "text": [
599
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.55it/s]\n"
600
- ]
601
- },
602
- {
603
- "name": "stdout",
604
- "output_type": "stream",
605
- "text": [
606
- "Train Loss: 0.0758 | Acc: 0.9868\n",
607
- "Val Loss: 0.5464 | Acc: 0.8097\n",
608
- "\n",
609
- "Epoch 5/100\n"
610
- ]
611
- },
612
- {
613
- "name": "stderr",
614
- "output_type": "stream",
615
- "text": [
616
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.55it/s]\n"
617
- ]
618
- },
619
- {
620
- "name": "stdout",
621
- "output_type": "stream",
622
- "text": [
623
- "Train Loss: 0.0505 | Acc: 0.9879\n",
624
- "Val Loss: 0.4652 | Acc: 0.8381\n",
625
- "\n",
626
- "Epoch 6/100\n"
627
- ]
628
- },
629
- {
630
- "name": "stderr",
631
- "output_type": "stream",
632
- "text": [
633
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.54it/s]\n"
634
- ]
635
- },
636
- {
637
- "name": "stdout",
638
- "output_type": "stream",
639
- "text": [
640
- "Train Loss: 0.0384 | Acc: 0.9909\n",
641
- "Val Loss: 0.5930 | Acc: 0.8016\n",
642
- "\n",
643
- "Epoch 7/100\n"
644
- ]
645
- },
646
- {
647
- "name": "stderr",
648
- "output_type": "stream",
649
- "text": [
650
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.53it/s]\n"
651
- ]
652
- },
653
- {
654
- "name": "stdout",
655
- "output_type": "stream",
656
- "text": [
657
- "Train Loss: 0.0368 | Acc: 0.9889\n",
658
- "Val Loss: 0.6190 | Acc: 0.7733\n",
659
- "\n",
660
- "Epoch 8/100\n"
661
- ]
662
- },
663
- {
664
- "name": "stderr",
665
- "output_type": "stream",
666
- "text": [
667
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.54it/s]\n"
668
- ]
669
- },
670
- {
671
- "name": "stdout",
672
- "output_type": "stream",
673
- "text": [
674
- "Train Loss: 0.0680 | Acc: 0.9706\n",
675
- "Val Loss: 0.6185 | Acc: 0.8057\n",
676
- "\n",
677
- "Epoch 9/100\n"
678
- ]
679
- },
680
- {
681
- "name": "stderr",
682
- "output_type": "stream",
683
- "text": [
684
- "Training: 100%|██████████| 62/62 [00:39<00:00, 1.55it/s]\n"
685
- ]
686
- },
687
- {
688
- "name": "stdout",
689
- "output_type": "stream",
690
- "text": [
691
- "Train Loss: 0.1453 | Acc: 0.9393\n",
692
- "Val Loss: 0.7820 | Acc: 0.7652\n",
693
- "\n",
694
- "Epoch 10/100\n"
695
- ]
696
- },
697
- {
698
- "name": "stderr",
699
- "output_type": "stream",
700
- "text": [
701
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.55it/s]\n"
702
- ]
703
- },
704
- {
705
- "name": "stdout",
706
- "output_type": "stream",
707
- "text": [
708
- "Train Loss: 0.1284 | Acc: 0.9565\n",
709
- "Val Loss: 0.6574 | Acc: 0.7895\n",
710
- "\n",
711
- "Epoch 11/100\n"
712
- ]
713
- },
714
- {
715
- "name": "stderr",
716
- "output_type": "stream",
717
- "text": [
718
- "Training: 100%|██████████| 62/62 [00:39<00:00, 1.56it/s]\n"
719
- ]
720
- },
721
- {
722
- "name": "stdout",
723
- "output_type": "stream",
724
- "text": [
725
- "Train Loss: 0.0861 | Acc: 0.9636\n",
726
- "Val Loss: 0.6034 | Acc: 0.7854\n",
727
- "\n",
728
- "Epoch 12/100\n"
729
- ]
730
- },
731
- {
732
- "name": "stderr",
733
- "output_type": "stream",
734
- "text": [
735
- "Training: 100%|██████████| 62/62 [00:39<00:00, 1.56it/s]\n"
736
- ]
737
- },
738
- {
739
- "name": "stdout",
740
- "output_type": "stream",
741
- "text": [
742
- "Train Loss: 0.0318 | Acc: 0.9929\n",
743
- "Val Loss: 0.5974 | Acc: 0.8300\n",
744
- "\n",
745
- "Epoch 13/100\n"
746
- ]
747
- },
748
- {
749
- "name": "stderr",
750
- "output_type": "stream",
751
- "text": [
752
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.52it/s]\n"
753
- ]
754
- },
755
- {
756
- "name": "stdout",
757
- "output_type": "stream",
758
- "text": [
759
- "Train Loss: 0.0133 | Acc: 0.9949\n",
760
- "Val Loss: 0.6431 | Acc: 0.7976\n",
761
- "\n",
762
- "Epoch 14/100\n"
763
- ]
764
- },
765
- {
766
- "name": "stderr",
767
- "output_type": "stream",
768
- "text": [
769
- "Training: 100%|██████████| 62/62 [00:45<00:00, 1.36it/s]\n"
770
- ]
771
- },
772
- {
773
- "name": "stdout",
774
- "output_type": "stream",
775
- "text": [
776
- "Train Loss: 0.0130 | Acc: 0.9949\n",
777
- "Val Loss: 0.6782 | Acc: 0.8340\n",
778
- "\n",
779
- "Epoch 15/100\n"
780
- ]
781
- },
782
- {
783
- "name": "stderr",
784
- "output_type": "stream",
785
- "text": [
786
- "Training: 100%|██████████| 62/62 [00:45<00:00, 1.36it/s]\n"
787
- ]
788
- },
789
- {
790
- "name": "stdout",
791
- "output_type": "stream",
792
- "text": [
793
- "Train Loss: 0.0108 | Acc: 0.9960\n",
794
- "Val Loss: 0.6482 | Acc: 0.8138\n",
795
- "\n",
796
- "Epoch 16/100\n"
797
- ]
798
- },
799
- {
800
- "name": "stderr",
801
- "output_type": "stream",
802
- "text": [
803
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.54it/s]\n"
804
- ]
805
- },
806
- {
807
- "name": "stdout",
808
- "output_type": "stream",
809
- "text": [
810
- "Train Loss: 0.0106 | Acc: 0.9929\n",
811
- "Val Loss: 0.6366 | Acc: 0.8138\n",
812
- "\n",
813
- "Epoch 17/100\n"
814
- ]
815
- },
816
- {
817
- "name": "stderr",
818
- "output_type": "stream",
819
- "text": [
820
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.55it/s]\n"
821
- ]
822
- },
823
- {
824
- "name": "stdout",
825
- "output_type": "stream",
826
- "text": [
827
- "Train Loss: 0.0208 | Acc: 0.9919\n",
828
- "Val Loss: 0.8402 | Acc: 0.7854\n",
829
- "\n",
830
- "Epoch 18/100\n"
831
- ]
832
- },
833
- {
834
- "name": "stderr",
835
- "output_type": "stream",
836
- "text": [
837
- "Training: 100%|██████████| 62/62 [00:39<00:00, 1.55it/s]\n"
838
- ]
839
- },
840
- {
841
- "name": "stdout",
842
- "output_type": "stream",
843
- "text": [
844
- "Train Loss: 0.0180 | Acc: 0.9919\n",
845
- "Val Loss: 0.7927 | Acc: 0.8097\n",
846
- "\n",
847
- "Epoch 19/100\n"
848
- ]
849
- },
850
- {
851
- "name": "stderr",
852
- "output_type": "stream",
853
- "text": [
854
- "Training: 100%|██████████| 62/62 [00:44<00:00, 1.39it/s]\n"
855
- ]
856
- },
857
- {
858
- "name": "stdout",
859
- "output_type": "stream",
860
- "text": [
861
- "Train Loss: 0.0113 | Acc: 0.9960\n",
862
- "Val Loss: 0.6292 | Acc: 0.8421\n",
863
- "\n",
864
- "Epoch 20/100\n"
865
- ]
866
- },
867
- {
868
- "name": "stderr",
869
- "output_type": "stream",
870
- "text": [
871
- "Training: 100%|██████████| 62/62 [00:44<00:00, 1.40it/s]\n"
872
- ]
873
- },
874
- {
875
- "name": "stdout",
876
- "output_type": "stream",
877
- "text": [
878
- "Train Loss: 0.0297 | Acc: 0.9879\n",
879
- "Val Loss: 0.7917 | Acc: 0.7935\n",
880
- "\n",
881
- "Epoch 21/100\n"
882
- ]
883
- },
884
- {
885
- "name": "stderr",
886
- "output_type": "stream",
887
- "text": [
888
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.54it/s]\n"
889
- ]
890
- },
891
- {
892
- "name": "stdout",
893
- "output_type": "stream",
894
- "text": [
895
- "Train Loss: 0.0325 | Acc: 0.9858\n",
896
- "Val Loss: 0.8198 | Acc: 0.8259\n",
897
- "\n",
898
- "Epoch 22/100\n"
899
- ]
900
- },
901
- {
902
- "name": "stderr",
903
- "output_type": "stream",
904
- "text": [
905
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.54it/s]\n"
906
- ]
907
- },
908
- {
909
- "name": "stdout",
910
- "output_type": "stream",
911
- "text": [
912
- "Train Loss: 0.0692 | Acc: 0.9717\n",
913
- "Val Loss: 0.9598 | Acc: 0.7733\n",
914
- "\n",
915
- "Epoch 23/100\n"
916
- ]
917
- },
918
- {
919
- "name": "stderr",
920
- "output_type": "stream",
921
- "text": [
922
- "Training: 100%|██████████| 62/62 [00:42<00:00, 1.46it/s]\n"
923
- ]
924
- },
925
- {
926
- "name": "stdout",
927
- "output_type": "stream",
928
- "text": [
929
- "Train Loss: 0.1032 | Acc: 0.9626\n",
930
- "Val Loss: 0.7010 | Acc: 0.7976\n",
931
- "\n",
932
- "Epoch 24/100\n"
933
- ]
934
- },
935
- {
936
- "name": "stderr",
937
- "output_type": "stream",
938
- "text": [
939
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.54it/s]\n"
940
- ]
941
- },
942
- {
943
- "name": "stdout",
944
- "output_type": "stream",
945
- "text": [
946
- "Train Loss: 0.0464 | Acc: 0.9818\n",
947
- "Val Loss: 0.7204 | Acc: 0.7895\n",
948
- "\n",
949
- "Epoch 25/100\n"
950
- ]
951
- },
952
- {
953
- "name": "stderr",
954
- "output_type": "stream",
955
- "text": [
956
- "Training: 100%|██████████| 62/62 [00:40<00:00, 1.54it/s]\n"
957
- ]
958
- },
959
- {
960
- "name": "stdout",
961
- "output_type": "stream",
962
- "text": [
963
- "Train Loss: 0.0126 | Acc: 0.9970\n",
964
- "Val Loss: 0.7141 | Acc: 0.7935\n",
965
- "\n",
966
- "Epoch 26/100\n"
967
- ]
968
- },
969
- {
970
- "name": "stderr",
971
- "output_type": "stream",
972
- "text": [
973
- "Training: 100%|██████████| 62/62 [00:42<00:00, 1.45it/s]\n"
974
- ]
975
- },
976
- {
977
- "name": "stdout",
978
- "output_type": "stream",
979
- "text": [
980
- "Train Loss: 0.0298 | Acc: 0.9889\n",
981
- "Val Loss: 0.8210 | Acc: 0.8057\n",
982
- "Validation loss decreased (inf --> 0.820993). Saving checkpoint...\n",
983
- "\n",
984
- "Epoch 27/100\n"
985
- ]
986
- },
987
- {
988
- "name": "stderr",
989
- "output_type": "stream",
990
- "text": [
991
- "Training: 100%|██████████| 62/62 [00:33<00:00, 1.84it/s]\n"
992
- ]
993
- },
994
- {
995
- "name": "stdout",
996
- "output_type": "stream",
997
- "text": [
998
- "Train Loss: 0.0346 | Acc: 0.9868\n",
999
- "Val Loss: 0.8775 | Acc: 0.8138\n",
1000
- "EarlyStopping counter: 1 out of 30\n",
1001
- "\n",
1002
- "Epoch 28/100\n"
1003
- ]
1004
- },
1005
- {
1006
- "name": "stderr",
1007
- "output_type": "stream",
1008
- "text": [
1009
- "Training: 100%|██████████| 62/62 [00:39<00:00, 1.58it/s]\n"
1010
- ]
1011
- },
1012
- {
1013
- "name": "stdout",
1014
- "output_type": "stream",
1015
- "text": [
1016
- "Train Loss: 0.0334 | Acc: 0.9868\n",
1017
- "Val Loss: 0.6443 | Acc: 0.8178\n",
1018
- "Validation loss decreased (0.820993 --> 0.644336). Saving checkpoint...\n",
1019
- "\n",
1020
- "Epoch 29/100\n"
1021
- ]
1022
- },
1023
- {
1024
- "name": "stderr",
1025
- "output_type": "stream",
1026
- "text": [
1027
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.80it/s]\n"
1028
- ]
1029
- },
1030
- {
1031
- "name": "stdout",
1032
- "output_type": "stream",
1033
- "text": [
1034
- "Train Loss: 0.0138 | Acc: 0.9919\n",
1035
- "Val Loss: 0.7253 | Acc: 0.8178\n",
1036
- "EarlyStopping counter: 1 out of 30\n",
1037
- "\n",
1038
- "Epoch 30/100\n"
1039
- ]
1040
- },
1041
- {
1042
- "name": "stderr",
1043
- "output_type": "stream",
1044
- "text": [
1045
- "Training: 100%|██████████| 62/62 [00:33<00:00, 1.88it/s]\n"
1046
- ]
1047
- },
1048
- {
1049
- "name": "stdout",
1050
- "output_type": "stream",
1051
- "text": [
1052
- "Train Loss: 0.0188 | Acc: 0.9919\n",
1053
- "Val Loss: 0.6810 | Acc: 0.7854\n",
1054
- "EarlyStopping counter: 2 out of 30\n",
1055
- "\n",
1056
- "Epoch 31/100\n"
1057
- ]
1058
- },
1059
- {
1060
- "name": "stderr",
1061
- "output_type": "stream",
1062
- "text": [
1063
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.81it/s]\n"
1064
- ]
1065
- },
1066
- {
1067
- "name": "stdout",
1068
- "output_type": "stream",
1069
- "text": [
1070
- "Train Loss: 0.0103 | Acc: 0.9949\n",
1071
- "Val Loss: 0.7952 | Acc: 0.8097\n",
1072
- "EarlyStopping counter: 3 out of 30\n",
1073
- "\n",
1074
- "Epoch 32/100\n"
1075
- ]
1076
- },
1077
- {
1078
- "name": "stderr",
1079
- "output_type": "stream",
1080
- "text": [
1081
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.80it/s]\n"
1082
- ]
1083
- },
1084
- {
1085
- "name": "stdout",
1086
- "output_type": "stream",
1087
- "text": [
1088
- "Train Loss: 0.0128 | Acc: 0.9939\n",
1089
- "Val Loss: 0.7404 | Acc: 0.8178\n",
1090
- "EarlyStopping counter: 4 out of 30\n",
1091
- "\n",
1092
- "Epoch 33/100\n"
1093
- ]
1094
- },
1095
- {
1096
- "name": "stderr",
1097
- "output_type": "stream",
1098
- "text": [
1099
- "Training: 100%|██████████| 62/62 [00:33<00:00, 1.84it/s]\n"
1100
- ]
1101
- },
1102
- {
1103
- "name": "stdout",
1104
- "output_type": "stream",
1105
- "text": [
1106
- "Train Loss: 0.0144 | Acc: 0.9919\n",
1107
- "Val Loss: 0.7201 | Acc: 0.8178\n",
1108
- "EarlyStopping counter: 5 out of 30\n",
1109
- "\n",
1110
- "Epoch 34/100\n"
1111
- ]
1112
- },
1113
- {
1114
- "name": "stderr",
1115
- "output_type": "stream",
1116
- "text": [
1117
- "Training: 100%|██████████| 62/62 [00:33<00:00, 1.84it/s]\n"
1118
- ]
1119
- },
1120
- {
1121
- "name": "stdout",
1122
- "output_type": "stream",
1123
- "text": [
1124
- "Train Loss: 0.0151 | Acc: 0.9949\n",
1125
- "Val Loss: 0.7246 | Acc: 0.8178\n",
1126
- "EarlyStopping counter: 6 out of 30\n",
1127
- "\n",
1128
- "Epoch 35/100\n"
1129
- ]
1130
- },
1131
- {
1132
- "name": "stderr",
1133
- "output_type": "stream",
1134
- "text": [
1135
- "Training: 100%|██████████| 62/62 [00:33<00:00, 1.85it/s]\n"
1136
- ]
1137
- },
1138
- {
1139
- "name": "stdout",
1140
- "output_type": "stream",
1141
- "text": [
1142
- "Train Loss: 0.0101 | Acc: 0.9960\n",
1143
- "Val Loss: 0.7968 | Acc: 0.8219\n",
1144
- "EarlyStopping counter: 7 out of 30\n",
1145
- "\n",
1146
- "Epoch 36/100\n"
1147
- ]
1148
- },
1149
- {
1150
- "name": "stderr",
1151
- "output_type": "stream",
1152
- "text": [
1153
- "Training: 100%|██████████| 62/62 [00:36<00:00, 1.71it/s]\n"
1154
- ]
1155
- },
1156
- {
1157
- "name": "stdout",
1158
- "output_type": "stream",
1159
- "text": [
1160
- "Train Loss: 0.0093 | Acc: 0.9939\n",
1161
- "Val Loss: 0.8325 | Acc: 0.8097\n",
1162
- "EarlyStopping counter: 8 out of 30\n",
1163
- "\n",
1164
- "Epoch 37/100\n"
1165
- ]
1166
- },
1167
- {
1168
- "name": "stderr",
1169
- "output_type": "stream",
1170
- "text": [
1171
- "Training: 100%|██████████| 62/62 [00:37<00:00, 1.65it/s]\n"
1172
- ]
1173
- },
1174
- {
1175
- "name": "stdout",
1176
- "output_type": "stream",
1177
- "text": [
1178
- "Train Loss: 0.0095 | Acc: 0.9960\n",
1179
- "Val Loss: 0.8379 | Acc: 0.8178\n",
1180
- "EarlyStopping counter: 9 out of 30\n",
1181
- "\n",
1182
- "Epoch 38/100\n"
1183
- ]
1184
- },
1185
- {
1186
- "name": "stderr",
1187
- "output_type": "stream",
1188
- "text": [
1189
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.82it/s]\n"
1190
- ]
1191
- },
1192
- {
1193
- "name": "stdout",
1194
- "output_type": "stream",
1195
- "text": [
1196
- "Train Loss: 0.0634 | Acc: 0.9757\n",
1197
- "Val Loss: 0.8605 | Acc: 0.7571\n",
1198
- "EarlyStopping counter: 10 out of 30\n",
1199
- "\n",
1200
- "Epoch 39/100\n"
1201
- ]
1202
- },
1203
- {
1204
- "name": "stderr",
1205
- "output_type": "stream",
1206
- "text": [
1207
- "Training: 100%|██████████| 62/62 [00:41<00:00, 1.49it/s]\n"
1208
- ]
1209
- },
1210
- {
1211
- "name": "stdout",
1212
- "output_type": "stream",
1213
- "text": [
1214
- "Train Loss: 0.1095 | Acc: 0.9615\n",
1215
- "Val Loss: 1.1914 | Acc: 0.7773\n",
1216
- "EarlyStopping counter: 11 out of 30\n",
1217
- "\n",
1218
- "Epoch 40/100\n"
1219
- ]
1220
- },
1221
- {
1222
- "name": "stderr",
1223
- "output_type": "stream",
1224
- "text": [
1225
- "Training: 100%|██████████| 62/62 [00:36<00:00, 1.69it/s]\n"
1226
- ]
1227
- },
1228
- {
1229
- "name": "stdout",
1230
- "output_type": "stream",
1231
- "text": [
1232
- "Train Loss: 0.0696 | Acc: 0.9777\n",
1233
- "Val Loss: 0.8333 | Acc: 0.8097\n",
1234
- "EarlyStopping counter: 12 out of 30\n",
1235
- "\n",
1236
- "Epoch 41/100\n"
1237
- ]
1238
- },
1239
- {
1240
- "name": "stderr",
1241
- "output_type": "stream",
1242
- "text": [
1243
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.79it/s]\n"
1244
- ]
1245
- },
1246
- {
1247
- "name": "stdout",
1248
- "output_type": "stream",
1249
- "text": [
1250
- "Train Loss: 0.0380 | Acc: 0.9828\n",
1251
- "Val Loss: 0.7910 | Acc: 0.8381\n",
1252
- "EarlyStopping counter: 13 out of 30\n",
1253
- "\n",
1254
- "Epoch 42/100\n"
1255
- ]
1256
- },
1257
- {
1258
- "name": "stderr",
1259
- "output_type": "stream",
1260
- "text": [
1261
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.78it/s]\n"
1262
- ]
1263
- },
1264
- {
1265
- "name": "stdout",
1266
- "output_type": "stream",
1267
- "text": [
1268
- "Train Loss: 0.0196 | Acc: 0.9919\n",
1269
- "Val Loss: 0.7824 | Acc: 0.8178\n",
1270
- "EarlyStopping counter: 14 out of 30\n",
1271
- "\n",
1272
- "Epoch 43/100\n"
1273
- ]
1274
- },
1275
- {
1276
- "name": "stderr",
1277
- "output_type": "stream",
1278
- "text": [
1279
- "Training: 100%|██████████| 62/62 [00:42<00:00, 1.45it/s]\n"
1280
- ]
1281
- },
1282
- {
1283
- "name": "stdout",
1284
- "output_type": "stream",
1285
- "text": [
1286
- "Train Loss: 0.0163 | Acc: 0.9939\n",
1287
- "Val Loss: 0.9460 | Acc: 0.8178\n",
1288
- "EarlyStopping counter: 15 out of 30\n",
1289
- "\n",
1290
- "Epoch 44/100\n"
1291
- ]
1292
- },
1293
- {
1294
- "name": "stderr",
1295
- "output_type": "stream",
1296
- "text": [
1297
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.80it/s]\n"
1298
- ]
1299
- },
1300
- {
1301
- "name": "stdout",
1302
- "output_type": "stream",
1303
- "text": [
1304
- "Train Loss: 0.0086 | Acc: 0.9949\n",
1305
- "Val Loss: 0.8867 | Acc: 0.8259\n",
1306
- "EarlyStopping counter: 16 out of 30\n",
1307
- "\n",
1308
- "Epoch 45/100\n"
1309
- ]
1310
- },
1311
- {
1312
- "name": "stderr",
1313
- "output_type": "stream",
1314
- "text": [
1315
- "Training: 100%|██████████| 62/62 [00:33<00:00, 1.83it/s]\n"
1316
- ]
1317
- },
1318
- {
1319
- "name": "stdout",
1320
- "output_type": "stream",
1321
- "text": [
1322
- "Train Loss: 0.0068 | Acc: 0.9949\n",
1323
- "Val Loss: 0.8065 | Acc: 0.8300\n",
1324
- "EarlyStopping counter: 17 out of 30\n",
1325
- "\n",
1326
- "Epoch 46/100\n"
1327
- ]
1328
- },
1329
- {
1330
- "name": "stderr",
1331
- "output_type": "stream",
1332
- "text": [
1333
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.82it/s]\n"
1334
- ]
1335
- },
1336
- {
1337
- "name": "stdout",
1338
- "output_type": "stream",
1339
- "text": [
1340
- "Train Loss: 0.0067 | Acc: 0.9939\n",
1341
- "Val Loss: 0.9346 | Acc: 0.8462\n",
1342
- "EarlyStopping counter: 18 out of 30\n",
1343
- "\n",
1344
- "Epoch 47/100\n"
1345
- ]
1346
- },
1347
- {
1348
- "name": "stderr",
1349
- "output_type": "stream",
1350
- "text": [
1351
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.81it/s]\n"
1352
- ]
1353
- },
1354
- {
1355
- "name": "stdout",
1356
- "output_type": "stream",
1357
- "text": [
1358
- "Train Loss: 0.0056 | Acc: 0.9970\n",
1359
- "Val Loss: 0.8019 | Acc: 0.8340\n",
1360
- "EarlyStopping counter: 19 out of 30\n",
1361
- "\n",
1362
- "Epoch 48/100\n"
1363
- ]
1364
- },
1365
- {
1366
- "name": "stderr",
1367
- "output_type": "stream",
1368
- "text": [
1369
- "Training: 100%|██████████| 62/62 [00:33<00:00, 1.83it/s]\n"
1370
- ]
1371
- },
1372
- {
1373
- "name": "stdout",
1374
- "output_type": "stream",
1375
- "text": [
1376
- "Train Loss: 0.0056 | Acc: 0.9970\n",
1377
- "Val Loss: 0.8673 | Acc: 0.8421\n",
1378
- "EarlyStopping counter: 20 out of 30\n",
1379
- "\n",
1380
- "Epoch 49/100\n"
1381
- ]
1382
- },
1383
- {
1384
- "name": "stderr",
1385
- "output_type": "stream",
1386
- "text": [
1387
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.82it/s]\n"
1388
- ]
1389
- },
1390
- {
1391
- "name": "stdout",
1392
- "output_type": "stream",
1393
- "text": [
1394
- "Train Loss: 0.0051 | Acc: 0.9949\n",
1395
- "Val Loss: 0.8609 | Acc: 0.8421\n",
1396
- "EarlyStopping counter: 21 out of 30\n",
1397
- "\n",
1398
- "Epoch 50/100\n"
1399
- ]
1400
- },
1401
- {
1402
- "name": "stderr",
1403
- "output_type": "stream",
1404
- "text": [
1405
- "Training: 100%|██████████| 62/62 [00:35<00:00, 1.74it/s]\n"
1406
- ]
1407
- },
1408
- {
1409
- "name": "stdout",
1410
- "output_type": "stream",
1411
- "text": [
1412
- "Train Loss: 0.0051 | Acc: 0.9939\n",
1413
- "Val Loss: 0.8720 | Acc: 0.8421\n",
1414
- "EarlyStopping counter: 22 out of 30\n",
1415
- "\n",
1416
- "Epoch 51/100\n"
1417
- ]
1418
- },
1419
- {
1420
- "name": "stderr",
1421
- "output_type": "stream",
1422
- "text": [
1423
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.80it/s]\n"
1424
- ]
1425
- },
1426
- {
1427
- "name": "stdout",
1428
- "output_type": "stream",
1429
- "text": [
1430
- "Train Loss: 0.0051 | Acc: 0.9970\n",
1431
- "Val Loss: 0.8415 | Acc: 0.8340\n",
1432
- "EarlyStopping counter: 23 out of 30\n",
1433
- "\n",
1434
- "Epoch 52/100\n"
1435
- ]
1436
- },
1437
- {
1438
- "name": "stderr",
1439
- "output_type": "stream",
1440
- "text": [
1441
- "Training: 100%|██████████| 62/62 [00:41<00:00, 1.51it/s]\n"
1442
- ]
1443
- },
1444
- {
1445
- "name": "stdout",
1446
- "output_type": "stream",
1447
- "text": [
1448
- "Train Loss: 0.0048 | Acc: 0.9970\n",
1449
- "Val Loss: 0.8468 | Acc: 0.8421\n",
1450
- "EarlyStopping counter: 24 out of 30\n",
1451
- "\n",
1452
- "Epoch 53/100\n"
1453
- ]
1454
- },
1455
- {
1456
- "name": "stderr",
1457
- "output_type": "stream",
1458
- "text": [
1459
- "Training: 100%|██████████| 62/62 [00:35<00:00, 1.73it/s]\n"
1460
- ]
1461
- },
1462
- {
1463
- "name": "stdout",
1464
- "output_type": "stream",
1465
- "text": [
1466
- "Train Loss: 0.0049 | Acc: 0.9960\n",
1467
- "Val Loss: 0.8831 | Acc: 0.8502\n",
1468
- "EarlyStopping counter: 25 out of 30\n",
1469
- "\n",
1470
- "Epoch 54/100\n"
1471
- ]
1472
- },
1473
- {
1474
- "name": "stderr",
1475
- "output_type": "stream",
1476
- "text": [
1477
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.80it/s]\n"
1478
- ]
1479
- },
1480
- {
1481
- "name": "stdout",
1482
- "output_type": "stream",
1483
- "text": [
1484
- "Train Loss: 0.0060 | Acc: 0.9960\n",
1485
- "Val Loss: 0.9027 | Acc: 0.8340\n",
1486
- "EarlyStopping counter: 26 out of 30\n",
1487
- "\n",
1488
- "Epoch 55/100\n"
1489
- ]
1490
- },
1491
- {
1492
- "name": "stderr",
1493
- "output_type": "stream",
1494
- "text": [
1495
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.79it/s]\n"
1496
- ]
1497
- },
1498
- {
1499
- "name": "stdout",
1500
- "output_type": "stream",
1501
- "text": [
1502
- "Train Loss: 0.0138 | Acc: 0.9919\n",
1503
- "Val Loss: 0.8650 | Acc: 0.8300\n",
1504
- "EarlyStopping counter: 27 out of 30\n",
1505
- "\n",
1506
- "Epoch 56/100\n"
1507
- ]
1508
- },
1509
- {
1510
- "name": "stderr",
1511
- "output_type": "stream",
1512
- "text": [
1513
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.80it/s]\n"
1514
- ]
1515
- },
1516
- {
1517
- "name": "stdout",
1518
- "output_type": "stream",
1519
- "text": [
1520
- "Train Loss: 0.0283 | Acc: 0.9909\n",
1521
- "Val Loss: 1.3316 | Acc: 0.7854\n",
1522
- "EarlyStopping counter: 28 out of 30\n",
1523
- "\n",
1524
- "Epoch 57/100\n"
1525
- ]
1526
- },
1527
- {
1528
- "name": "stderr",
1529
- "output_type": "stream",
1530
- "text": [
1531
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.79it/s]\n"
1532
- ]
1533
- },
1534
- {
1535
- "name": "stdout",
1536
- "output_type": "stream",
1537
- "text": [
1538
- "Train Loss: 0.0984 | Acc: 0.9666\n",
1539
- "Val Loss: 0.7924 | Acc: 0.7935\n",
1540
- "EarlyStopping counter: 29 out of 30\n",
1541
- "\n",
1542
- "Epoch 58/100\n"
1543
- ]
1544
- },
1545
- {
1546
- "name": "stderr",
1547
- "output_type": "stream",
1548
- "text": [
1549
- "Training: 100%|██████████| 62/62 [00:34<00:00, 1.80it/s]\n"
1550
- ]
1551
- },
1552
- {
1553
- "name": "stdout",
1554
- "output_type": "stream",
1555
- "text": [
1556
- "Train Loss: 0.0883 | Acc: 0.9666\n",
1557
- "Val Loss: 0.8648 | Acc: 0.7935\n",
1558
- "EarlyStopping counter: 30 out of 30\n",
1559
- "Early stopping triggered.\n"
1560
- ]
1561
- },
1562
- {
1563
- "name": "stderr",
1564
- "output_type": "stream",
1565
- "text": [
1566
- "/tmp/ipykernel_65719/2726642889.py:141: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
1567
- " checkpoint = torch.load(f'{self.model_name}_best.pt')\n",
1568
- "Testing: 100%|██████████| 20/20 [00:09<00:00, 2.15it/s]\n"
1569
- ]
1570
- },
1571
- {
1572
- "name": "stdout",
1573
- "output_type": "stream",
1574
- "text": [
1575
- "Accuracy: 0.8317\n",
1576
- "F1 Score: 0.8266\n",
1577
- "Precision: 0.8237\n",
1578
- "Recall: 0.8317\n",
1579
- "AUC (macro): 0.9197\n",
1580
- "\n",
1581
- "Classification Report:\n",
1582
- " precision recall f1-score support\n",
1583
- "\n",
1584
- " 0 0.86 0.91 0.88 99\n",
1585
- " 1 0.59 0.49 0.54 53\n",
1586
- " 2 0.88 0.90 0.89 157\n",
1587
- "\n",
1588
- " accuracy 0.83 309\n",
1589
- " macro avg 0.78 0.77 0.77 309\n",
1590
- "weighted avg 0.82 0.83 0.83 309\n",
1591
- "\n",
1592
- "\n",
1593
- "✅ Results saved to glaucoma_classification_glam_model_results_40.json\n"
1594
- ]
1595
- }
1596
- ],
1597
- "source": [
1598
- "def set_seed(seed):\n",
1599
- " torch.manual_seed(seed)\n",
1600
- " torch.cuda.manual_seed(seed)\n",
1601
- " torch.cuda.manual_seed_all(seed)\n",
1602
- " np.random.seed(seed)\n",
1603
- " random.seed(seed)\n",
1604
- " torch.backends.cudnn.deterministic = True\n",
1605
- " torch.backends.cudnn.benchmark = False\n",
1606
- "\n",
1607
- "# ✅ Set device\n",
1608
- "device = torch.device(\"cuda:7\" if torch.cuda.is_available() else \"cpu\")\n",
1609
- "print(f'🖥️ Using device: {device}')\n",
1610
- "\n",
1611
- "# ✅ Define data transforms\n",
1612
- "data_transforms = {\n",
1613
- " 'train': transforms.Compose([\n",
1614
- " transforms.RandomResizedCrop(224),\n",
1615
- " transforms.RandomHorizontalFlip(),\n",
1616
- " transforms.RandomRotation(15),\n",
1617
- " transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),\n",
1618
- " transforms.ToTensor(),\n",
1619
- " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
1620
- " ]),\n",
1621
- " 'val': transforms.Compose([\n",
1622
- " transforms.Resize(256),\n",
1623
- " transforms.CenterCrop(224),\n",
1624
- " transforms.ToTensor(),\n",
1625
- " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
1626
- " ]),\n",
1627
- "}\n",
1628
- "\n",
1629
- "# ✅ Define architecture and attention variants\n",
1630
- "architectures = [\"densenet169\"]\n",
1631
- "\n",
1632
- "models_to_test=architectures\n",
1633
- "seeds_to_test = [40]\n",
1634
- "num_classes = 3\n",
1635
- "results_file = 'glaucoma_classification_glam_model_results_40.json'\n",
1636
- "\n",
1637
- "# ✅ Store all results\n",
1638
- "all_results = {}\n",
1639
- "\n",
1640
- "# ✅ Main training loop\n",
1641
- "for model_name in models_to_test:\n",
1642
- " model_results = {}\n",
1643
- " for seed in seeds_to_test:\n",
1644
- " print(f\"\\n🚀 Training {model_name} with seed {seed}\")\n",
1645
- " set_seed(seed)\n",
1646
- "\n",
1647
- " try:\n",
1648
- " # Load dataset\n",
1649
- " dataset = datasets.ImageFolder(\n",
1650
- " root=\"/raid/home/sbag/B.tech/Final Results/processed_data\",\n",
1651
- " transform=data_transforms['train']\n",
1652
- " )\n",
1653
- "\n",
1654
- " # Split data\n",
1655
- " total_size = len(dataset)\n",
1656
- " train_size = int(0.8 * total_size)\n",
1657
- " test_size = total_size - train_size\n",
1658
- " train_val_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
1659
- "\n",
1660
- " val_size = int(0.2 * train_size)\n",
1661
- " train_size = train_size - val_size\n",
1662
- " train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])\n",
1663
- "\n",
1664
- " val_dataset.dataset.transform = data_transforms['val']\n",
1665
- " test_dataset.dataset.transform = data_transforms['val']\n",
1666
- "\n",
1667
- " # Create loaders\n",
1668
- " train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)\n",
1669
- " val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)\n",
1670
- " test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)\n",
1671
- "\n",
1672
- " # Get model with attention\n",
1673
- " model = get_model_with_attention(model_name, num_classes=num_classes).to(device)\n",
1674
- "\n",
1675
- " # Loss and optimizer\n",
1676
- " criterion = nn.CrossEntropyLoss()\n",
1677
- " optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)\n",
1678
- "\n",
1679
- " # Train\n",
1680
- " trainer = ModelTrainer(\n",
1681
- " model=model,\n",
1682
- " train_loader=train_loader,\n",
1683
- " val_loader=val_loader,\n",
1684
- " test_loader=test_loader,\n",
1685
- " criterion=criterion,\n",
1686
- " optimizer=optimizer,\n",
1687
- " device=device,\n",
1688
- " model_name=model_name,\n",
1689
- " seed=seed\n",
1690
- " )\n",
1691
- "\n",
1692
- " history = trainer.train(num_epochs=100, patience=30)\n",
1693
- " results = trainer.evaluate()\n",
1694
- " model_results[f\"seed_{seed}\"] = results\n",
1695
- "\n",
1696
- " except Exception as e:\n",
1697
- " print(f\"❌ Error while training {model_name} (seed={seed}): {e}\")\n",
1698
- " model_results[f\"seed_{seed}\"] = {\"error\": str(e)}\n",
1699
- "\n",
1700
- " all_results[model_name] = model_results\n",
1701
- "\n",
1702
- "# ✅ Save all results\n",
1703
- "with open(results_file, 'w') as f:\n",
1704
- " json.dump(all_results, f, indent=4)\n",
1705
- "\n",
1706
- "print(f\"\\n✅ Results saved to {results_file}\")"
1707
- ]
1708
- }
1709
- ],
1710
- "metadata": {
1711
- "kernelspec": {
1712
- "display_name": "yashwantvenv",
1713
- "language": "python",
1714
- "name": "python3"
1715
- },
1716
- "language_info": {
1717
- "codemirror_mode": {
1718
- "name": "ipython",
1719
- "version": 3
1720
- },
1721
- "file_extension": ".py",
1722
- "mimetype": "text/x-python",
1723
- "name": "python",
1724
- "nbconvert_exporter": "python",
1725
- "pygments_lexer": "ipython3",
1726
- "version": "3.8.19"
1727
- }
1728
- },
1729
- "nbformat": 4,
1730
- "nbformat_minor": 2
1731
- }