sakshamlakhera commited on
Commit
1265dde
·
1 Parent(s): b274faf

fixing scripts

Browse files
scripts/CV/Part1.ipynb CHANGED
@@ -66,10 +66,10 @@
66
  "def augment_rotations(X, y):\n",
67
  " X_aug = []\n",
68
  " y_aug = []\n",
69
- " for k in [1, 2, 3]: # 90, 180, 270 degrees\n",
70
- " X_rot = torch.rot90(X, k=k, dims=[2, 3]) # rotate along H and W\n",
71
  " X_aug.append(X_rot)\n",
72
- " y_aug.append(y.clone()) # Same labels for rotated images\n",
73
  " return torch.cat(X_aug), torch.cat(y_aug)\n"
74
  ]
75
  },
@@ -165,7 +165,6 @@
165
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
166
  " plt.show()\n",
167
  "\n",
168
- "# Display for each class\n",
169
  "for class_name, image_array in datasets.items():\n",
170
  " show_random_samples(image_array, class_name)\n"
171
  ]
@@ -189,7 +188,7 @@
189
  "\n",
190
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
191
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
192
- " ax.label_outer() # Hide x labels and tick labels for inner plots\n",
193
  "\n",
194
  "plt.tight_layout()\n",
195
  "plt.show()\n"
@@ -305,7 +304,7 @@
305
  "class_names = list(datasets.keys())\n",
306
  "num_classes = len(class_names)\n",
307
  "\n",
308
- "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) # 1 row, 4 columns\n",
309
  "\n",
310
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
311
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
@@ -371,7 +370,6 @@
371
  "from sklearn.model_selection import train_test_split\n",
372
  "from torchvision import transforms\n",
373
  "\n",
374
- "# Combine data\n",
375
  "X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)\n",
376
  "y = (\n",
377
  " ['onion'] * len(onion_images) +\n",
@@ -380,16 +378,14 @@
380
  " ['tomato'] * len(tomato_images)\n",
381
  ")\n",
382
  "\n",
383
- "# Normalizing image\n",
384
  "X = X.astype(np.float32) / 255.0\n",
385
- "X = np.transpose(X, (0, 3, 1, 2)) # (N, C, H, W)\n",
386
  "X_tensor = torch.tensor(X)\n",
387
  "\n",
388
  "le = LabelEncoder()\n",
389
  "y_encoded = le.fit_transform(y)\n",
390
  "y_tensor = torch.tensor(y_encoded)\n",
391
  "\n",
392
- "# splitting data into 50:25:25 (train, validation, test)\n",
393
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
394
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
395
  ]
@@ -403,13 +399,10 @@
403
  "source": [
404
  "batch_size = 32\n",
405
  "\n",
406
- "# Create new training dataset and loader\n",
407
  "train_dataset = TensorDataset(X_train, y_train)\n",
408
  "val_dataset = TensorDataset(X_val, y_val)\n",
409
  "test_dataset = TensorDataset(X_test, y_test)\n",
410
  "\n",
411
- "# DataLoaders\n",
412
- "\n",
413
  "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
414
  "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
415
  "test_loader = DataLoader(test_dataset, batch_size=batch_size)"
@@ -432,9 +425,9 @@
432
  "metadata": {},
433
  "outputs": [],
434
  "source": [
435
- "print(f\"🔢 Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
436
- "print(f\"🔢 Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
437
- "print(f\"🔢 Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
438
  ]
439
  },
440
  {
@@ -521,8 +514,6 @@
521
  " optimizer.step()\n",
522
  "\n",
523
  " total_train_loss += loss.item()\n",
524
- "\n",
525
- " # Track training accuracy\n",
526
  " pred_labels = preds.argmax(dim=1)\n",
527
  " train_correct += (pred_labels == batch_y).sum().item()\n",
528
  " train_total += batch_y.size(0)\n",
@@ -546,7 +537,6 @@
546
  " val_accuracy = val_correct / val_total\n",
547
  " validation_loss = criterion(model(val_x), val_y).item()\n",
548
  "\n",
549
- " # After calculating val_accuracy\n",
550
  " val_losses.append(validation_loss)\n",
551
  " val_accs.append(val_accuracy)\n",
552
  "\n",
@@ -637,7 +627,7 @@
637
  "\n",
638
  "print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n",
639
  "\n",
640
- "target_names = le.classes_ # ['onion', 'pear', 'strawberry', 'tomato']\n",
641
  "print(\"\\nClassification Report:\\n\")\n",
642
  "print(classification_report(all_targets, all_preds, target_names=target_names))\n",
643
  "\n",
@@ -726,7 +716,7 @@
726
  " h.remove()\n",
727
  "\n",
728
  " for layer_name, fmap in activations.items():\n",
729
- " fmap = fmap.squeeze(0) # [C, H, W]\n",
730
  " num_channels = min(fmap.shape[0], max_channels)\n",
731
  "\n",
732
  " plt.figure(figsize=(num_channels * 2, 2.5))\n",
@@ -758,29 +748,22 @@
758
  " activations[name] = output.detach().cpu()\n",
759
  " return hook\n",
760
  "\n",
761
- " # Register hooks for all layers in model.features\n",
762
  " hooks = []\n",
763
  " for i in range(len(model.features)):\n",
764
  " layer = model.features[i]\n",
765
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
766
  "\n",
767
  " with torch.no_grad():\n",
768
- " _ = model(image_tensor.unsqueeze(0)) # Add batch dimension: [1, 3, 224, 224]\n",
769
  "\n",
770
  " for h in hooks:\n",
771
  " h.remove()\n",
772
  "\n",
773
  " for layer_name, fmap in activations.items():\n",
774
- " fmap = fmap.squeeze(0) # Shape: [C, H, W]\n",
775
- "\n",
776
- " # Compute mean activation per channel\n",
777
- " channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
778
- "\n",
779
- " # Get indices of top-k channels\n",
780
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
781
  " top_indices = topk.indices\n",
782
- "\n",
783
- " # Plot top-k channels\n",
784
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
785
  " for idx, ch in enumerate(top_indices):\n",
786
  " plt.subplot(1, max_channels, idx + 1)\n",
@@ -821,14 +804,11 @@
821
  "\n",
822
  "img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
823
  "\n",
824
- "# Preprocessing (must match model requirements)\n",
825
  "transform = transforms.Compose([\n",
826
  " transforms.Resize((224, 224)),\n",
827
  " transforms.ToTensor()\n",
828
  "])\n",
829
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
830
- "\n",
831
- "# Visualize feature maps\n",
832
  "visualize_channels(model, img_tensor, max_channels=16)\n"
833
  ]
834
  },
@@ -849,14 +829,12 @@
849
  "source": [
850
  "img = Image.open(\"dataset/Pear_512/Whole/image_0089.jpg\").convert(\"RGB\")\n",
851
  "\n",
852
- "# Preprocessing (must match model requirements)\n",
853
  "transform = transforms.Compose([\n",
854
  " transforms.Resize((224, 224)),\n",
855
  " transforms.ToTensor()\n",
856
  "])\n",
857
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
858
  "\n",
859
- "# Visualize feature maps\n",
860
  "visualize_channels(model, img_tensor, max_channels=16)\n"
861
  ]
862
  },
@@ -877,14 +855,12 @@
877
  "source": [
878
  "img = Image.open(\"dataset/Tomato_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
879
  "\n",
880
- "# Preprocessing (must match model requirements)\n",
881
  "transform = transforms.Compose([\n",
882
  " transforms.Resize((224, 224)),\n",
883
  " transforms.ToTensor()\n",
884
  "])\n",
885
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
886
  "\n",
887
- "# Visualize feature maps\n",
888
  "visualize_channels(model, img_tensor, max_channels=16)\n"
889
  ]
890
  },
 
66
  "def augment_rotations(X, y):\n",
67
  " X_aug = []\n",
68
  " y_aug = []\n",
69
+ " for k in [1, 2, 3]: \n",
70
+ " X_rot = torch.rot90(X, k=k, dims=[2, 3])\n",
71
  " X_aug.append(X_rot)\n",
72
+ " y_aug.append(y.clone())\n",
73
  " return torch.cat(X_aug), torch.cat(y_aug)\n"
74
  ]
75
  },
 
165
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
166
  " plt.show()\n",
167
  "\n",
 
168
  "for class_name, image_array in datasets.items():\n",
169
  " show_random_samples(image_array, class_name)\n"
170
  ]
 
188
  "\n",
189
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
190
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
191
+ " ax.label_outer()\n",
192
  "\n",
193
  "plt.tight_layout()\n",
194
  "plt.show()\n"
 
304
  "class_names = list(datasets.keys())\n",
305
  "num_classes = len(class_names)\n",
306
  "\n",
307
+ "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))\n",
308
  "\n",
309
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
310
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
 
370
  "from sklearn.model_selection import train_test_split\n",
371
  "from torchvision import transforms\n",
372
  "\n",
 
373
  "X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)\n",
374
  "y = (\n",
375
  " ['onion'] * len(onion_images) +\n",
 
378
  " ['tomato'] * len(tomato_images)\n",
379
  ")\n",
380
  "\n",
 
381
  "X = X.astype(np.float32) / 255.0\n",
382
+ "X = np.transpose(X, (0, 3, 1, 2)) \n",
383
  "X_tensor = torch.tensor(X)\n",
384
  "\n",
385
  "le = LabelEncoder()\n",
386
  "y_encoded = le.fit_transform(y)\n",
387
  "y_tensor = torch.tensor(y_encoded)\n",
388
  "\n",
 
389
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
390
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
391
  ]
 
399
  "source": [
400
  "batch_size = 32\n",
401
  "\n",
 
402
  "train_dataset = TensorDataset(X_train, y_train)\n",
403
  "val_dataset = TensorDataset(X_val, y_val)\n",
404
  "test_dataset = TensorDataset(X_test, y_test)\n",
405
  "\n",
 
 
406
  "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
407
  "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
408
  "test_loader = DataLoader(test_dataset, batch_size=batch_size)"
 
425
  "metadata": {},
426
  "outputs": [],
427
  "source": [
428
+ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
429
+ "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
430
+ "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
431
  ]
432
  },
433
  {
 
514
  " optimizer.step()\n",
515
  "\n",
516
  " total_train_loss += loss.item()\n",
 
 
517
  " pred_labels = preds.argmax(dim=1)\n",
518
  " train_correct += (pred_labels == batch_y).sum().item()\n",
519
  " train_total += batch_y.size(0)\n",
 
537
  " val_accuracy = val_correct / val_total\n",
538
  " validation_loss = criterion(model(val_x), val_y).item()\n",
539
  "\n",
 
540
  " val_losses.append(validation_loss)\n",
541
  " val_accs.append(val_accuracy)\n",
542
  "\n",
 
627
  "\n",
628
  "print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n",
629
  "\n",
630
+ "target_names = le.classes_ \n",
631
  "print(\"\\nClassification Report:\\n\")\n",
632
  "print(classification_report(all_targets, all_preds, target_names=target_names))\n",
633
  "\n",
 
716
  " h.remove()\n",
717
  "\n",
718
  " for layer_name, fmap in activations.items():\n",
719
+ " fmap = fmap.squeeze(0)\n",
720
  " num_channels = min(fmap.shape[0], max_channels)\n",
721
  "\n",
722
  " plt.figure(figsize=(num_channels * 2, 2.5))\n",
 
748
  " activations[name] = output.detach().cpu()\n",
749
  " return hook\n",
750
  "\n",
 
751
  " hooks = []\n",
752
  " for i in range(len(model.features)):\n",
753
  " layer = model.features[i]\n",
754
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
755
  "\n",
756
  " with torch.no_grad():\n",
757
+ " _ = model(image_tensor.unsqueeze(0))\n",
758
  "\n",
759
  " for h in hooks:\n",
760
  " h.remove()\n",
761
  "\n",
762
  " for layer_name, fmap in activations.items():\n",
763
+ " fmap = fmap.squeeze(0)\n",
764
+ " channel_scores = fmap.mean(dim=(1, 2)) \n",
 
 
 
 
765
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
766
  " top_indices = topk.indices\n",
 
 
767
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
768
  " for idx, ch in enumerate(top_indices):\n",
769
  " plt.subplot(1, max_channels, idx + 1)\n",
 
804
  "\n",
805
  "img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
806
  "\n",
 
807
  "transform = transforms.Compose([\n",
808
  " transforms.Resize((224, 224)),\n",
809
  " transforms.ToTensor()\n",
810
  "])\n",
811
+ "img_tensor = transform(img)\n",
 
 
812
  "visualize_channels(model, img_tensor, max_channels=16)\n"
813
  ]
814
  },
 
829
  "source": [
830
  "img = Image.open(\"dataset/Pear_512/Whole/image_0089.jpg\").convert(\"RGB\")\n",
831
  "\n",
 
832
  "transform = transforms.Compose([\n",
833
  " transforms.Resize((224, 224)),\n",
834
  " transforms.ToTensor()\n",
835
  "])\n",
836
+ "img_tensor = transform(img)\n",
837
  "\n",
 
838
  "visualize_channels(model, img_tensor, max_channels=16)\n"
839
  ]
840
  },
 
855
  "source": [
856
  "img = Image.open(\"dataset/Tomato_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
857
  "\n",
 
858
  "transform = transforms.Compose([\n",
859
  " transforms.Resize((224, 224)),\n",
860
  " transforms.ToTensor()\n",
861
  "])\n",
862
+ "img_tensor = transform(img)\n",
863
  "\n",
 
864
  "visualize_channels(model, img_tensor, max_channels=16)\n"
865
  ]
866
  },
scripts/CV/compression.ipynb CHANGED
@@ -18,8 +18,8 @@
18
  "import os\n",
19
  "from PIL import Image, ImageOps\n",
20
  "\n",
21
- "input_root = 'Tomato' # Root folder with raw images\n",
22
- "output_root = 'Tomato_512' # Output root folder\n",
23
  "os.makedirs(output_root, exist_ok=True)\n",
24
  "\n",
25
  "def process_image(input_path, output_path, size=(512, 512)):\n",
@@ -27,14 +27,12 @@
27
  " with Image.open(input_path) as img:\n",
28
  " img = img.convert(\"RGB\")\n",
29
  "\n",
30
- " # Resize while preserving aspect ratio, then pad to 512x512\n",
31
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
32
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
33
  " img.save(output_path, \"JPEG\", quality=95)\n",
34
  " except Exception as e:\n",
35
- " print(f\"Error processing {input_path}: {e}\")\n",
36
  "\n",
37
- "# Recursively walk through input_root\n",
38
  "for root, _, files in os.walk(input_root):\n",
39
  " for file in files:\n",
40
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
@@ -43,7 +41,7 @@
43
  " output_path = os.path.join(output_root, rel_path)\n",
44
  " process_image(input_path, output_path)\n",
45
  "\n",
46
- "print(\"All images processed and saved in\", output_root)\n"
47
  ]
48
  },
49
  {
@@ -56,23 +54,20 @@
56
  "import os\n",
57
  "from PIL import Image, ImageOps\n",
58
  "\n",
59
- "input_root = 'Onion' # Root folder with raw images\n",
60
- "output_root = 'Onion_512' # Output root folder\n",
61
  "os.makedirs(output_root, exist_ok=True)\n",
62
  "\n",
63
  "def process_image(input_path, output_path, size=(512, 512)):\n",
64
  " try:\n",
65
  " with Image.open(input_path) as img:\n",
66
  " img = img.convert(\"RGB\")\n",
67
- "\n",
68
- " # Resize while preserving aspect ratio, then pad to 512x512\n",
69
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
70
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
71
  " img.save(output_path, \"JPEG\", quality=95)\n",
72
  " except Exception as e:\n",
73
- " print(f\"Error processing {input_path}: {e}\")\n",
74
  "\n",
75
- "# Recursively walk through input_root\n",
76
  "for root, _, files in os.walk(input_root):\n",
77
  " for file in files:\n",
78
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
@@ -81,7 +76,7 @@
81
  " output_path = os.path.join(output_root, rel_path)\n",
82
  " process_image(input_path, output_path)\n",
83
  "\n",
84
- "print(\"All images processed and saved in\", output_root)\n"
85
  ]
86
  },
87
  {
@@ -94,23 +89,20 @@
94
  "import os\n",
95
  "from PIL import Image, ImageOps\n",
96
  "\n",
97
- "input_root = 'Pear' # Root folder with raw images\n",
98
- "output_root = 'Pear_512' # Output root folder\n",
99
  "os.makedirs(output_root, exist_ok=True)\n",
100
  "\n",
101
  "def process_image(input_path, output_path, size=(512, 512)):\n",
102
  " try:\n",
103
  " with Image.open(input_path) as img:\n",
104
  " img = img.convert(\"RGB\")\n",
105
- "\n",
106
- " # Resize while preserving aspect ratio, then pad to 512x512\n",
107
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
108
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
109
  " img.save(output_path, \"JPEG\", quality=95)\n",
110
  " except Exception as e:\n",
111
- " print(f\"Error processing {input_path}: {e}\")\n",
112
  "\n",
113
- "# Recursively walk through input_root\n",
114
  "for root, _, files in os.walk(input_root):\n",
115
  " for file in files:\n",
116
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
@@ -119,7 +111,7 @@
119
  " output_path = os.path.join(output_root, rel_path)\n",
120
  " process_image(input_path, output_path)\n",
121
  "\n",
122
- "print(\"All images processed and saved in\", output_root)\n"
123
  ]
124
  },
125
  {
@@ -132,23 +124,21 @@
132
  "import os\n",
133
  "from PIL import Image, ImageOps\n",
134
  "\n",
135
- "input_root = 'Strawberry' # Root folder with raw images\n",
136
- "output_root = 'Strawberry_512' # Output root folder\n",
137
  "os.makedirs(output_root, exist_ok=True)\n",
138
  "\n",
139
  "def process_image(input_path, output_path, size=(512, 512)):\n",
140
  " try:\n",
141
  " with Image.open(input_path) as img:\n",
142
  " img = img.convert(\"RGB\")\n",
143
- "\n",
144
- " # Resize while preserving aspect ratio, then pad to 512x512\n",
145
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
146
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
147
  " img.save(output_path, \"JPEG\", quality=95)\n",
148
  " except Exception as e:\n",
149
- " print(f\"Error processing {input_path}: {e}\")\n",
 
150
  "\n",
151
- "# Recursively walk through input_root\n",
152
  "for root, _, files in os.walk(input_root):\n",
153
  " for file in files:\n",
154
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
@@ -157,16 +147,8 @@
157
  " output_path = os.path.join(output_root, rel_path)\n",
158
  " process_image(input_path, output_path)\n",
159
  "\n",
160
- "print(\"All images processed and saved in\", output_root)\n"
161
  ]
162
- },
163
- {
164
- "cell_type": "code",
165
- "execution_count": null,
166
- "id": "fd49ae48",
167
- "metadata": {},
168
- "outputs": [],
169
- "source": []
170
  }
171
  ],
172
  "metadata": {
 
18
  "import os\n",
19
  "from PIL import Image, ImageOps\n",
20
  "\n",
21
+ "input_root = 'Tomato' \n",
22
+ "output_root = 'Tomato_512' \n",
23
  "os.makedirs(output_root, exist_ok=True)\n",
24
  "\n",
25
  "def process_image(input_path, output_path, size=(512, 512)):\n",
 
27
  " with Image.open(input_path) as img:\n",
28
  " img = img.convert(\"RGB\")\n",
29
  "\n",
 
30
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
31
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
32
  " img.save(output_path, \"JPEG\", quality=95)\n",
33
  " except Exception as e:\n",
34
+ " print(f\"Error processing {input_path}: {e}\")\n",
35
  "\n",
 
36
  "for root, _, files in os.walk(input_root):\n",
37
  " for file in files:\n",
38
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
 
41
  " output_path = os.path.join(output_root, rel_path)\n",
42
  " process_image(input_path, output_path)\n",
43
  "\n",
44
+ "print(\"All images processed and saved in\", output_root)\n"
45
  ]
46
  },
47
  {
 
54
  "import os\n",
55
  "from PIL import Image, ImageOps\n",
56
  "\n",
57
+ "input_root = 'Onion' \n",
58
+ "output_root = 'Onion_512' \n",
59
  "os.makedirs(output_root, exist_ok=True)\n",
60
  "\n",
61
  "def process_image(input_path, output_path, size=(512, 512)):\n",
62
  " try:\n",
63
  " with Image.open(input_path) as img:\n",
64
  " img = img.convert(\"RGB\")\n",
 
 
65
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
66
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
67
  " img.save(output_path, \"JPEG\", quality=95)\n",
68
  " except Exception as e:\n",
69
+ " print(f\"Error processing {input_path}: {e}\")\n",
70
  "\n",
 
71
  "for root, _, files in os.walk(input_root):\n",
72
  " for file in files:\n",
73
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
 
76
  " output_path = os.path.join(output_root, rel_path)\n",
77
  " process_image(input_path, output_path)\n",
78
  "\n",
79
+ "print(\"All images processed and saved in\", output_root)\n"
80
  ]
81
  },
82
  {
 
89
  "import os\n",
90
  "from PIL import Image, ImageOps\n",
91
  "\n",
92
+ "input_root = 'Pear' \n",
93
+ "output_root = 'Pear_512' \n",
94
  "os.makedirs(output_root, exist_ok=True)\n",
95
  "\n",
96
  "def process_image(input_path, output_path, size=(512, 512)):\n",
97
  " try:\n",
98
  " with Image.open(input_path) as img:\n",
99
  " img = img.convert(\"RGB\")\n",
 
 
100
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
101
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
102
  " img.save(output_path, \"JPEG\", quality=95)\n",
103
  " except Exception as e:\n",
104
+ " print(f\"Error processing {input_path}: {e}\")\n",
105
  "\n",
 
106
  "for root, _, files in os.walk(input_root):\n",
107
  " for file in files:\n",
108
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
 
111
  " output_path = os.path.join(output_root, rel_path)\n",
112
  " process_image(input_path, output_path)\n",
113
  "\n",
114
+ "print(\"All images processed and saved in\", output_root)\n"
115
  ]
116
  },
117
  {
 
124
  "import os\n",
125
  "from PIL import Image, ImageOps\n",
126
  "\n",
127
+ "input_root = 'Strawberry' \n",
128
+ "output_root = 'Strawberry_512' \n",
129
  "os.makedirs(output_root, exist_ok=True)\n",
130
  "\n",
131
  "def process_image(input_path, output_path, size=(512, 512)):\n",
132
  " try:\n",
133
  " with Image.open(input_path) as img:\n",
134
  " img = img.convert(\"RGB\")\n",
 
 
135
  " img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
136
  " os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
137
  " img.save(output_path, \"JPEG\", quality=95)\n",
138
  " except Exception as e:\n",
139
+ " print(f\"Error processing {input_path}: {e}\")\n",
140
+ "\n",
141
  "\n",
 
142
  "for root, _, files in os.walk(input_root):\n",
143
  " for file in files:\n",
144
  " if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
 
147
  " output_path = os.path.join(output_root, rel_path)\n",
148
  " process_image(input_path, output_path)\n",
149
  "\n",
150
+ "print(\"All images processed and saved in\", output_root)\n"
151
  ]
 
 
 
 
 
 
 
 
152
  }
153
  ],
154
  "metadata": {
scripts/CV/script_onion.ipynb CHANGED
@@ -59,10 +59,10 @@
59
  "def augment_rotations(X, y):\n",
60
  " X_aug = []\n",
61
  " y_aug = []\n",
62
- " for k in [1, 2, 3]: # 90, 180, 270 degrees\n",
63
- " X_rot = torch.rot90(X, k=k, dims=[2, 3]) # rotate along H and W\n",
64
  " X_aug.append(X_rot)\n",
65
- " y_aug.append(y.clone()) # Same labels for rotated images\n",
66
  " return torch.cat(X_aug), torch.cat(y_aug)"
67
  ]
68
  },
@@ -120,7 +120,6 @@
120
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
121
  " plt.show()\n",
122
  "\n",
123
- "# Display for each class\n",
124
  "for class_name, image_array in datasets.items():\n",
125
  " show_random_samples(image_array, class_name)\n"
126
  ]
@@ -136,7 +135,7 @@
136
  "\n",
137
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
138
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
139
- " ax.label_outer() # Hide x labels and tick labels for inner plots\n",
140
  "\n",
141
  "plt.tight_layout()\n",
142
  "plt.show()"
@@ -152,7 +151,7 @@
152
  "class_names = list(datasets.keys())\n",
153
  "num_classes = len(class_names)\n",
154
  "\n",
155
- "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) # 1 row, 4 columns\n",
156
  "\n",
157
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
158
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
@@ -209,7 +208,6 @@
209
  "\n",
210
  "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
211
  "\n",
212
- "# Combine original and augmented data\n",
213
  "X_train_combined = torch.cat([X_train, X_augmented])\n",
214
  "y_train_combined = torch.cat([y_train, y_augmented])\n",
215
  "\n",
@@ -230,9 +228,9 @@
230
  "metadata": {},
231
  "outputs": [],
232
  "source": [
233
- "print(f\"🔢 Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
234
- "print(f\"🔢 Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
235
- "print(f\"🔢 Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
236
  ]
237
  },
238
  {
@@ -480,29 +478,24 @@
480
  " activations[name] = output.detach().cpu()\n",
481
  " return hook\n",
482
  "\n",
483
- " # Register hooks for all layers in model.features\n",
484
  " hooks = []\n",
485
  " for i in range(len(model.features)):\n",
486
  " layer = model.features[i]\n",
487
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
488
  "\n",
489
  " with torch.no_grad():\n",
490
- " _ = model(image_tensor.unsqueeze(0)) # Add batch dimension: [1, 3, 224, 224]\n",
491
  "\n",
492
  " for h in hooks:\n",
493
  " h.remove()\n",
494
  "\n",
495
  " for layer_name, fmap in activations.items():\n",
496
- " fmap = fmap.squeeze(0) # Shape: [C, H, W]\n",
 
497
  "\n",
498
- " # Compute mean activation per channel\n",
499
- " channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
500
- "\n",
501
- " # Get indices of top-k channels\n",
502
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
503
  " top_indices = topk.indices\n",
504
  "\n",
505
- " # Plot top-k channels\n",
506
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
507
  " for idx, ch in enumerate(top_indices):\n",
508
  " plt.subplot(1, max_channels, idx + 1)\n",
@@ -535,14 +528,12 @@
535
  "\n",
536
  "img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
537
  "\n",
538
- "# Preprocessing (must match model requirements)\n",
539
  "transform = transforms.Compose([\n",
540
  " transforms.Resize((224, 224)),\n",
541
  " transforms.ToTensor()\n",
542
  "])\n",
543
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
544
  "\n",
545
- "# Visualize feature maps\n",
546
  "visualize_channels(model, img_tensor, max_channels=16)\n"
547
  ]
548
  },
@@ -556,14 +547,11 @@
556
  "\n",
557
  "img = Image.open(\"dataset/Onion_512/Halved/image_0880.jpg\").convert(\"RGB\")\n",
558
  "\n",
559
- "# Preprocessing (must match model requirements)\n",
560
  "transform = transforms.Compose([\n",
561
  " transforms.Resize((224, 224)),\n",
562
  " transforms.ToTensor()\n",
563
  "])\n",
564
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
565
- "\n",
566
- "# Visualize feature maps\n",
567
  "visualize_channels(model, img_tensor, max_channels=16)\n"
568
  ]
569
  },
@@ -576,14 +564,12 @@
576
  "source": [
577
  "img = Image.open(\"dataset/Onion_512/Sliced/image_0772.jpg\").convert(\"RGB\")\n",
578
  "\n",
579
- "# Preprocessing (must match model requirements)\n",
580
  "transform = transforms.Compose([\n",
581
  " transforms.Resize((224, 224)),\n",
582
  " transforms.ToTensor()\n",
583
  "])\n",
584
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
585
  "\n",
586
- "# Visualize feature maps\n",
587
  "visualize_channels(model, img_tensor, max_channels=16)\n"
588
  ]
589
  },
 
59
  "def augment_rotations(X, y):\n",
60
  " X_aug = []\n",
61
  " y_aug = []\n",
62
+ " for k in [1, 2, 3]:\n",
63
+ " X_rot = torch.rot90(X, k=k, dims=[2, 3])\n",
64
  " X_aug.append(X_rot)\n",
65
+ " y_aug.append(y.clone()) \n",
66
  " return torch.cat(X_aug), torch.cat(y_aug)"
67
  ]
68
  },
 
120
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
121
  " plt.show()\n",
122
  "\n",
 
123
  "for class_name, image_array in datasets.items():\n",
124
  " show_random_samples(image_array, class_name)\n"
125
  ]
 
135
  "\n",
136
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
137
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
138
+ " ax.label_outer()\n",
139
  "\n",
140
  "plt.tight_layout()\n",
141
  "plt.show()"
 
151
  "class_names = list(datasets.keys())\n",
152
  "num_classes = len(class_names)\n",
153
  "\n",
154
+ "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
155
  "\n",
156
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
157
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
 
208
  "\n",
209
  "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
210
  "\n",
 
211
  "X_train_combined = torch.cat([X_train, X_augmented])\n",
212
  "y_train_combined = torch.cat([y_train, y_augmented])\n",
213
  "\n",
 
228
  "metadata": {},
229
  "outputs": [],
230
  "source": [
231
+ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
232
+ "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
233
+ "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
234
  ]
235
  },
236
  {
 
478
  " activations[name] = output.detach().cpu()\n",
479
  " return hook\n",
480
  "\n",
 
481
  " hooks = []\n",
482
  " for i in range(len(model.features)):\n",
483
  " layer = model.features[i]\n",
484
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
485
  "\n",
486
  " with torch.no_grad():\n",
487
+ " _ = model(image_tensor.unsqueeze(0)) \n",
488
  "\n",
489
  " for h in hooks:\n",
490
  " h.remove()\n",
491
  "\n",
492
  " for layer_name, fmap in activations.items():\n",
493
+ " fmap = fmap.squeeze(0) \n",
494
+ " channel_scores = fmap.mean(dim=(1, 2))\n",
495
  "\n",
 
 
 
 
496
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
497
  " top_indices = topk.indices\n",
498
  "\n",
 
499
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
500
  " for idx, ch in enumerate(top_indices):\n",
501
  " plt.subplot(1, max_channels, idx + 1)\n",
 
528
  "\n",
529
  "img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
530
  "\n",
 
531
  "transform = transforms.Compose([\n",
532
  " transforms.Resize((224, 224)),\n",
533
  " transforms.ToTensor()\n",
534
  "])\n",
535
+ "img_tensor = transform(img)\n",
536
  "\n",
 
537
  "visualize_channels(model, img_tensor, max_channels=16)\n"
538
  ]
539
  },
 
547
  "\n",
548
  "img = Image.open(\"dataset/Onion_512/Halved/image_0880.jpg\").convert(\"RGB\")\n",
549
  "\n",
 
550
  "transform = transforms.Compose([\n",
551
  " transforms.Resize((224, 224)),\n",
552
  " transforms.ToTensor()\n",
553
  "])\n",
554
+ "img_tensor = transform(img)\n",
 
 
555
  "visualize_channels(model, img_tensor, max_channels=16)\n"
556
  ]
557
  },
 
564
  "source": [
565
  "img = Image.open(\"dataset/Onion_512/Sliced/image_0772.jpg\").convert(\"RGB\")\n",
566
  "\n",
 
567
  "transform = transforms.Compose([\n",
568
  " transforms.Resize((224, 224)),\n",
569
  " transforms.ToTensor()\n",
570
  "])\n",
571
+ "img_tensor = transform(img) \n",
572
  "\n",
 
573
  "visualize_channels(model, img_tensor, max_channels=16)\n"
574
  ]
575
  },
scripts/CV/script_pear.ipynb CHANGED
@@ -59,10 +59,10 @@
59
  "def augment_rotations(X, y):\n",
60
  " X_aug = []\n",
61
  " y_aug = []\n",
62
- " for k in [1, 2, 3]: # 90, 180, 270 degrees\n",
63
- " X_rot = torch.rot90(X, k=k, dims=[2, 3]) # rotate along H and W\n",
64
  " X_aug.append(X_rot)\n",
65
- " y_aug.append(y.clone()) # Same labels for rotated images\n",
66
  " return torch.cat(X_aug), torch.cat(y_aug)"
67
  ]
68
  },
@@ -122,7 +122,6 @@
122
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
123
  " plt.show()\n",
124
  "\n",
125
- "# Display for each class\n",
126
  "for class_name, image_array in datasets.items():\n",
127
  " show_random_samples(image_array, class_name)\n"
128
  ]
@@ -138,7 +137,7 @@
138
  "\n",
139
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
140
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
141
- " ax.label_outer() # Hide x labels and tick labels for inner plots\n",
142
  "\n",
143
  "plt.tight_layout()\n",
144
  "plt.show()"
@@ -154,7 +153,7 @@
154
  "class_names = list(datasets.keys())\n",
155
  "num_classes = len(class_names)\n",
156
  "\n",
157
- "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) # 1 row, 4 columns\n",
158
  "\n",
159
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
160
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
@@ -180,7 +179,6 @@
180
  " \"whole\": pear_whole_images\n",
181
  "}\n",
182
  "\n",
183
- "# Combine data\n",
184
  "X = np.concatenate([pear_halved_images, pear_sliced_images, pear_whole_images], axis=0)\n",
185
  "y = (\n",
186
  " ['halved'] * len(pear_halved_images) +\n",
@@ -188,17 +186,14 @@
188
  " ['whole'] * len(pear_whole_images)\n",
189
  ")\n",
190
  "\n",
191
- "# Normalize and convert to torch tensors\n",
192
  "X = X.astype(np.float32) / 255.0\n",
193
- "X = np.transpose(X, (0, 3, 1, 2)) # (N, C, H, W)\n",
194
  "X_tensor = torch.tensor(X)\n",
195
  "\n",
196
- "# Encode labels\n",
197
  "le = LabelEncoder()\n",
198
  "y_encoded = le.fit_transform(y)\n",
199
  "y_tensor = torch.tensor(y_encoded)\n",
200
  "\n",
201
- "# Train/val/test split\n",
202
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
203
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
204
  ]
@@ -215,7 +210,6 @@
215
  "val_dataset = TensorDataset(X_val, y_val)\n",
216
  "test_dataset = TensorDataset(X_test, y_test)\n",
217
  "\n",
218
- "# DataLoaders\n",
219
  "batch_size = 32\n",
220
  "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
221
  "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
@@ -229,9 +223,9 @@
229
  "metadata": {},
230
  "outputs": [],
231
  "source": [
232
- "print(f\"🔢 Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
233
- "print(f\"🔢 Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
234
- "print(f\"🔢 Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
235
  ]
236
  },
237
  {
@@ -249,8 +243,6 @@
249
  "\n",
250
  "def get_efficientnet_model(num_classes):\n",
251
  " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
252
- "\n",
253
- " # Replace classifier head with custom head\n",
254
  " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
255
  "\n",
256
  " return model\n",
@@ -266,10 +258,10 @@
266
  "source": [
267
  "if torch.backends.mps.is_available():\n",
268
  " device = torch.device(\"mps\")\n",
269
- " print(\"Using MPS (Apple GPU)\")\n",
270
  "else:\n",
271
  " device = torch.device(\"cpu\")\n",
272
- " print(\"⚠️ MPS not available. Using CPU\")\n",
273
  "\n",
274
  "model = get_efficientnet_model(num_classes=3).to(device)\n",
275
  "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
@@ -312,7 +304,6 @@
312
  "\n",
313
  " total_train_loss += loss.item()\n",
314
  "\n",
315
- " # Track training accuracy\n",
316
  " pred_labels = preds.argmax(dim=1)\n",
317
  " train_correct += (pred_labels == batch_y).sum().item()\n",
318
  " train_total += batch_y.size(0)\n",
@@ -367,7 +358,6 @@
367
  "\n",
368
  "plt.figure(figsize=(12, 5))\n",
369
  "\n",
370
- "# Plot Loss\n",
371
  "plt.subplot(1, 2, 1)\n",
372
  "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
373
  "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
@@ -377,7 +367,6 @@
377
  "plt.legend()\n",
378
  "plt.grid(True)\n",
379
  "\n",
380
- "# Plot Accuracy\n",
381
  "plt.subplot(1, 2, 2)\n",
382
  "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
383
  "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
@@ -490,29 +479,24 @@
490
  " activations[name] = output.detach().cpu()\n",
491
  " return hook\n",
492
  "\n",
493
- " # Register hooks for all layers in model.features\n",
494
  " hooks = []\n",
495
  " for i in range(len(model.features)):\n",
496
  " layer = model.features[i]\n",
497
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
498
  "\n",
499
  " with torch.no_grad():\n",
500
- " _ = model(image_tensor.unsqueeze(0)) # Add batch dimension: [1, 3, 224, 224]\n",
501
  "\n",
502
  " for h in hooks:\n",
503
  " h.remove()\n",
504
  "\n",
505
  " for layer_name, fmap in activations.items():\n",
506
- " fmap = fmap.squeeze(0) # Shape: [C, H, W]\n",
507
- "\n",
508
- " # Compute mean activation per channel\n",
509
- " channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
510
  "\n",
511
- " # Get indices of top-k channels\n",
512
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
513
  " top_indices = topk.indices\n",
514
  "\n",
515
- " # Plot top-k channels\n",
516
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
517
  " for idx, ch in enumerate(top_indices):\n",
518
  " plt.subplot(1, max_channels, idx + 1)\n",
@@ -545,14 +529,11 @@
545
  "\n",
546
  "img = Image.open(\"dataset/Pear_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
547
  "\n",
548
- "# Preprocessing (must match model requirements)\n",
549
  "transform = transforms.Compose([\n",
550
  " transforms.Resize((224, 224)),\n",
551
  " transforms.ToTensor()\n",
552
  "])\n",
553
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
554
- "\n",
555
- "# Visualize feature maps\n",
556
  "visualize_channels(model, img_tensor, max_channels=16)\n"
557
  ]
558
  },
@@ -565,14 +546,11 @@
565
  "source": [
566
  "img = Image.open(\"dataset/Pear_512/Halved/image_0578.jpg\").convert(\"RGB\")\n",
567
  "\n",
568
- "# Preprocessing (must match model requirements)\n",
569
  "transform = transforms.Compose([\n",
570
  " transforms.Resize((224, 224)),\n",
571
  " transforms.ToTensor()\n",
572
  "])\n",
573
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
574
- "\n",
575
- "# Visualize feature maps\n",
576
  "visualize_channels(model, img_tensor, max_channels=16)\n"
577
  ]
578
  },
@@ -584,15 +562,11 @@
584
  "outputs": [],
585
  "source": [
586
  "img = Image.open(\"dataset/Pear_512/Sliced/image_0007.jpg\").convert(\"RGB\")\n",
587
- "\n",
588
- "# Preprocessing (must match model requirements)\n",
589
  "transform = transforms.Compose([\n",
590
  " transforms.Resize((224, 224)),\n",
591
  " transforms.ToTensor()\n",
592
  "])\n",
593
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
594
- "\n",
595
- "# Visualize feature maps\n",
596
  "visualize_channels(model, img_tensor, max_channels=16)\n"
597
  ]
598
  },
 
59
  "def augment_rotations(X, y):\n",
60
  " X_aug = []\n",
61
  " y_aug = []\n",
62
+ " for k in [1, 2, 3]: \n",
63
+ " X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
64
  " X_aug.append(X_rot)\n",
65
+ " y_aug.append(y.clone()) \n",
66
  " return torch.cat(X_aug), torch.cat(y_aug)"
67
  ]
68
  },
 
122
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
123
  " plt.show()\n",
124
  "\n",
 
125
  "for class_name, image_array in datasets.items():\n",
126
  " show_random_samples(image_array, class_name)\n"
127
  ]
 
137
  "\n",
138
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
139
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
140
+ " ax.label_outer() \n",
141
  "\n",
142
  "plt.tight_layout()\n",
143
  "plt.show()"
 
153
  "class_names = list(datasets.keys())\n",
154
  "num_classes = len(class_names)\n",
155
  "\n",
156
+ "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
157
  "\n",
158
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
159
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
 
179
  " \"whole\": pear_whole_images\n",
180
  "}\n",
181
  "\n",
 
182
  "X = np.concatenate([pear_halved_images, pear_sliced_images, pear_whole_images], axis=0)\n",
183
  "y = (\n",
184
  " ['halved'] * len(pear_halved_images) +\n",
 
186
  " ['whole'] * len(pear_whole_images)\n",
187
  ")\n",
188
  "\n",
 
189
  "X = X.astype(np.float32) / 255.0\n",
190
+ "X = np.transpose(X, (0, 3, 1, 2)) \n",
191
  "X_tensor = torch.tensor(X)\n",
192
  "\n",
 
193
  "le = LabelEncoder()\n",
194
  "y_encoded = le.fit_transform(y)\n",
195
  "y_tensor = torch.tensor(y_encoded)\n",
196
  "\n",
 
197
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
198
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
199
  ]
 
210
  "val_dataset = TensorDataset(X_val, y_val)\n",
211
  "test_dataset = TensorDataset(X_test, y_test)\n",
212
  "\n",
 
213
  "batch_size = 32\n",
214
  "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
215
  "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
 
223
  "metadata": {},
224
  "outputs": [],
225
  "source": [
226
+ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
227
+ "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
228
+ "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
229
  ]
230
  },
231
  {
 
243
  "\n",
244
  "def get_efficientnet_model(num_classes):\n",
245
  " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
 
 
246
  " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
247
  "\n",
248
  " return model\n",
 
258
  "source": [
259
  "if torch.backends.mps.is_available():\n",
260
  " device = torch.device(\"mps\")\n",
261
+ " print(\"Using MPS (Apple GPU)\")\n",
262
  "else:\n",
263
  " device = torch.device(\"cpu\")\n",
264
+ " print(\"MPS not available. Using CPU\")\n",
265
  "\n",
266
  "model = get_efficientnet_model(num_classes=3).to(device)\n",
267
  "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
 
304
  "\n",
305
  " total_train_loss += loss.item()\n",
306
  "\n",
 
307
  " pred_labels = preds.argmax(dim=1)\n",
308
  " train_correct += (pred_labels == batch_y).sum().item()\n",
309
  " train_total += batch_y.size(0)\n",
 
358
  "\n",
359
  "plt.figure(figsize=(12, 5))\n",
360
  "\n",
 
361
  "plt.subplot(1, 2, 1)\n",
362
  "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
363
  "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
 
367
  "plt.legend()\n",
368
  "plt.grid(True)\n",
369
  "\n",
 
370
  "plt.subplot(1, 2, 2)\n",
371
  "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
372
  "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
 
479
  " activations[name] = output.detach().cpu()\n",
480
  " return hook\n",
481
  "\n",
 
482
  " hooks = []\n",
483
  " for i in range(len(model.features)):\n",
484
  " layer = model.features[i]\n",
485
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
486
  "\n",
487
  " with torch.no_grad():\n",
488
+ " _ = model(image_tensor.unsqueeze(0)) \n",
489
  "\n",
490
  " for h in hooks:\n",
491
  " h.remove()\n",
492
  "\n",
493
  " for layer_name, fmap in activations.items():\n",
494
+ " fmap = fmap.squeeze(0) \n",
495
+ " channel_scores = fmap.mean(dim=(1, 2)) \n",
 
 
496
  "\n",
 
497
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
498
  " top_indices = topk.indices\n",
499
  "\n",
 
500
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
501
  " for idx, ch in enumerate(top_indices):\n",
502
  " plt.subplot(1, max_channels, idx + 1)\n",
 
529
  "\n",
530
  "img = Image.open(\"dataset/Pear_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
531
  "\n",
 
532
  "transform = transforms.Compose([\n",
533
  " transforms.Resize((224, 224)),\n",
534
  " transforms.ToTensor()\n",
535
  "])\n",
536
+ "img_tensor = transform(img) \n",
 
 
537
  "visualize_channels(model, img_tensor, max_channels=16)\n"
538
  ]
539
  },
 
546
  "source": [
547
  "img = Image.open(\"dataset/Pear_512/Halved/image_0578.jpg\").convert(\"RGB\")\n",
548
  "\n",
 
549
  "transform = transforms.Compose([\n",
550
  " transforms.Resize((224, 224)),\n",
551
  " transforms.ToTensor()\n",
552
  "])\n",
553
+ "img_tensor = transform(img) \n",
 
 
554
  "visualize_channels(model, img_tensor, max_channels=16)\n"
555
  ]
556
  },
 
562
  "outputs": [],
563
  "source": [
564
  "img = Image.open(\"dataset/Pear_512/Sliced/image_0007.jpg\").convert(\"RGB\")\n",
 
 
565
  "transform = transforms.Compose([\n",
566
  " transforms.Resize((224, 224)),\n",
567
  " transforms.ToTensor()\n",
568
  "])\n",
569
+ "img_tensor = transform(img) \n",
 
 
570
  "visualize_channels(model, img_tensor, max_channels=16)\n"
571
  ]
572
  },
scripts/CV/script_strawberry.ipynb CHANGED
@@ -59,10 +59,10 @@
59
  "def augment_rotations(X, y):\n",
60
  " X_aug = []\n",
61
  " y_aug = []\n",
62
- " for k in [1, 2, 3]: # 90, 180, 270 degrees\n",
63
- " X_rot = torch.rot90(X, k=k, dims=[2, 3]) # rotate along H and W\n",
64
  " X_aug.append(X_rot)\n",
65
- " y_aug.append(y.clone()) # Same labels for rotated images\n",
66
  " return torch.cat(X_aug), torch.cat(y_aug)"
67
  ]
68
  },
@@ -124,7 +124,6 @@
124
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
125
  " plt.show()\n",
126
  "\n",
127
- "# Display for each class\n",
128
  "for class_name, image_array in datasets.items():\n",
129
  " show_random_samples(image_array, class_name)\n"
130
  ]
@@ -140,7 +139,7 @@
140
  "\n",
141
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
142
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
143
- " ax.label_outer() # Hide x labels and tick labels for inner plots\n",
144
  "\n",
145
  "plt.tight_layout()\n",
146
  "plt.show()"
@@ -156,7 +155,7 @@
156
  "class_names = list(datasets.keys())\n",
157
  "num_classes = len(class_names)\n",
158
  "\n",
159
- "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) # 1 row, 4 columns\n",
160
  "\n",
161
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
162
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
@@ -181,7 +180,6 @@
181
  " \"whole\": strawberry_whole_images\n",
182
  "}\n",
183
  "\n",
184
- "# Combine data\n",
185
  "X = np.concatenate([strawberry_hulled_images, strawberry_sliced_images, strawberry_whole_images], axis=0)\n",
186
  "y = (\n",
187
  " ['hulled'] * len(strawberry_hulled_images) +\n",
@@ -189,17 +187,14 @@
189
  " ['whole'] * len(strawberry_whole_images)\n",
190
  ")\n",
191
  "\n",
192
- "# Normalize and convert to torch tensors\n",
193
  "X = X.astype(np.float32) / 255.0\n",
194
- "X = np.transpose(X, (0, 3, 1, 2)) # (N, C, H, W)\n",
195
  "X_tensor = torch.tensor(X)\n",
196
  "\n",
197
- "# Encode labels\n",
198
  "le = LabelEncoder()\n",
199
  "y_encoded = le.fit_transform(y)\n",
200
  "y_tensor = torch.tensor(y_encoded)\n",
201
  "\n",
202
- "# Train/val/test split\n",
203
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
204
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
205
  ]
@@ -215,11 +210,9 @@
215
  "\n",
216
  "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
217
  "\n",
218
- "# Combine original and augmented data\n",
219
  "X_train_combined = torch.cat([X_train, X_augmented])\n",
220
  "y_train_combined = torch.cat([y_train, y_augmented])\n",
221
  "\n",
222
- "\n",
223
  "train_dataset = TensorDataset(X_train_combined, y_train_combined)\n",
224
  "val_dataset = TensorDataset(X_val, y_val)\n",
225
  "test_dataset = TensorDataset(X_test, y_test)\n",
@@ -236,9 +229,9 @@
236
  "metadata": {},
237
  "outputs": [],
238
  "source": [
239
- "print(f\"🔢 Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
240
- "print(f\"🔢 Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
241
- "print(f\"🔢 Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
242
  ]
243
  },
244
  {
@@ -341,7 +334,6 @@
341
  " val_accuracy = val_correct / val_total\n",
342
  " validation_loss = criterion(model(val_x), val_y).item()\n",
343
  "\n",
344
- " # After calculating val_accuracy\n",
345
  " val_losses.append(validation_loss)\n",
346
  " val_accs.append(val_accuracy)\n",
347
  "\n",
@@ -449,29 +441,25 @@
449
  "source": [
450
  "all_preds = np.array(all_preds)\n",
451
  "all_targets = np.array(all_targets)\n",
452
- "all_images = torch.stack(all_images) # shape: [N, C, H, W]\n",
453
  "\n",
454
- "# Per class FP and FN\n",
455
  "for class_idx, class_name in enumerate(target_names):\n",
456
- " print(f\"\\n🔍 Showing False Negatives and False Positives for class: {class_name}\")\n",
457
- "\n",
458
- " # False Negatives: True label is class_idx, but predicted something else\n",
459
  " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
460
- " # False Positives: Predicted class_idx, but true label is different\n",
461
  " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
462
  "\n",
463
  " def show_images(indices, title, max_images=5):\n",
464
  " num = min(len(indices), max_images)\n",
465
  " if num == 0:\n",
466
- " print(f\"No {title} samples.\")\n",
467
  " return\n",
468
  "\n",
469
  " plt.figure(figsize=(12, 2))\n",
470
  " for i, idx in enumerate(indices[:num]):\n",
471
  " img = all_images[idx]\n",
472
- " img = img.permute(1, 2, 0).numpy() # [C, H, W] → [H, W, C]\n",
473
  " plt.subplot(1, num, i + 1)\n",
474
- " plt.imshow((img - img.min()) / (img.max() - img.min())) # normalize to [0,1] for display\n",
475
  " plt.axis('off')\n",
476
  " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
477
  " plt.suptitle(f\"{title} for {class_name}\")\n",
@@ -498,29 +486,25 @@
498
  " activations[name] = output.detach().cpu()\n",
499
  " return hook\n",
500
  "\n",
501
- " # Register hooks for all layers in model.features\n",
502
  " hooks = []\n",
503
  " for i in range(len(model.features)):\n",
504
  " layer = model.features[i]\n",
505
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
506
  "\n",
507
  " with torch.no_grad():\n",
508
- " _ = model(image_tensor.unsqueeze(0)) # Add batch dimension: [1, 3, 224, 224]\n",
509
  "\n",
510
  " for h in hooks:\n",
511
  " h.remove()\n",
512
  "\n",
513
  " for layer_name, fmap in activations.items():\n",
514
- " fmap = fmap.squeeze(0) # Shape: [C, H, W]\n",
515
  "\n",
516
- " # Compute mean activation per channel\n",
517
- " channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
518
  "\n",
519
- " # Get indices of top-k channels\n",
520
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
521
  " top_indices = topk.indices\n",
522
  "\n",
523
- " # Plot top-k channels\n",
524
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
525
  " for idx, ch in enumerate(top_indices):\n",
526
  " plt.subplot(1, max_channels, idx + 1)\n",
@@ -553,14 +537,12 @@
553
  "\n",
554
  "img = Image.open(\"dataset/Strawberry_512/Whole/image_0017.jpg\").convert(\"RGB\")\n",
555
  "\n",
556
- "# Preprocessing (must match model requirements)\n",
557
  "transform = transforms.Compose([\n",
558
  " transforms.Resize((224, 224)),\n",
559
  " transforms.ToTensor()\n",
560
  "])\n",
561
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
562
  "\n",
563
- "# Visualize feature maps\n",
564
  "visualize_channels(model, img_tensor, max_channels=16)\n"
565
  ]
566
  },
@@ -574,14 +556,12 @@
574
  "\n",
575
  "img = Image.open(\"dataset/Strawberry_512/Hulled/image_0001.jpg\").convert(\"RGB\")\n",
576
  "\n",
577
- "# Preprocessing (must match model requirements)\n",
578
  "transform = transforms.Compose([\n",
579
  " transforms.Resize((224, 224)),\n",
580
  " transforms.ToTensor()\n",
581
  "])\n",
582
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
583
  "\n",
584
- "# Visualize feature maps\n",
585
  "visualize_channels(model, img_tensor, max_channels=16)\n"
586
  ]
587
  },
@@ -595,15 +575,13 @@
595
  "\n",
596
  "img = Image.open(\"dataset/Strawberry_512/Sliced/image_0001.jpg\").convert(\"RGB\")\n",
597
  "\n",
598
- "# Preprocessing (must match model requirements)\n",
599
  "transform = transforms.Compose([\n",
600
  " transforms.Resize((224, 224)),\n",
601
  " transforms.ToTensor()\n",
602
  "])\n",
603
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
604
  "\n",
605
- "# Visualize feature maps\n",
606
- "visualize_channels(model, img_tensor, max_channels=16)\n"
607
  ]
608
  },
609
  {
 
59
  "def augment_rotations(X, y):\n",
60
  " X_aug = []\n",
61
  " y_aug = []\n",
62
+ " for k in [1, 2, 3]: \n",
63
+ " X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
64
  " X_aug.append(X_rot)\n",
65
+ " y_aug.append(y.clone()) \n",
66
  " return torch.cat(X_aug), torch.cat(y_aug)"
67
  ]
68
  },
 
124
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
125
  " plt.show()\n",
126
  "\n",
 
127
  "for class_name, image_array in datasets.items():\n",
128
  " show_random_samples(image_array, class_name)\n"
129
  ]
 
139
  "\n",
140
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
141
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
142
+ " ax.label_outer() \n",
143
  "\n",
144
  "plt.tight_layout()\n",
145
  "plt.show()"
 
155
  "class_names = list(datasets.keys())\n",
156
  "num_classes = len(class_names)\n",
157
  "\n",
158
+ "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
159
  "\n",
160
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
161
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
 
180
  " \"whole\": strawberry_whole_images\n",
181
  "}\n",
182
  "\n",
 
183
  "X = np.concatenate([strawberry_hulled_images, strawberry_sliced_images, strawberry_whole_images], axis=0)\n",
184
  "y = (\n",
185
  " ['hulled'] * len(strawberry_hulled_images) +\n",
 
187
  " ['whole'] * len(strawberry_whole_images)\n",
188
  ")\n",
189
  "\n",
 
190
  "X = X.astype(np.float32) / 255.0\n",
191
+ "X = np.transpose(X, (0, 3, 1, 2)) \n",
192
  "X_tensor = torch.tensor(X)\n",
193
  "\n",
 
194
  "le = LabelEncoder()\n",
195
  "y_encoded = le.fit_transform(y)\n",
196
  "y_tensor = torch.tensor(y_encoded)\n",
197
  "\n",
 
198
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
199
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
200
  ]
 
210
  "\n",
211
  "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
212
  "\n",
 
213
  "X_train_combined = torch.cat([X_train, X_augmented])\n",
214
  "y_train_combined = torch.cat([y_train, y_augmented])\n",
215
  "\n",
 
216
  "train_dataset = TensorDataset(X_train_combined, y_train_combined)\n",
217
  "val_dataset = TensorDataset(X_val, y_val)\n",
218
  "test_dataset = TensorDataset(X_test, y_test)\n",
 
229
  "metadata": {},
230
  "outputs": [],
231
  "source": [
232
+ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
233
+ "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
234
+ "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
235
  ]
236
  },
237
  {
 
334
  " val_accuracy = val_correct / val_total\n",
335
  " validation_loss = criterion(model(val_x), val_y).item()\n",
336
  "\n",
 
337
  " val_losses.append(validation_loss)\n",
338
  " val_accs.append(val_accuracy)\n",
339
  "\n",
 
441
  "source": [
442
  "all_preds = np.array(all_preds)\n",
443
  "all_targets = np.array(all_targets)\n",
444
+ "all_images = torch.stack(all_images) \n",
445
  "\n",
 
446
  "for class_idx, class_name in enumerate(target_names):\n",
447
+ " print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n",
 
 
448
  " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
 
449
  " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
450
  "\n",
451
  " def show_images(indices, title, max_images=5):\n",
452
  " num = min(len(indices), max_images)\n",
453
  " if num == 0:\n",
454
+ " print(f\"No {title} samples.\")\n",
455
  " return\n",
456
  "\n",
457
  " plt.figure(figsize=(12, 2))\n",
458
  " for i, idx in enumerate(indices[:num]):\n",
459
  " img = all_images[idx]\n",
460
+ " img = img.permute(1, 2, 0).numpy()\n",
461
  " plt.subplot(1, num, i + 1)\n",
462
+ " plt.imshow((img - img.min()) / (img.max() - img.min()))\n",
463
  " plt.axis('off')\n",
464
  " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
465
  " plt.suptitle(f\"{title} for {class_name}\")\n",
 
486
  " activations[name] = output.detach().cpu()\n",
487
  " return hook\n",
488
  "\n",
 
489
  " hooks = []\n",
490
  " for i in range(len(model.features)):\n",
491
  " layer = model.features[i]\n",
492
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
493
  "\n",
494
  " with torch.no_grad():\n",
495
+ " _ = model(image_tensor.unsqueeze(0)) \n",
496
  "\n",
497
  " for h in hooks:\n",
498
  " h.remove()\n",
499
  "\n",
500
  " for layer_name, fmap in activations.items():\n",
501
+ " fmap = fmap.squeeze(0) \n",
502
  "\n",
503
+ " channel_scores = fmap.mean(dim=(1, 2))\n",
 
504
  "\n",
 
505
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
506
  " top_indices = topk.indices\n",
507
  "\n",
 
508
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
509
  " for idx, ch in enumerate(top_indices):\n",
510
  " plt.subplot(1, max_channels, idx + 1)\n",
 
537
  "\n",
538
  "img = Image.open(\"dataset/Strawberry_512/Whole/image_0017.jpg\").convert(\"RGB\")\n",
539
  "\n",
 
540
  "transform = transforms.Compose([\n",
541
  " transforms.Resize((224, 224)),\n",
542
  " transforms.ToTensor()\n",
543
  "])\n",
544
+ "img_tensor = transform(img) \n",
545
  "\n",
 
546
  "visualize_channels(model, img_tensor, max_channels=16)\n"
547
  ]
548
  },
 
556
  "\n",
557
  "img = Image.open(\"dataset/Strawberry_512/Hulled/image_0001.jpg\").convert(\"RGB\")\n",
558
  "\n",
 
559
  "transform = transforms.Compose([\n",
560
  " transforms.Resize((224, 224)),\n",
561
  " transforms.ToTensor()\n",
562
  "])\n",
563
+ "img_tensor = transform(img) \n",
564
  "\n",
 
565
  "visualize_channels(model, img_tensor, max_channels=16)\n"
566
  ]
567
  },
 
575
  "\n",
576
  "img = Image.open(\"dataset/Strawberry_512/Sliced/image_0001.jpg\").convert(\"RGB\")\n",
577
  "\n",
 
578
  "transform = transforms.Compose([\n",
579
  " transforms.Resize((224, 224)),\n",
580
  " transforms.ToTensor()\n",
581
  "])\n",
582
+ "img_tensor = transform(img) \n",
583
  "\n",
584
+ "visualize_channels(model, img_tensor, max_channels=16)"
 
585
  ]
586
  },
587
  {
scripts/CV/script_tomato.ipynb CHANGED
@@ -13,7 +13,6 @@
13
  "import matplotlib.pyplot as plt\n",
14
  "import random\n",
15
  "import torch\n",
16
- "import numpy as np\n",
17
  "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
18
  "from sklearn.preprocessing import LabelEncoder\n",
19
  "from sklearn.model_selection import train_test_split\n",
@@ -59,10 +58,10 @@
59
  "def augment_rotations(X, y):\n",
60
  " X_aug = []\n",
61
  " y_aug = []\n",
62
- " for k in [1, 2, 3]: # 90, 180, 270 degrees\n",
63
- " X_rot = torch.rot90(X, k=k, dims=[2, 3]) # rotate along H and W\n",
64
  " X_aug.append(X_rot)\n",
65
- " y_aug.append(y.clone()) # Same labels for rotated images\n",
66
  " return torch.cat(X_aug), torch.cat(y_aug)"
67
  ]
68
  },
@@ -103,8 +102,7 @@
103
  "metadata": {},
104
  "outputs": [],
105
  "source": [
106
- "import matplotlib.pyplot as plt\n",
107
- "import random\n",
108
  "datasets = {\n",
109
  " \"diced\": tomato_diced_images,\n",
110
  " \"vines\": tomato_vines_images,\n",
@@ -124,7 +122,6 @@
124
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
125
  " plt.show()\n",
126
  "\n",
127
- "# Display for each class\n",
128
  "for class_name, image_array in datasets.items():\n",
129
  " show_random_samples(image_array, class_name)\n"
130
  ]
@@ -140,7 +137,7 @@
140
  "\n",
141
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
142
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
143
- " ax.label_outer() # Hide x labels and tick labels for inner plots\n",
144
  "\n",
145
  "plt.tight_layout()\n",
146
  "plt.show()"
@@ -156,7 +153,7 @@
156
  "class_names = list(datasets.keys())\n",
157
  "num_classes = len(class_names)\n",
158
  "\n",
159
- "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) # 1 row, 4 columns\n",
160
  "\n",
161
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
162
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
@@ -175,20 +172,12 @@
175
  "metadata": {},
176
  "outputs": [],
177
  "source": [
178
- "import torch\n",
179
- "import numpy as np\n",
180
- "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
181
- "from sklearn.preprocessing import LabelEncoder\n",
182
- "from sklearn.model_selection import train_test_split\n",
183
- "from torchvision import transforms\n",
184
- "\n",
185
  "datasets = {\n",
186
  " \"diced\": tomato_diced_images,\n",
187
  " \"vines\": tomato_vines_images,\n",
188
  " \"whole\": tomato_whole_images\n",
189
  "}\n",
190
  "\n",
191
- "# Combine data\n",
192
  "X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n",
193
  "y = (\n",
194
  " ['diced'] * len(tomato_diced_images) +\n",
@@ -196,17 +185,14 @@
196
  " ['whole'] * len(tomato_whole_images)\n",
197
  ")\n",
198
  "\n",
199
- "# Normalize and convert to torch tensors\n",
200
  "X = X.astype(np.float32) / 255.0\n",
201
- "X = np.transpose(X, (0, 3, 1, 2)) # (N, C, H, W)\n",
202
  "X_tensor = torch.tensor(X)\n",
203
  "\n",
204
- "# Encode labels\n",
205
  "le = LabelEncoder()\n",
206
  "y_encoded = le.fit_transform(y)\n",
207
  "y_tensor = torch.tensor(y_encoded)\n",
208
  "\n",
209
- "# Train/val/test split\n",
210
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
211
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
212
  ]
@@ -222,17 +208,13 @@
222
  "\n",
223
  "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
224
  "\n",
225
- "# Combine original and augmented data\n",
226
  "X_train_combined = torch.cat([X_train, X_augmented])\n",
227
  "y_train_combined = torch.cat([y_train, y_augmented])\n",
228
  "\n",
229
- "# Create new training dataset and loader\n",
230
- "\n",
231
  "train_dataset = TensorDataset(X_train, y_train)\n",
232
  "val_dataset = TensorDataset(X_val, y_val)\n",
233
  "test_dataset = TensorDataset(X_test, y_test)\n",
234
  "\n",
235
- "# DataLoaders\n",
236
  "\n",
237
  "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
238
  "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
@@ -246,9 +228,9 @@
246
  "metadata": {},
247
  "outputs": [],
248
  "source": [
249
- "print(f\"🔢 Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
250
- "print(f\"🔢 Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
251
- "print(f\"🔢 Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
252
  ]
253
  },
254
  {
@@ -258,18 +240,9 @@
258
  "metadata": {},
259
  "outputs": [],
260
  "source": [
261
- "import torch.nn as nn\n",
262
- "import torch.nn.functional as F\n",
263
- "\n",
264
- "import torch.nn as nn\n",
265
- "import torchvision.models as models\n",
266
- "\n",
267
  "def get_efficientnet_model(num_classes):\n",
268
  " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
269
- "\n",
270
- " # Replace classifier head with custom head\n",
271
  " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
272
- "\n",
273
  " return model\n",
274
  "\n"
275
  ]
@@ -283,10 +256,10 @@
283
  "source": [
284
  "if torch.backends.mps.is_available():\n",
285
  " device = torch.device(\"mps\")\n",
286
- " print(\"Using MPS (Apple GPU)\")\n",
287
  "else:\n",
288
  " device = torch.device(\"cpu\")\n",
289
- " print(\"⚠️ MPS not available. Using CPU\")\n",
290
  "\n",
291
  "model = get_efficientnet_model(num_classes=3).to(device)\n",
292
  "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
@@ -329,7 +302,6 @@
329
  "\n",
330
  " total_train_loss += loss.item()\n",
331
  "\n",
332
- " # Track training accuracy\n",
333
  " pred_labels = preds.argmax(dim=1)\n",
334
  " train_correct += (pred_labels == batch_y).sum().item()\n",
335
  " train_total += batch_y.size(0)\n",
@@ -353,7 +325,6 @@
353
  " val_accuracy = val_correct / val_total\n",
354
  " validation_loss = criterion(model(val_x), val_y).item()\n",
355
  "\n",
356
- " # After calculating val_accuracy\n",
357
  " val_losses.append(validation_loss)\n",
358
  " val_accs.append(val_accuracy)\n",
359
  "\n",
@@ -381,13 +352,12 @@
381
  "metadata": {},
382
  "outputs": [],
383
  "source": [
384
- "import matplotlib.pyplot as plt\n",
385
  "\n",
386
  "epochs = range(1, len(train_losses) + 1)\n",
387
  "\n",
388
  "plt.figure(figsize=(12, 5))\n",
389
  "\n",
390
- "# Plot Loss\n",
391
  "plt.subplot(1, 2, 1)\n",
392
  "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
393
  "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
@@ -397,7 +367,6 @@
397
  "plt.legend()\n",
398
  "plt.grid(True)\n",
399
  "\n",
400
- "# Plot Accuracy\n",
401
  "plt.subplot(1, 2, 2)\n",
402
  "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
403
  "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
@@ -462,35 +431,28 @@
462
  "metadata": {},
463
  "outputs": [],
464
  "source": [
465
- "import torch\n",
466
- "import numpy as np\n",
467
- "import matplotlib.pyplot as plt\n",
468
  "\n",
469
  "all_preds = np.array(all_preds)\n",
470
  "all_targets = np.array(all_targets)\n",
471
- "all_images = torch.stack(all_images) # shape: [N, C, H, W]\n",
472
  "\n",
473
- "# Per class FP and FN\n",
474
  "for class_idx, class_name in enumerate(target_names):\n",
475
- " print(f\"\\n🔍 Showing False Negatives and False Positives for class: {class_name}\")\n",
476
- "\n",
477
- " # False Negatives: True label is class_idx, but predicted something else\n",
478
  " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
479
- " # False Positives: Predicted class_idx, but true label is different\n",
480
  " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
481
  "\n",
482
  " def show_images(indices, title, max_images=5):\n",
483
  " num = min(len(indices), max_images)\n",
484
  " if num == 0:\n",
485
- " print(f\"No {title} samples.\")\n",
486
  " return\n",
487
  "\n",
488
  " plt.figure(figsize=(12, 2))\n",
489
  " for i, idx in enumerate(indices[:num]):\n",
490
  " img = all_images[idx]\n",
491
- " img = img.permute(1, 2, 0).numpy() # [C, H, W] → [H, W, C]\n",
492
  " plt.subplot(1, num, i + 1)\n",
493
- " plt.imshow((img - img.min()) / (img.max() - img.min())) # normalize to [0,1] for display\n",
494
  " plt.axis('off')\n",
495
  " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
496
  " plt.suptitle(f\"{title} for {class_name}\")\n",
@@ -517,29 +479,25 @@
517
  " activations[name] = output.detach().cpu()\n",
518
  " return hook\n",
519
  "\n",
520
- " # Register hooks for all layers in model.features\n",
521
  " hooks = []\n",
522
  " for i in range(len(model.features)):\n",
523
  " layer = model.features[i]\n",
524
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
525
  "\n",
526
  " with torch.no_grad():\n",
527
- " _ = model(image_tensor.unsqueeze(0)) # Add batch dimension: [1, 3, 224, 224]\n",
528
  "\n",
529
  " for h in hooks:\n",
530
  " h.remove()\n",
531
  "\n",
532
  " for layer_name, fmap in activations.items():\n",
533
- " fmap = fmap.squeeze(0) # Shape: [C, H, W]\n",
534
  "\n",
535
- " # Compute mean activation per channel\n",
536
- " channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
537
  "\n",
538
- " # Get indices of top-k channels\n",
539
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
540
  " top_indices = topk.indices\n",
541
  "\n",
542
- " # Plot top-k channels\n",
543
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
544
  " for idx, ch in enumerate(top_indices):\n",
545
  " plt.subplot(1, max_channels, idx + 1)\n",
@@ -572,14 +530,12 @@
572
  "\n",
573
  "img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
574
  "\n",
575
- "# Preprocessing (must match model requirements)\n",
576
  "transform = transforms.Compose([\n",
577
  " transforms.Resize((224, 224)),\n",
578
  " transforms.ToTensor()\n",
579
  "])\n",
580
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
581
  "\n",
582
- "# Visualize feature maps\n",
583
  "visualize_channels(model, img_tensor, max_channels=16)\n"
584
  ]
585
  },
@@ -592,14 +548,12 @@
592
  "source": [
593
  "img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n",
594
  "\n",
595
- "# Preprocessing (must match model requirements)\n",
596
  "transform = transforms.Compose([\n",
597
  " transforms.Resize((224, 224)),\n",
598
  " transforms.ToTensor()\n",
599
  "])\n",
600
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
601
  "\n",
602
- "# Visualize feature maps\n",
603
  "visualize_channels(model, img_tensor, max_channels=16)\n"
604
  ]
605
  },
@@ -612,14 +566,12 @@
612
  "source": [
613
  "img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n",
614
  "\n",
615
- "# Preprocessing (must match model requirements)\n",
616
  "transform = transforms.Compose([\n",
617
  " transforms.Resize((224, 224)),\n",
618
  " transforms.ToTensor()\n",
619
  "])\n",
620
- "img_tensor = transform(img) # shape: [3, 224, 224]\n",
621
  "\n",
622
- "# Visualize feature maps\n",
623
  "visualize_channels(model, img_tensor, max_channels=16)\n"
624
  ]
625
  },
 
13
  "import matplotlib.pyplot as plt\n",
14
  "import random\n",
15
  "import torch\n",
 
16
  "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
17
  "from sklearn.preprocessing import LabelEncoder\n",
18
  "from sklearn.model_selection import train_test_split\n",
 
58
  "def augment_rotations(X, y):\n",
59
  " X_aug = []\n",
60
  " y_aug = []\n",
61
+ " for k in [1, 2, 3]: \n",
62
+ " X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
63
  " X_aug.append(X_rot)\n",
64
+ " y_aug.append(y.clone()) \n",
65
  " return torch.cat(X_aug), torch.cat(y_aug)"
66
  ]
67
  },
 
102
  "metadata": {},
103
  "outputs": [],
104
  "source": [
105
+ "\n",
 
106
  "datasets = {\n",
107
  " \"diced\": tomato_diced_images,\n",
108
  " \"vines\": tomato_vines_images,\n",
 
122
  " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
123
  " plt.show()\n",
124
  "\n",
 
125
  "for class_name, image_array in datasets.items():\n",
126
  " show_random_samples(image_array, class_name)\n"
127
  ]
 
137
  "\n",
138
  "for ax, (class_name, images) in zip(axes, datasets.items()):\n",
139
  " plot_rgb_histogram_subplot(ax, images, class_name)\n",
140
+ " ax.label_outer() \n",
141
  "\n",
142
  "plt.tight_layout()\n",
143
  "plt.show()"
 
153
  "class_names = list(datasets.keys())\n",
154
  "num_classes = len(class_names)\n",
155
  "\n",
156
+ "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
157
  "\n",
158
  "for i, (class_name, images) in enumerate(datasets.items()):\n",
159
  " avg_img = np.mean(images.astype(np.float32), axis=0)\n",
 
172
  "metadata": {},
173
  "outputs": [],
174
  "source": [
 
 
 
 
 
 
 
175
  "datasets = {\n",
176
  " \"diced\": tomato_diced_images,\n",
177
  " \"vines\": tomato_vines_images,\n",
178
  " \"whole\": tomato_whole_images\n",
179
  "}\n",
180
  "\n",
 
181
  "X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n",
182
  "y = (\n",
183
  " ['diced'] * len(tomato_diced_images) +\n",
 
185
  " ['whole'] * len(tomato_whole_images)\n",
186
  ")\n",
187
  "\n",
 
188
  "X = X.astype(np.float32) / 255.0\n",
189
+ "X = np.transpose(X, (0, 3, 1, 2))\n",
190
  "X_tensor = torch.tensor(X)\n",
191
  "\n",
 
192
  "le = LabelEncoder()\n",
193
  "y_encoded = le.fit_transform(y)\n",
194
  "y_tensor = torch.tensor(y_encoded)\n",
195
  "\n",
 
196
  "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
197
  "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
198
  ]
 
208
  "\n",
209
  "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
210
  "\n",
 
211
  "X_train_combined = torch.cat([X_train, X_augmented])\n",
212
  "y_train_combined = torch.cat([y_train, y_augmented])\n",
213
  "\n",
 
 
214
  "train_dataset = TensorDataset(X_train, y_train)\n",
215
  "val_dataset = TensorDataset(X_val, y_val)\n",
216
  "test_dataset = TensorDataset(X_test, y_test)\n",
217
  "\n",
 
218
  "\n",
219
  "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
220
  "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
 
228
  "metadata": {},
229
  "outputs": [],
230
  "source": [
231
+ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
232
+ "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
233
+ "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
234
  ]
235
  },
236
  {
 
240
  "metadata": {},
241
  "outputs": [],
242
  "source": [
 
 
 
 
 
 
243
  "def get_efficientnet_model(num_classes):\n",
244
  " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
 
 
245
  " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
 
246
  " return model\n",
247
  "\n"
248
  ]
 
256
  "source": [
257
  "if torch.backends.mps.is_available():\n",
258
  " device = torch.device(\"mps\")\n",
259
+ " print(\"Using MPS (Apple GPU)\")\n",
260
  "else:\n",
261
  " device = torch.device(\"cpu\")\n",
262
+ " print(\"MPS not available. Using CPU\")\n",
263
  "\n",
264
  "model = get_efficientnet_model(num_classes=3).to(device)\n",
265
  "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
 
302
  "\n",
303
  " total_train_loss += loss.item()\n",
304
  "\n",
 
305
  " pred_labels = preds.argmax(dim=1)\n",
306
  " train_correct += (pred_labels == batch_y).sum().item()\n",
307
  " train_total += batch_y.size(0)\n",
 
325
  " val_accuracy = val_correct / val_total\n",
326
  " validation_loss = criterion(model(val_x), val_y).item()\n",
327
  "\n",
 
328
  " val_losses.append(validation_loss)\n",
329
  " val_accs.append(val_accuracy)\n",
330
  "\n",
 
352
  "metadata": {},
353
  "outputs": [],
354
  "source": [
355
+ "\n",
356
  "\n",
357
  "epochs = range(1, len(train_losses) + 1)\n",
358
  "\n",
359
  "plt.figure(figsize=(12, 5))\n",
360
  "\n",
 
361
  "plt.subplot(1, 2, 1)\n",
362
  "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
363
  "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
 
367
  "plt.legend()\n",
368
  "plt.grid(True)\n",
369
  "\n",
 
370
  "plt.subplot(1, 2, 2)\n",
371
  "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
372
  "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
 
431
  "metadata": {},
432
  "outputs": [],
433
  "source": [
 
 
 
434
  "\n",
435
  "all_preds = np.array(all_preds)\n",
436
  "all_targets = np.array(all_targets)\n",
437
+ "all_images = torch.stack(all_images) \n",
438
  "\n",
 
439
  "for class_idx, class_name in enumerate(target_names):\n",
440
+ " print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n",
 
 
441
  " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
 
442
  " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
443
  "\n",
444
  " def show_images(indices, title, max_images=5):\n",
445
  " num = min(len(indices), max_images)\n",
446
  " if num == 0:\n",
447
+ " print(f\"No {title} samples.\")\n",
448
  " return\n",
449
  "\n",
450
  " plt.figure(figsize=(12, 2))\n",
451
  " for i, idx in enumerate(indices[:num]):\n",
452
  " img = all_images[idx]\n",
453
+ " img = img.permute(1, 2, 0).numpy()\n",
454
  " plt.subplot(1, num, i + 1)\n",
455
+ " plt.imshow((img - img.min()) / (img.max() - img.min())) \n",
456
  " plt.axis('off')\n",
457
  " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
458
  " plt.suptitle(f\"{title} for {class_name}\")\n",
 
479
  " activations[name] = output.detach().cpu()\n",
480
  " return hook\n",
481
  "\n",
 
482
  " hooks = []\n",
483
  " for i in range(len(model.features)):\n",
484
  " layer = model.features[i]\n",
485
  " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
486
  "\n",
487
  " with torch.no_grad():\n",
488
+ " _ = model(image_tensor.unsqueeze(0)) \n",
489
  "\n",
490
  " for h in hooks:\n",
491
  " h.remove()\n",
492
  "\n",
493
  " for layer_name, fmap in activations.items():\n",
494
+ " fmap = fmap.squeeze(0) \n",
495
  "\n",
496
+ " channel_scores = fmap.mean(dim=(1, 2)) \n",
 
497
  "\n",
 
498
  " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
499
  " top_indices = topk.indices\n",
500
  "\n",
 
501
  " plt.figure(figsize=(max_channels * 2, 2.5))\n",
502
  " for idx, ch in enumerate(top_indices):\n",
503
  " plt.subplot(1, max_channels, idx + 1)\n",
 
530
  "\n",
531
  "img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
532
  "\n",
 
533
  "transform = transforms.Compose([\n",
534
  " transforms.Resize((224, 224)),\n",
535
  " transforms.ToTensor()\n",
536
  "])\n",
537
+ "img_tensor = transform(img) \n",
538
  "\n",
 
539
  "visualize_channels(model, img_tensor, max_channels=16)\n"
540
  ]
541
  },
 
548
  "source": [
549
  "img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n",
550
  "\n",
 
551
  "transform = transforms.Compose([\n",
552
  " transforms.Resize((224, 224)),\n",
553
  " transforms.ToTensor()\n",
554
  "])\n",
555
+ "img_tensor = transform(img) \n",
556
  "\n",
 
557
  "visualize_channels(model, img_tensor, max_channels=16)\n"
558
  ]
559
  },
 
566
  "source": [
567
  "img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n",
568
  "\n",
 
569
  "transform = transforms.Compose([\n",
570
  " transforms.Resize((224, 224)),\n",
571
  " transforms.ToTensor()\n",
572
  "])\n",
573
+ "img_tensor = transform(img) \n",
574
  "\n",
 
575
  "visualize_channels(model, img_tensor, max_channels=16)\n"
576
  ]
577
  },