Spaces:
Sleeping
Sleeping
sakshamlakhera
commited on
Commit
·
1265dde
1
Parent(s):
b274faf
fixing scripts
Browse files- scripts/CV/Part1.ipynb +17 -41
- scripts/CV/compression.ipynb +17 -35
- scripts/CV/script_onion.ipynb +14 -28
- scripts/CV/script_pear.ipynb +17 -43
- scripts/CV/script_strawberry.ipynb +21 -43
- scripts/CV/script_tomato.ipynb +24 -72
scripts/CV/Part1.ipynb
CHANGED
@@ -66,10 +66,10 @@
|
|
66 |
"def augment_rotations(X, y):\n",
|
67 |
" X_aug = []\n",
|
68 |
" y_aug = []\n",
|
69 |
-
" for k in [1, 2, 3]:
|
70 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
71 |
" X_aug.append(X_rot)\n",
|
72 |
-
" y_aug.append(y.clone())
|
73 |
" return torch.cat(X_aug), torch.cat(y_aug)\n"
|
74 |
]
|
75 |
},
|
@@ -165,7 +165,6 @@
|
|
165 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
166 |
" plt.show()\n",
|
167 |
"\n",
|
168 |
-
"# Display for each class\n",
|
169 |
"for class_name, image_array in datasets.items():\n",
|
170 |
" show_random_samples(image_array, class_name)\n"
|
171 |
]
|
@@ -189,7 +188,7 @@
|
|
189 |
"\n",
|
190 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
191 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
192 |
-
" ax.label_outer()
|
193 |
"\n",
|
194 |
"plt.tight_layout()\n",
|
195 |
"plt.show()\n"
|
@@ -305,7 +304,7 @@
|
|
305 |
"class_names = list(datasets.keys())\n",
|
306 |
"num_classes = len(class_names)\n",
|
307 |
"\n",
|
308 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
309 |
"\n",
|
310 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
311 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
@@ -371,7 +370,6 @@
|
|
371 |
"from sklearn.model_selection import train_test_split\n",
|
372 |
"from torchvision import transforms\n",
|
373 |
"\n",
|
374 |
-
"# Combine data\n",
|
375 |
"X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)\n",
|
376 |
"y = (\n",
|
377 |
" ['onion'] * len(onion_images) +\n",
|
@@ -380,16 +378,14 @@
|
|
380 |
" ['tomato'] * len(tomato_images)\n",
|
381 |
")\n",
|
382 |
"\n",
|
383 |
-
"# Normalizing image\n",
|
384 |
"X = X.astype(np.float32) / 255.0\n",
|
385 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
386 |
"X_tensor = torch.tensor(X)\n",
|
387 |
"\n",
|
388 |
"le = LabelEncoder()\n",
|
389 |
"y_encoded = le.fit_transform(y)\n",
|
390 |
"y_tensor = torch.tensor(y_encoded)\n",
|
391 |
"\n",
|
392 |
-
"# splitting data into 50:25:25 (train, validation, test)\n",
|
393 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
394 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
395 |
]
|
@@ -403,13 +399,10 @@
|
|
403 |
"source": [
|
404 |
"batch_size = 32\n",
|
405 |
"\n",
|
406 |
-
"# Create new training dataset and loader\n",
|
407 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
408 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
409 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
410 |
"\n",
|
411 |
-
"# DataLoaders\n",
|
412 |
-
"\n",
|
413 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
414 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
415 |
"test_loader = DataLoader(test_dataset, batch_size=batch_size)"
|
@@ -432,9 +425,9 @@
|
|
432 |
"metadata": {},
|
433 |
"outputs": [],
|
434 |
"source": [
|
435 |
-
"print(f\"
|
436 |
-
"print(f\"
|
437 |
-
"print(f\"
|
438 |
]
|
439 |
},
|
440 |
{
|
@@ -521,8 +514,6 @@
|
|
521 |
" optimizer.step()\n",
|
522 |
"\n",
|
523 |
" total_train_loss += loss.item()\n",
|
524 |
-
"\n",
|
525 |
-
" # Track training accuracy\n",
|
526 |
" pred_labels = preds.argmax(dim=1)\n",
|
527 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
528 |
" train_total += batch_y.size(0)\n",
|
@@ -546,7 +537,6 @@
|
|
546 |
" val_accuracy = val_correct / val_total\n",
|
547 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
548 |
"\n",
|
549 |
-
" # After calculating val_accuracy\n",
|
550 |
" val_losses.append(validation_loss)\n",
|
551 |
" val_accs.append(val_accuracy)\n",
|
552 |
"\n",
|
@@ -637,7 +627,7 @@
|
|
637 |
"\n",
|
638 |
"print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n",
|
639 |
"\n",
|
640 |
-
"target_names = le.classes_
|
641 |
"print(\"\\nClassification Report:\\n\")\n",
|
642 |
"print(classification_report(all_targets, all_preds, target_names=target_names))\n",
|
643 |
"\n",
|
@@ -726,7 +716,7 @@
|
|
726 |
" h.remove()\n",
|
727 |
"\n",
|
728 |
" for layer_name, fmap in activations.items():\n",
|
729 |
-
" fmap = fmap.squeeze(0)
|
730 |
" num_channels = min(fmap.shape[0], max_channels)\n",
|
731 |
"\n",
|
732 |
" plt.figure(figsize=(num_channels * 2, 2.5))\n",
|
@@ -758,29 +748,22 @@
|
|
758 |
" activations[name] = output.detach().cpu()\n",
|
759 |
" return hook\n",
|
760 |
"\n",
|
761 |
-
" # Register hooks for all layers in model.features\n",
|
762 |
" hooks = []\n",
|
763 |
" for i in range(len(model.features)):\n",
|
764 |
" layer = model.features[i]\n",
|
765 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
766 |
"\n",
|
767 |
" with torch.no_grad():\n",
|
768 |
-
" _ = model(image_tensor.unsqueeze(0))
|
769 |
"\n",
|
770 |
" for h in hooks:\n",
|
771 |
" h.remove()\n",
|
772 |
"\n",
|
773 |
" for layer_name, fmap in activations.items():\n",
|
774 |
-
" fmap = fmap.squeeze(0)
|
775 |
-
"\n",
|
776 |
-
" # Compute mean activation per channel\n",
|
777 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
778 |
-
"\n",
|
779 |
-
" # Get indices of top-k channels\n",
|
780 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
781 |
" top_indices = topk.indices\n",
|
782 |
-
"\n",
|
783 |
-
" # Plot top-k channels\n",
|
784 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
785 |
" for idx, ch in enumerate(top_indices):\n",
|
786 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
@@ -821,14 +804,11 @@
|
|
821 |
"\n",
|
822 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
823 |
"\n",
|
824 |
-
"# Preprocessing (must match model requirements)\n",
|
825 |
"transform = transforms.Compose([\n",
|
826 |
" transforms.Resize((224, 224)),\n",
|
827 |
" transforms.ToTensor()\n",
|
828 |
"])\n",
|
829 |
-
"img_tensor = transform(img)
|
830 |
-
"\n",
|
831 |
-
"# Visualize feature maps\n",
|
832 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
833 |
]
|
834 |
},
|
@@ -849,14 +829,12 @@
|
|
849 |
"source": [
|
850 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0089.jpg\").convert(\"RGB\")\n",
|
851 |
"\n",
|
852 |
-
"# Preprocessing (must match model requirements)\n",
|
853 |
"transform = transforms.Compose([\n",
|
854 |
" transforms.Resize((224, 224)),\n",
|
855 |
" transforms.ToTensor()\n",
|
856 |
"])\n",
|
857 |
-
"img_tensor = transform(img)
|
858 |
"\n",
|
859 |
-
"# Visualize feature maps\n",
|
860 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
861 |
]
|
862 |
},
|
@@ -877,14 +855,12 @@
|
|
877 |
"source": [
|
878 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
879 |
"\n",
|
880 |
-
"# Preprocessing (must match model requirements)\n",
|
881 |
"transform = transforms.Compose([\n",
|
882 |
" transforms.Resize((224, 224)),\n",
|
883 |
" transforms.ToTensor()\n",
|
884 |
"])\n",
|
885 |
-
"img_tensor = transform(img)
|
886 |
"\n",
|
887 |
-
"# Visualize feature maps\n",
|
888 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
889 |
]
|
890 |
},
|
|
|
66 |
"def augment_rotations(X, y):\n",
|
67 |
" X_aug = []\n",
|
68 |
" y_aug = []\n",
|
69 |
+
" for k in [1, 2, 3]: \n",
|
70 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3])\n",
|
71 |
" X_aug.append(X_rot)\n",
|
72 |
+
" y_aug.append(y.clone())\n",
|
73 |
" return torch.cat(X_aug), torch.cat(y_aug)\n"
|
74 |
]
|
75 |
},
|
|
|
165 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
166 |
" plt.show()\n",
|
167 |
"\n",
|
|
|
168 |
"for class_name, image_array in datasets.items():\n",
|
169 |
" show_random_samples(image_array, class_name)\n"
|
170 |
]
|
|
|
188 |
"\n",
|
189 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
190 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
191 |
+
" ax.label_outer()\n",
|
192 |
"\n",
|
193 |
"plt.tight_layout()\n",
|
194 |
"plt.show()\n"
|
|
|
304 |
"class_names = list(datasets.keys())\n",
|
305 |
"num_classes = len(class_names)\n",
|
306 |
"\n",
|
307 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))\n",
|
308 |
"\n",
|
309 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
310 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
370 |
"from sklearn.model_selection import train_test_split\n",
|
371 |
"from torchvision import transforms\n",
|
372 |
"\n",
|
|
|
373 |
"X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)\n",
|
374 |
"y = (\n",
|
375 |
" ['onion'] * len(onion_images) +\n",
|
|
|
378 |
" ['tomato'] * len(tomato_images)\n",
|
379 |
")\n",
|
380 |
"\n",
|
|
|
381 |
"X = X.astype(np.float32) / 255.0\n",
|
382 |
+
"X = np.transpose(X, (0, 3, 1, 2)) \n",
|
383 |
"X_tensor = torch.tensor(X)\n",
|
384 |
"\n",
|
385 |
"le = LabelEncoder()\n",
|
386 |
"y_encoded = le.fit_transform(y)\n",
|
387 |
"y_tensor = torch.tensor(y_encoded)\n",
|
388 |
"\n",
|
|
|
389 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
390 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
391 |
]
|
|
|
399 |
"source": [
|
400 |
"batch_size = 32\n",
|
401 |
"\n",
|
|
|
402 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
403 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
404 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
405 |
"\n",
|
|
|
|
|
406 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
407 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
408 |
"test_loader = DataLoader(test_dataset, batch_size=batch_size)"
|
|
|
425 |
"metadata": {},
|
426 |
"outputs": [],
|
427 |
"source": [
|
428 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
429 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
430 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
431 |
]
|
432 |
},
|
433 |
{
|
|
|
514 |
" optimizer.step()\n",
|
515 |
"\n",
|
516 |
" total_train_loss += loss.item()\n",
|
|
|
|
|
517 |
" pred_labels = preds.argmax(dim=1)\n",
|
518 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
519 |
" train_total += batch_y.size(0)\n",
|
|
|
537 |
" val_accuracy = val_correct / val_total\n",
|
538 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
539 |
"\n",
|
|
|
540 |
" val_losses.append(validation_loss)\n",
|
541 |
" val_accs.append(val_accuracy)\n",
|
542 |
"\n",
|
|
|
627 |
"\n",
|
628 |
"print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n",
|
629 |
"\n",
|
630 |
+
"target_names = le.classes_ \n",
|
631 |
"print(\"\\nClassification Report:\\n\")\n",
|
632 |
"print(classification_report(all_targets, all_preds, target_names=target_names))\n",
|
633 |
"\n",
|
|
|
716 |
" h.remove()\n",
|
717 |
"\n",
|
718 |
" for layer_name, fmap in activations.items():\n",
|
719 |
+
" fmap = fmap.squeeze(0)\n",
|
720 |
" num_channels = min(fmap.shape[0], max_channels)\n",
|
721 |
"\n",
|
722 |
" plt.figure(figsize=(num_channels * 2, 2.5))\n",
|
|
|
748 |
" activations[name] = output.detach().cpu()\n",
|
749 |
" return hook\n",
|
750 |
"\n",
|
|
|
751 |
" hooks = []\n",
|
752 |
" for i in range(len(model.features)):\n",
|
753 |
" layer = model.features[i]\n",
|
754 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
755 |
"\n",
|
756 |
" with torch.no_grad():\n",
|
757 |
+
" _ = model(image_tensor.unsqueeze(0))\n",
|
758 |
"\n",
|
759 |
" for h in hooks:\n",
|
760 |
" h.remove()\n",
|
761 |
"\n",
|
762 |
" for layer_name, fmap in activations.items():\n",
|
763 |
+
" fmap = fmap.squeeze(0)\n",
|
764 |
+
" channel_scores = fmap.mean(dim=(1, 2)) \n",
|
|
|
|
|
|
|
|
|
765 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
766 |
" top_indices = topk.indices\n",
|
|
|
|
|
767 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
768 |
" for idx, ch in enumerate(top_indices):\n",
|
769 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
804 |
"\n",
|
805 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
806 |
"\n",
|
|
|
807 |
"transform = transforms.Compose([\n",
|
808 |
" transforms.Resize((224, 224)),\n",
|
809 |
" transforms.ToTensor()\n",
|
810 |
"])\n",
|
811 |
+
"img_tensor = transform(img)\n",
|
|
|
|
|
812 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
813 |
]
|
814 |
},
|
|
|
829 |
"source": [
|
830 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0089.jpg\").convert(\"RGB\")\n",
|
831 |
"\n",
|
|
|
832 |
"transform = transforms.Compose([\n",
|
833 |
" transforms.Resize((224, 224)),\n",
|
834 |
" transforms.ToTensor()\n",
|
835 |
"])\n",
|
836 |
+
"img_tensor = transform(img)\n",
|
837 |
"\n",
|
|
|
838 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
839 |
]
|
840 |
},
|
|
|
855 |
"source": [
|
856 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
857 |
"\n",
|
|
|
858 |
"transform = transforms.Compose([\n",
|
859 |
" transforms.Resize((224, 224)),\n",
|
860 |
" transforms.ToTensor()\n",
|
861 |
"])\n",
|
862 |
+
"img_tensor = transform(img)\n",
|
863 |
"\n",
|
|
|
864 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
865 |
]
|
866 |
},
|
scripts/CV/compression.ipynb
CHANGED
@@ -18,8 +18,8 @@
|
|
18 |
"import os\n",
|
19 |
"from PIL import Image, ImageOps\n",
|
20 |
"\n",
|
21 |
-
"input_root = 'Tomato'
|
22 |
-
"output_root = 'Tomato_512'
|
23 |
"os.makedirs(output_root, exist_ok=True)\n",
|
24 |
"\n",
|
25 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
@@ -27,14 +27,12 @@
|
|
27 |
" with Image.open(input_path) as img:\n",
|
28 |
" img = img.convert(\"RGB\")\n",
|
29 |
"\n",
|
30 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
31 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
32 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
33 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
34 |
" except Exception as e:\n",
|
35 |
-
" print(f\"
|
36 |
"\n",
|
37 |
-
"# Recursively walk through input_root\n",
|
38 |
"for root, _, files in os.walk(input_root):\n",
|
39 |
" for file in files:\n",
|
40 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
@@ -43,7 +41,7 @@
|
|
43 |
" output_path = os.path.join(output_root, rel_path)\n",
|
44 |
" process_image(input_path, output_path)\n",
|
45 |
"\n",
|
46 |
-
"print(\"
|
47 |
]
|
48 |
},
|
49 |
{
|
@@ -56,23 +54,20 @@
|
|
56 |
"import os\n",
|
57 |
"from PIL import Image, ImageOps\n",
|
58 |
"\n",
|
59 |
-
"input_root = 'Onion'
|
60 |
-
"output_root = 'Onion_512'
|
61 |
"os.makedirs(output_root, exist_ok=True)\n",
|
62 |
"\n",
|
63 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
64 |
" try:\n",
|
65 |
" with Image.open(input_path) as img:\n",
|
66 |
" img = img.convert(\"RGB\")\n",
|
67 |
-
"\n",
|
68 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
69 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
70 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
71 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
72 |
" except Exception as e:\n",
|
73 |
-
" print(f\"
|
74 |
"\n",
|
75 |
-
"# Recursively walk through input_root\n",
|
76 |
"for root, _, files in os.walk(input_root):\n",
|
77 |
" for file in files:\n",
|
78 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
@@ -81,7 +76,7 @@
|
|
81 |
" output_path = os.path.join(output_root, rel_path)\n",
|
82 |
" process_image(input_path, output_path)\n",
|
83 |
"\n",
|
84 |
-
"print(\"
|
85 |
]
|
86 |
},
|
87 |
{
|
@@ -94,23 +89,20 @@
|
|
94 |
"import os\n",
|
95 |
"from PIL import Image, ImageOps\n",
|
96 |
"\n",
|
97 |
-
"input_root = 'Pear'
|
98 |
-
"output_root = 'Pear_512'
|
99 |
"os.makedirs(output_root, exist_ok=True)\n",
|
100 |
"\n",
|
101 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
102 |
" try:\n",
|
103 |
" with Image.open(input_path) as img:\n",
|
104 |
" img = img.convert(\"RGB\")\n",
|
105 |
-
"\n",
|
106 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
107 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
108 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
109 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
110 |
" except Exception as e:\n",
|
111 |
-
" print(f\"
|
112 |
"\n",
|
113 |
-
"# Recursively walk through input_root\n",
|
114 |
"for root, _, files in os.walk(input_root):\n",
|
115 |
" for file in files:\n",
|
116 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
@@ -119,7 +111,7 @@
|
|
119 |
" output_path = os.path.join(output_root, rel_path)\n",
|
120 |
" process_image(input_path, output_path)\n",
|
121 |
"\n",
|
122 |
-
"print(\"
|
123 |
]
|
124 |
},
|
125 |
{
|
@@ -132,23 +124,21 @@
|
|
132 |
"import os\n",
|
133 |
"from PIL import Image, ImageOps\n",
|
134 |
"\n",
|
135 |
-
"input_root = 'Strawberry'
|
136 |
-
"output_root = 'Strawberry_512'
|
137 |
"os.makedirs(output_root, exist_ok=True)\n",
|
138 |
"\n",
|
139 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
140 |
" try:\n",
|
141 |
" with Image.open(input_path) as img:\n",
|
142 |
" img = img.convert(\"RGB\")\n",
|
143 |
-
"\n",
|
144 |
-
" # Resize while preserving aspect ratio, then pad to 512x512\n",
|
145 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
146 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
147 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
148 |
" except Exception as e:\n",
|
149 |
-
" print(f\"
|
|
|
150 |
"\n",
|
151 |
-
"# Recursively walk through input_root\n",
|
152 |
"for root, _, files in os.walk(input_root):\n",
|
153 |
" for file in files:\n",
|
154 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
@@ -157,16 +147,8 @@
|
|
157 |
" output_path = os.path.join(output_root, rel_path)\n",
|
158 |
" process_image(input_path, output_path)\n",
|
159 |
"\n",
|
160 |
-
"print(\"
|
161 |
]
|
162 |
-
},
|
163 |
-
{
|
164 |
-
"cell_type": "code",
|
165 |
-
"execution_count": null,
|
166 |
-
"id": "fd49ae48",
|
167 |
-
"metadata": {},
|
168 |
-
"outputs": [],
|
169 |
-
"source": []
|
170 |
}
|
171 |
],
|
172 |
"metadata": {
|
|
|
18 |
"import os\n",
|
19 |
"from PIL import Image, ImageOps\n",
|
20 |
"\n",
|
21 |
+
"input_root = 'Tomato' \n",
|
22 |
+
"output_root = 'Tomato_512' \n",
|
23 |
"os.makedirs(output_root, exist_ok=True)\n",
|
24 |
"\n",
|
25 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
|
|
27 |
" with Image.open(input_path) as img:\n",
|
28 |
" img = img.convert(\"RGB\")\n",
|
29 |
"\n",
|
|
|
30 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
31 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
32 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
33 |
" except Exception as e:\n",
|
34 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
35 |
"\n",
|
|
|
36 |
"for root, _, files in os.walk(input_root):\n",
|
37 |
" for file in files:\n",
|
38 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
41 |
" output_path = os.path.join(output_root, rel_path)\n",
|
42 |
" process_image(input_path, output_path)\n",
|
43 |
"\n",
|
44 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
45 |
]
|
46 |
},
|
47 |
{
|
|
|
54 |
"import os\n",
|
55 |
"from PIL import Image, ImageOps\n",
|
56 |
"\n",
|
57 |
+
"input_root = 'Onion' \n",
|
58 |
+
"output_root = 'Onion_512' \n",
|
59 |
"os.makedirs(output_root, exist_ok=True)\n",
|
60 |
"\n",
|
61 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
62 |
" try:\n",
|
63 |
" with Image.open(input_path) as img:\n",
|
64 |
" img = img.convert(\"RGB\")\n",
|
|
|
|
|
65 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
66 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
67 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
68 |
" except Exception as e:\n",
|
69 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
70 |
"\n",
|
|
|
71 |
"for root, _, files in os.walk(input_root):\n",
|
72 |
" for file in files:\n",
|
73 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
76 |
" output_path = os.path.join(output_root, rel_path)\n",
|
77 |
" process_image(input_path, output_path)\n",
|
78 |
"\n",
|
79 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
80 |
]
|
81 |
},
|
82 |
{
|
|
|
89 |
"import os\n",
|
90 |
"from PIL import Image, ImageOps\n",
|
91 |
"\n",
|
92 |
+
"input_root = 'Pear' \n",
|
93 |
+
"output_root = 'Pear_512' \n",
|
94 |
"os.makedirs(output_root, exist_ok=True)\n",
|
95 |
"\n",
|
96 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
97 |
" try:\n",
|
98 |
" with Image.open(input_path) as img:\n",
|
99 |
" img = img.convert(\"RGB\")\n",
|
|
|
|
|
100 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
101 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
102 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
103 |
" except Exception as e:\n",
|
104 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
105 |
"\n",
|
|
|
106 |
"for root, _, files in os.walk(input_root):\n",
|
107 |
" for file in files:\n",
|
108 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
111 |
" output_path = os.path.join(output_root, rel_path)\n",
|
112 |
" process_image(input_path, output_path)\n",
|
113 |
"\n",
|
114 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
115 |
]
|
116 |
},
|
117 |
{
|
|
|
124 |
"import os\n",
|
125 |
"from PIL import Image, ImageOps\n",
|
126 |
"\n",
|
127 |
+
"input_root = 'Strawberry' \n",
|
128 |
+
"output_root = 'Strawberry_512' \n",
|
129 |
"os.makedirs(output_root, exist_ok=True)\n",
|
130 |
"\n",
|
131 |
"def process_image(input_path, output_path, size=(512, 512)):\n",
|
132 |
" try:\n",
|
133 |
" with Image.open(input_path) as img:\n",
|
134 |
" img = img.convert(\"RGB\")\n",
|
|
|
|
|
135 |
" img = ImageOps.fit(img, size, Image.LANCZOS, centering=(0.5, 0.5))\n",
|
136 |
" os.makedirs(os.path.dirname(output_path), exist_ok=True)\n",
|
137 |
" img.save(output_path, \"JPEG\", quality=95)\n",
|
138 |
" except Exception as e:\n",
|
139 |
+
" print(f\"Error processing {input_path}: {e}\")\n",
|
140 |
+
"\n",
|
141 |
"\n",
|
|
|
142 |
"for root, _, files in os.walk(input_root):\n",
|
143 |
" for file in files:\n",
|
144 |
" if file.lower().endswith((\".jpg\", \".jpeg\")):\n",
|
|
|
147 |
" output_path = os.path.join(output_root, rel_path)\n",
|
148 |
" process_image(input_path, output_path)\n",
|
149 |
"\n",
|
150 |
+
"print(\"All images processed and saved in\", output_root)\n"
|
151 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
}
|
153 |
],
|
154 |
"metadata": {
|
scripts/CV/script_onion.ipynb
CHANGED
@@ -59,10 +59,10 @@
|
|
59 |
"def augment_rotations(X, y):\n",
|
60 |
" X_aug = []\n",
|
61 |
" y_aug = []\n",
|
62 |
-
" for k in [1, 2, 3]
|
63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
64 |
" X_aug.append(X_rot)\n",
|
65 |
-
" y_aug.append(y.clone())
|
66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
67 |
]
|
68 |
},
|
@@ -120,7 +120,6 @@
|
|
120 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
121 |
" plt.show()\n",
|
122 |
"\n",
|
123 |
-
"# Display for each class\n",
|
124 |
"for class_name, image_array in datasets.items():\n",
|
125 |
" show_random_samples(image_array, class_name)\n"
|
126 |
]
|
@@ -136,7 +135,7 @@
|
|
136 |
"\n",
|
137 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
138 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
139 |
-
" ax.label_outer()
|
140 |
"\n",
|
141 |
"plt.tight_layout()\n",
|
142 |
"plt.show()"
|
@@ -152,7 +151,7 @@
|
|
152 |
"class_names = list(datasets.keys())\n",
|
153 |
"num_classes = len(class_names)\n",
|
154 |
"\n",
|
155 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
156 |
"\n",
|
157 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
158 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
@@ -209,7 +208,6 @@
|
|
209 |
"\n",
|
210 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
211 |
"\n",
|
212 |
-
"# Combine original and augmented data\n",
|
213 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
214 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
215 |
"\n",
|
@@ -230,9 +228,9 @@
|
|
230 |
"metadata": {},
|
231 |
"outputs": [],
|
232 |
"source": [
|
233 |
-
"print(f\"
|
234 |
-
"print(f\"
|
235 |
-
"print(f\"
|
236 |
]
|
237 |
},
|
238 |
{
|
@@ -480,29 +478,24 @@
|
|
480 |
" activations[name] = output.detach().cpu()\n",
|
481 |
" return hook\n",
|
482 |
"\n",
|
483 |
-
" # Register hooks for all layers in model.features\n",
|
484 |
" hooks = []\n",
|
485 |
" for i in range(len(model.features)):\n",
|
486 |
" layer = model.features[i]\n",
|
487 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
488 |
"\n",
|
489 |
" with torch.no_grad():\n",
|
490 |
-
" _ = model(image_tensor.unsqueeze(0))
|
491 |
"\n",
|
492 |
" for h in hooks:\n",
|
493 |
" h.remove()\n",
|
494 |
"\n",
|
495 |
" for layer_name, fmap in activations.items():\n",
|
496 |
-
" fmap = fmap.squeeze(0)
|
|
|
497 |
"\n",
|
498 |
-
" # Compute mean activation per channel\n",
|
499 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
500 |
-
"\n",
|
501 |
-
" # Get indices of top-k channels\n",
|
502 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
503 |
" top_indices = topk.indices\n",
|
504 |
"\n",
|
505 |
-
" # Plot top-k channels\n",
|
506 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
507 |
" for idx, ch in enumerate(top_indices):\n",
|
508 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
@@ -535,14 +528,12 @@
|
|
535 |
"\n",
|
536 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
537 |
"\n",
|
538 |
-
"# Preprocessing (must match model requirements)\n",
|
539 |
"transform = transforms.Compose([\n",
|
540 |
" transforms.Resize((224, 224)),\n",
|
541 |
" transforms.ToTensor()\n",
|
542 |
"])\n",
|
543 |
-
"img_tensor = transform(img)
|
544 |
"\n",
|
545 |
-
"# Visualize feature maps\n",
|
546 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
547 |
]
|
548 |
},
|
@@ -556,14 +547,11 @@
|
|
556 |
"\n",
|
557 |
"img = Image.open(\"dataset/Onion_512/Halved/image_0880.jpg\").convert(\"RGB\")\n",
|
558 |
"\n",
|
559 |
-
"# Preprocessing (must match model requirements)\n",
|
560 |
"transform = transforms.Compose([\n",
|
561 |
" transforms.Resize((224, 224)),\n",
|
562 |
" transforms.ToTensor()\n",
|
563 |
"])\n",
|
564 |
-
"img_tensor = transform(img)
|
565 |
-
"\n",
|
566 |
-
"# Visualize feature maps\n",
|
567 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
568 |
]
|
569 |
},
|
@@ -576,14 +564,12 @@
|
|
576 |
"source": [
|
577 |
"img = Image.open(\"dataset/Onion_512/Sliced/image_0772.jpg\").convert(\"RGB\")\n",
|
578 |
"\n",
|
579 |
-
"# Preprocessing (must match model requirements)\n",
|
580 |
"transform = transforms.Compose([\n",
|
581 |
" transforms.Resize((224, 224)),\n",
|
582 |
" transforms.ToTensor()\n",
|
583 |
"])\n",
|
584 |
-
"img_tensor = transform(img)
|
585 |
"\n",
|
586 |
-
"# Visualize feature maps\n",
|
587 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
588 |
]
|
589 |
},
|
|
|
59 |
"def augment_rotations(X, y):\n",
|
60 |
" X_aug = []\n",
|
61 |
" y_aug = []\n",
|
62 |
+
" for k in [1, 2, 3]:\n",
|
63 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3])\n",
|
64 |
" X_aug.append(X_rot)\n",
|
65 |
+
" y_aug.append(y.clone()) \n",
|
66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
67 |
]
|
68 |
},
|
|
|
120 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
121 |
" plt.show()\n",
|
122 |
"\n",
|
|
|
123 |
"for class_name, image_array in datasets.items():\n",
|
124 |
" show_random_samples(image_array, class_name)\n"
|
125 |
]
|
|
|
135 |
"\n",
|
136 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
137 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
138 |
+
" ax.label_outer()\n",
|
139 |
"\n",
|
140 |
"plt.tight_layout()\n",
|
141 |
"plt.show()"
|
|
|
151 |
"class_names = list(datasets.keys())\n",
|
152 |
"num_classes = len(class_names)\n",
|
153 |
"\n",
|
154 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
155 |
"\n",
|
156 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
157 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
208 |
"\n",
|
209 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
210 |
"\n",
|
|
|
211 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
212 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
213 |
"\n",
|
|
|
228 |
"metadata": {},
|
229 |
"outputs": [],
|
230 |
"source": [
|
231 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
232 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
233 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
234 |
]
|
235 |
},
|
236 |
{
|
|
|
478 |
" activations[name] = output.detach().cpu()\n",
|
479 |
" return hook\n",
|
480 |
"\n",
|
|
|
481 |
" hooks = []\n",
|
482 |
" for i in range(len(model.features)):\n",
|
483 |
" layer = model.features[i]\n",
|
484 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
485 |
"\n",
|
486 |
" with torch.no_grad():\n",
|
487 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
488 |
"\n",
|
489 |
" for h in hooks:\n",
|
490 |
" h.remove()\n",
|
491 |
"\n",
|
492 |
" for layer_name, fmap in activations.items():\n",
|
493 |
+
" fmap = fmap.squeeze(0) \n",
|
494 |
+
" channel_scores = fmap.mean(dim=(1, 2))\n",
|
495 |
"\n",
|
|
|
|
|
|
|
|
|
496 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
497 |
" top_indices = topk.indices\n",
|
498 |
"\n",
|
|
|
499 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
500 |
" for idx, ch in enumerate(top_indices):\n",
|
501 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
528 |
"\n",
|
529 |
"img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n",
|
530 |
"\n",
|
|
|
531 |
"transform = transforms.Compose([\n",
|
532 |
" transforms.Resize((224, 224)),\n",
|
533 |
" transforms.ToTensor()\n",
|
534 |
"])\n",
|
535 |
+
"img_tensor = transform(img)\n",
|
536 |
"\n",
|
|
|
537 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
538 |
]
|
539 |
},
|
|
|
547 |
"\n",
|
548 |
"img = Image.open(\"dataset/Onion_512/Halved/image_0880.jpg\").convert(\"RGB\")\n",
|
549 |
"\n",
|
|
|
550 |
"transform = transforms.Compose([\n",
|
551 |
" transforms.Resize((224, 224)),\n",
|
552 |
" transforms.ToTensor()\n",
|
553 |
"])\n",
|
554 |
+
"img_tensor = transform(img)\n",
|
|
|
|
|
555 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
556 |
]
|
557 |
},
|
|
|
564 |
"source": [
|
565 |
"img = Image.open(\"dataset/Onion_512/Sliced/image_0772.jpg\").convert(\"RGB\")\n",
|
566 |
"\n",
|
|
|
567 |
"transform = transforms.Compose([\n",
|
568 |
" transforms.Resize((224, 224)),\n",
|
569 |
" transforms.ToTensor()\n",
|
570 |
"])\n",
|
571 |
+
"img_tensor = transform(img) \n",
|
572 |
"\n",
|
|
|
573 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
574 |
]
|
575 |
},
|
scripts/CV/script_pear.ipynb
CHANGED
@@ -59,10 +59,10 @@
|
|
59 |
"def augment_rotations(X, y):\n",
|
60 |
" X_aug = []\n",
|
61 |
" y_aug = []\n",
|
62 |
-
" for k in [1, 2, 3]:
|
63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
64 |
" X_aug.append(X_rot)\n",
|
65 |
-
" y_aug.append(y.clone())
|
66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
67 |
]
|
68 |
},
|
@@ -122,7 +122,6 @@
|
|
122 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
123 |
" plt.show()\n",
|
124 |
"\n",
|
125 |
-
"# Display for each class\n",
|
126 |
"for class_name, image_array in datasets.items():\n",
|
127 |
" show_random_samples(image_array, class_name)\n"
|
128 |
]
|
@@ -138,7 +137,7 @@
|
|
138 |
"\n",
|
139 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
140 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
141 |
-
" ax.label_outer()
|
142 |
"\n",
|
143 |
"plt.tight_layout()\n",
|
144 |
"plt.show()"
|
@@ -154,7 +153,7 @@
|
|
154 |
"class_names = list(datasets.keys())\n",
|
155 |
"num_classes = len(class_names)\n",
|
156 |
"\n",
|
157 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
158 |
"\n",
|
159 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
160 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
@@ -180,7 +179,6 @@
|
|
180 |
" \"whole\": pear_whole_images\n",
|
181 |
"}\n",
|
182 |
"\n",
|
183 |
-
"# Combine data\n",
|
184 |
"X = np.concatenate([pear_halved_images, pear_sliced_images, pear_whole_images], axis=0)\n",
|
185 |
"y = (\n",
|
186 |
" ['halved'] * len(pear_halved_images) +\n",
|
@@ -188,17 +186,14 @@
|
|
188 |
" ['whole'] * len(pear_whole_images)\n",
|
189 |
")\n",
|
190 |
"\n",
|
191 |
-
"# Normalize and convert to torch tensors\n",
|
192 |
"X = X.astype(np.float32) / 255.0\n",
|
193 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
194 |
"X_tensor = torch.tensor(X)\n",
|
195 |
"\n",
|
196 |
-
"# Encode labels\n",
|
197 |
"le = LabelEncoder()\n",
|
198 |
"y_encoded = le.fit_transform(y)\n",
|
199 |
"y_tensor = torch.tensor(y_encoded)\n",
|
200 |
"\n",
|
201 |
-
"# Train/val/test split\n",
|
202 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
203 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
204 |
]
|
@@ -215,7 +210,6 @@
|
|
215 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
216 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
217 |
"\n",
|
218 |
-
"# DataLoaders\n",
|
219 |
"batch_size = 32\n",
|
220 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
221 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
@@ -229,9 +223,9 @@
|
|
229 |
"metadata": {},
|
230 |
"outputs": [],
|
231 |
"source": [
|
232 |
-
"print(f\"
|
233 |
-
"print(f\"
|
234 |
-
"print(f\"
|
235 |
]
|
236 |
},
|
237 |
{
|
@@ -249,8 +243,6 @@
|
|
249 |
"\n",
|
250 |
"def get_efficientnet_model(num_classes):\n",
|
251 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
252 |
-
"\n",
|
253 |
-
" # Replace classifier head with custom head\n",
|
254 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
255 |
"\n",
|
256 |
" return model\n",
|
@@ -266,10 +258,10 @@
|
|
266 |
"source": [
|
267 |
"if torch.backends.mps.is_available():\n",
|
268 |
" device = torch.device(\"mps\")\n",
|
269 |
-
" print(\"
|
270 |
"else:\n",
|
271 |
" device = torch.device(\"cpu\")\n",
|
272 |
-
" print(\"
|
273 |
"\n",
|
274 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
275 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
@@ -312,7 +304,6 @@
|
|
312 |
"\n",
|
313 |
" total_train_loss += loss.item()\n",
|
314 |
"\n",
|
315 |
-
" # Track training accuracy\n",
|
316 |
" pred_labels = preds.argmax(dim=1)\n",
|
317 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
318 |
" train_total += batch_y.size(0)\n",
|
@@ -367,7 +358,6 @@
|
|
367 |
"\n",
|
368 |
"plt.figure(figsize=(12, 5))\n",
|
369 |
"\n",
|
370 |
-
"# Plot Loss\n",
|
371 |
"plt.subplot(1, 2, 1)\n",
|
372 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
373 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
@@ -377,7 +367,6 @@
|
|
377 |
"plt.legend()\n",
|
378 |
"plt.grid(True)\n",
|
379 |
"\n",
|
380 |
-
"# Plot Accuracy\n",
|
381 |
"plt.subplot(1, 2, 2)\n",
|
382 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
383 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
@@ -490,29 +479,24 @@
|
|
490 |
" activations[name] = output.detach().cpu()\n",
|
491 |
" return hook\n",
|
492 |
"\n",
|
493 |
-
" # Register hooks for all layers in model.features\n",
|
494 |
" hooks = []\n",
|
495 |
" for i in range(len(model.features)):\n",
|
496 |
" layer = model.features[i]\n",
|
497 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
498 |
"\n",
|
499 |
" with torch.no_grad():\n",
|
500 |
-
" _ = model(image_tensor.unsqueeze(0))
|
501 |
"\n",
|
502 |
" for h in hooks:\n",
|
503 |
" h.remove()\n",
|
504 |
"\n",
|
505 |
" for layer_name, fmap in activations.items():\n",
|
506 |
-
" fmap = fmap.squeeze(0)
|
507 |
-
"\n",
|
508 |
-
" # Compute mean activation per channel\n",
|
509 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
510 |
"\n",
|
511 |
-
" # Get indices of top-k channels\n",
|
512 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
513 |
" top_indices = topk.indices\n",
|
514 |
"\n",
|
515 |
-
" # Plot top-k channels\n",
|
516 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
517 |
" for idx, ch in enumerate(top_indices):\n",
|
518 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
@@ -545,14 +529,11 @@
|
|
545 |
"\n",
|
546 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
547 |
"\n",
|
548 |
-
"# Preprocessing (must match model requirements)\n",
|
549 |
"transform = transforms.Compose([\n",
|
550 |
" transforms.Resize((224, 224)),\n",
|
551 |
" transforms.ToTensor()\n",
|
552 |
"])\n",
|
553 |
-
"img_tensor = transform(img)
|
554 |
-
"\n",
|
555 |
-
"# Visualize feature maps\n",
|
556 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
557 |
]
|
558 |
},
|
@@ -565,14 +546,11 @@
|
|
565 |
"source": [
|
566 |
"img = Image.open(\"dataset/Pear_512/Halved/image_0578.jpg\").convert(\"RGB\")\n",
|
567 |
"\n",
|
568 |
-
"# Preprocessing (must match model requirements)\n",
|
569 |
"transform = transforms.Compose([\n",
|
570 |
" transforms.Resize((224, 224)),\n",
|
571 |
" transforms.ToTensor()\n",
|
572 |
"])\n",
|
573 |
-
"img_tensor = transform(img)
|
574 |
-
"\n",
|
575 |
-
"# Visualize feature maps\n",
|
576 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
577 |
]
|
578 |
},
|
@@ -584,15 +562,11 @@
|
|
584 |
"outputs": [],
|
585 |
"source": [
|
586 |
"img = Image.open(\"dataset/Pear_512/Sliced/image_0007.jpg\").convert(\"RGB\")\n",
|
587 |
-
"\n",
|
588 |
-
"# Preprocessing (must match model requirements)\n",
|
589 |
"transform = transforms.Compose([\n",
|
590 |
" transforms.Resize((224, 224)),\n",
|
591 |
" transforms.ToTensor()\n",
|
592 |
"])\n",
|
593 |
-
"img_tensor = transform(img)
|
594 |
-
"\n",
|
595 |
-
"# Visualize feature maps\n",
|
596 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
597 |
]
|
598 |
},
|
|
|
59 |
"def augment_rotations(X, y):\n",
|
60 |
" X_aug = []\n",
|
61 |
" y_aug = []\n",
|
62 |
+
" for k in [1, 2, 3]: \n",
|
63 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
|
64 |
" X_aug.append(X_rot)\n",
|
65 |
+
" y_aug.append(y.clone()) \n",
|
66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
67 |
]
|
68 |
},
|
|
|
122 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
123 |
" plt.show()\n",
|
124 |
"\n",
|
|
|
125 |
"for class_name, image_array in datasets.items():\n",
|
126 |
" show_random_samples(image_array, class_name)\n"
|
127 |
]
|
|
|
137 |
"\n",
|
138 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
139 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
140 |
+
" ax.label_outer() \n",
|
141 |
"\n",
|
142 |
"plt.tight_layout()\n",
|
143 |
"plt.show()"
|
|
|
153 |
"class_names = list(datasets.keys())\n",
|
154 |
"num_classes = len(class_names)\n",
|
155 |
"\n",
|
156 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
157 |
"\n",
|
158 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
159 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
179 |
" \"whole\": pear_whole_images\n",
|
180 |
"}\n",
|
181 |
"\n",
|
|
|
182 |
"X = np.concatenate([pear_halved_images, pear_sliced_images, pear_whole_images], axis=0)\n",
|
183 |
"y = (\n",
|
184 |
" ['halved'] * len(pear_halved_images) +\n",
|
|
|
186 |
" ['whole'] * len(pear_whole_images)\n",
|
187 |
")\n",
|
188 |
"\n",
|
|
|
189 |
"X = X.astype(np.float32) / 255.0\n",
|
190 |
+
"X = np.transpose(X, (0, 3, 1, 2)) \n",
|
191 |
"X_tensor = torch.tensor(X)\n",
|
192 |
"\n",
|
|
|
193 |
"le = LabelEncoder()\n",
|
194 |
"y_encoded = le.fit_transform(y)\n",
|
195 |
"y_tensor = torch.tensor(y_encoded)\n",
|
196 |
"\n",
|
|
|
197 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
198 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
199 |
]
|
|
|
210 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
211 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
212 |
"\n",
|
|
|
213 |
"batch_size = 32\n",
|
214 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
215 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
|
|
223 |
"metadata": {},
|
224 |
"outputs": [],
|
225 |
"source": [
|
226 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
227 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
228 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
229 |
]
|
230 |
},
|
231 |
{
|
|
|
243 |
"\n",
|
244 |
"def get_efficientnet_model(num_classes):\n",
|
245 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
|
|
|
|
246 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
247 |
"\n",
|
248 |
" return model\n",
|
|
|
258 |
"source": [
|
259 |
"if torch.backends.mps.is_available():\n",
|
260 |
" device = torch.device(\"mps\")\n",
|
261 |
+
" print(\"Using MPS (Apple GPU)\")\n",
|
262 |
"else:\n",
|
263 |
" device = torch.device(\"cpu\")\n",
|
264 |
+
" print(\"MPS not available. Using CPU\")\n",
|
265 |
"\n",
|
266 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
267 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
|
|
304 |
"\n",
|
305 |
" total_train_loss += loss.item()\n",
|
306 |
"\n",
|
|
|
307 |
" pred_labels = preds.argmax(dim=1)\n",
|
308 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
309 |
" train_total += batch_y.size(0)\n",
|
|
|
358 |
"\n",
|
359 |
"plt.figure(figsize=(12, 5))\n",
|
360 |
"\n",
|
|
|
361 |
"plt.subplot(1, 2, 1)\n",
|
362 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
363 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
|
|
367 |
"plt.legend()\n",
|
368 |
"plt.grid(True)\n",
|
369 |
"\n",
|
|
|
370 |
"plt.subplot(1, 2, 2)\n",
|
371 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
372 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
|
|
479 |
" activations[name] = output.detach().cpu()\n",
|
480 |
" return hook\n",
|
481 |
"\n",
|
|
|
482 |
" hooks = []\n",
|
483 |
" for i in range(len(model.features)):\n",
|
484 |
" layer = model.features[i]\n",
|
485 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
486 |
"\n",
|
487 |
" with torch.no_grad():\n",
|
488 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
489 |
"\n",
|
490 |
" for h in hooks:\n",
|
491 |
" h.remove()\n",
|
492 |
"\n",
|
493 |
" for layer_name, fmap in activations.items():\n",
|
494 |
+
" fmap = fmap.squeeze(0) \n",
|
495 |
+
" channel_scores = fmap.mean(dim=(1, 2)) \n",
|
|
|
|
|
496 |
"\n",
|
|
|
497 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
498 |
" top_indices = topk.indices\n",
|
499 |
"\n",
|
|
|
500 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
501 |
" for idx, ch in enumerate(top_indices):\n",
|
502 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
529 |
"\n",
|
530 |
"img = Image.open(\"dataset/Pear_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
531 |
"\n",
|
|
|
532 |
"transform = transforms.Compose([\n",
|
533 |
" transforms.Resize((224, 224)),\n",
|
534 |
" transforms.ToTensor()\n",
|
535 |
"])\n",
|
536 |
+
"img_tensor = transform(img) \n",
|
|
|
|
|
537 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
538 |
]
|
539 |
},
|
|
|
546 |
"source": [
|
547 |
"img = Image.open(\"dataset/Pear_512/Halved/image_0578.jpg\").convert(\"RGB\")\n",
|
548 |
"\n",
|
|
|
549 |
"transform = transforms.Compose([\n",
|
550 |
" transforms.Resize((224, 224)),\n",
|
551 |
" transforms.ToTensor()\n",
|
552 |
"])\n",
|
553 |
+
"img_tensor = transform(img) \n",
|
|
|
|
|
554 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
555 |
]
|
556 |
},
|
|
|
562 |
"outputs": [],
|
563 |
"source": [
|
564 |
"img = Image.open(\"dataset/Pear_512/Sliced/image_0007.jpg\").convert(\"RGB\")\n",
|
|
|
|
|
565 |
"transform = transforms.Compose([\n",
|
566 |
" transforms.Resize((224, 224)),\n",
|
567 |
" transforms.ToTensor()\n",
|
568 |
"])\n",
|
569 |
+
"img_tensor = transform(img) \n",
|
|
|
|
|
570 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
571 |
]
|
572 |
},
|
scripts/CV/script_strawberry.ipynb
CHANGED
@@ -59,10 +59,10 @@
|
|
59 |
"def augment_rotations(X, y):\n",
|
60 |
" X_aug = []\n",
|
61 |
" y_aug = []\n",
|
62 |
-
" for k in [1, 2, 3]:
|
63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
64 |
" X_aug.append(X_rot)\n",
|
65 |
-
" y_aug.append(y.clone())
|
66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
67 |
]
|
68 |
},
|
@@ -124,7 +124,6 @@
|
|
124 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
125 |
" plt.show()\n",
|
126 |
"\n",
|
127 |
-
"# Display for each class\n",
|
128 |
"for class_name, image_array in datasets.items():\n",
|
129 |
" show_random_samples(image_array, class_name)\n"
|
130 |
]
|
@@ -140,7 +139,7 @@
|
|
140 |
"\n",
|
141 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
142 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
143 |
-
" ax.label_outer()
|
144 |
"\n",
|
145 |
"plt.tight_layout()\n",
|
146 |
"plt.show()"
|
@@ -156,7 +155,7 @@
|
|
156 |
"class_names = list(datasets.keys())\n",
|
157 |
"num_classes = len(class_names)\n",
|
158 |
"\n",
|
159 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
160 |
"\n",
|
161 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
162 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
@@ -181,7 +180,6 @@
|
|
181 |
" \"whole\": strawberry_whole_images\n",
|
182 |
"}\n",
|
183 |
"\n",
|
184 |
-
"# Combine data\n",
|
185 |
"X = np.concatenate([strawberry_hulled_images, strawberry_sliced_images, strawberry_whole_images], axis=0)\n",
|
186 |
"y = (\n",
|
187 |
" ['hulled'] * len(strawberry_hulled_images) +\n",
|
@@ -189,17 +187,14 @@
|
|
189 |
" ['whole'] * len(strawberry_whole_images)\n",
|
190 |
")\n",
|
191 |
"\n",
|
192 |
-
"# Normalize and convert to torch tensors\n",
|
193 |
"X = X.astype(np.float32) / 255.0\n",
|
194 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
195 |
"X_tensor = torch.tensor(X)\n",
|
196 |
"\n",
|
197 |
-
"# Encode labels\n",
|
198 |
"le = LabelEncoder()\n",
|
199 |
"y_encoded = le.fit_transform(y)\n",
|
200 |
"y_tensor = torch.tensor(y_encoded)\n",
|
201 |
"\n",
|
202 |
-
"# Train/val/test split\n",
|
203 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
204 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
205 |
]
|
@@ -215,11 +210,9 @@
|
|
215 |
"\n",
|
216 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
217 |
"\n",
|
218 |
-
"# Combine original and augmented data\n",
|
219 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
220 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
221 |
"\n",
|
222 |
-
"\n",
|
223 |
"train_dataset = TensorDataset(X_train_combined, y_train_combined)\n",
|
224 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
225 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
@@ -236,9 +229,9 @@
|
|
236 |
"metadata": {},
|
237 |
"outputs": [],
|
238 |
"source": [
|
239 |
-
"print(f\"
|
240 |
-
"print(f\"
|
241 |
-
"print(f\"
|
242 |
]
|
243 |
},
|
244 |
{
|
@@ -341,7 +334,6 @@
|
|
341 |
" val_accuracy = val_correct / val_total\n",
|
342 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
343 |
"\n",
|
344 |
-
" # After calculating val_accuracy\n",
|
345 |
" val_losses.append(validation_loss)\n",
|
346 |
" val_accs.append(val_accuracy)\n",
|
347 |
"\n",
|
@@ -449,29 +441,25 @@
|
|
449 |
"source": [
|
450 |
"all_preds = np.array(all_preds)\n",
|
451 |
"all_targets = np.array(all_targets)\n",
|
452 |
-
"all_images = torch.stack(all_images)
|
453 |
"\n",
|
454 |
-
"# Per class FP and FN\n",
|
455 |
"for class_idx, class_name in enumerate(target_names):\n",
|
456 |
-
" print(f\"\\
|
457 |
-
"\n",
|
458 |
-
" # False Negatives: True label is class_idx, but predicted something else\n",
|
459 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
460 |
-
" # False Positives: Predicted class_idx, but true label is different\n",
|
461 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
462 |
"\n",
|
463 |
" def show_images(indices, title, max_images=5):\n",
|
464 |
" num = min(len(indices), max_images)\n",
|
465 |
" if num == 0:\n",
|
466 |
-
" print(f\"
|
467 |
" return\n",
|
468 |
"\n",
|
469 |
" plt.figure(figsize=(12, 2))\n",
|
470 |
" for i, idx in enumerate(indices[:num]):\n",
|
471 |
" img = all_images[idx]\n",
|
472 |
-
" img = img.permute(1, 2, 0).numpy()
|
473 |
" plt.subplot(1, num, i + 1)\n",
|
474 |
-
" plt.imshow((img - img.min()) / (img.max() - img.min()))
|
475 |
" plt.axis('off')\n",
|
476 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
477 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
@@ -498,29 +486,25 @@
|
|
498 |
" activations[name] = output.detach().cpu()\n",
|
499 |
" return hook\n",
|
500 |
"\n",
|
501 |
-
" # Register hooks for all layers in model.features\n",
|
502 |
" hooks = []\n",
|
503 |
" for i in range(len(model.features)):\n",
|
504 |
" layer = model.features[i]\n",
|
505 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
506 |
"\n",
|
507 |
" with torch.no_grad():\n",
|
508 |
-
" _ = model(image_tensor.unsqueeze(0))
|
509 |
"\n",
|
510 |
" for h in hooks:\n",
|
511 |
" h.remove()\n",
|
512 |
"\n",
|
513 |
" for layer_name, fmap in activations.items():\n",
|
514 |
-
" fmap = fmap.squeeze(0)
|
515 |
"\n",
|
516 |
-
"
|
517 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
518 |
"\n",
|
519 |
-
" # Get indices of top-k channels\n",
|
520 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
521 |
" top_indices = topk.indices\n",
|
522 |
"\n",
|
523 |
-
" # Plot top-k channels\n",
|
524 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
525 |
" for idx, ch in enumerate(top_indices):\n",
|
526 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
@@ -553,14 +537,12 @@
|
|
553 |
"\n",
|
554 |
"img = Image.open(\"dataset/Strawberry_512/Whole/image_0017.jpg\").convert(\"RGB\")\n",
|
555 |
"\n",
|
556 |
-
"# Preprocessing (must match model requirements)\n",
|
557 |
"transform = transforms.Compose([\n",
|
558 |
" transforms.Resize((224, 224)),\n",
|
559 |
" transforms.ToTensor()\n",
|
560 |
"])\n",
|
561 |
-
"img_tensor = transform(img)
|
562 |
"\n",
|
563 |
-
"# Visualize feature maps\n",
|
564 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
565 |
]
|
566 |
},
|
@@ -574,14 +556,12 @@
|
|
574 |
"\n",
|
575 |
"img = Image.open(\"dataset/Strawberry_512/Hulled/image_0001.jpg\").convert(\"RGB\")\n",
|
576 |
"\n",
|
577 |
-
"# Preprocessing (must match model requirements)\n",
|
578 |
"transform = transforms.Compose([\n",
|
579 |
" transforms.Resize((224, 224)),\n",
|
580 |
" transforms.ToTensor()\n",
|
581 |
"])\n",
|
582 |
-
"img_tensor = transform(img)
|
583 |
"\n",
|
584 |
-
"# Visualize feature maps\n",
|
585 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
586 |
]
|
587 |
},
|
@@ -595,15 +575,13 @@
|
|
595 |
"\n",
|
596 |
"img = Image.open(\"dataset/Strawberry_512/Sliced/image_0001.jpg\").convert(\"RGB\")\n",
|
597 |
"\n",
|
598 |
-
"# Preprocessing (must match model requirements)\n",
|
599 |
"transform = transforms.Compose([\n",
|
600 |
" transforms.Resize((224, 224)),\n",
|
601 |
" transforms.ToTensor()\n",
|
602 |
"])\n",
|
603 |
-
"img_tensor = transform(img)
|
604 |
"\n",
|
605 |
-
"
|
606 |
-
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
607 |
]
|
608 |
},
|
609 |
{
|
|
|
59 |
"def augment_rotations(X, y):\n",
|
60 |
" X_aug = []\n",
|
61 |
" y_aug = []\n",
|
62 |
+
" for k in [1, 2, 3]: \n",
|
63 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
|
64 |
" X_aug.append(X_rot)\n",
|
65 |
+
" y_aug.append(y.clone()) \n",
|
66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
67 |
]
|
68 |
},
|
|
|
124 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
125 |
" plt.show()\n",
|
126 |
"\n",
|
|
|
127 |
"for class_name, image_array in datasets.items():\n",
|
128 |
" show_random_samples(image_array, class_name)\n"
|
129 |
]
|
|
|
139 |
"\n",
|
140 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
141 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
142 |
+
" ax.label_outer() \n",
|
143 |
"\n",
|
144 |
"plt.tight_layout()\n",
|
145 |
"plt.show()"
|
|
|
155 |
"class_names = list(datasets.keys())\n",
|
156 |
"num_classes = len(class_names)\n",
|
157 |
"\n",
|
158 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
159 |
"\n",
|
160 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
161 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
180 |
" \"whole\": strawberry_whole_images\n",
|
181 |
"}\n",
|
182 |
"\n",
|
|
|
183 |
"X = np.concatenate([strawberry_hulled_images, strawberry_sliced_images, strawberry_whole_images], axis=0)\n",
|
184 |
"y = (\n",
|
185 |
" ['hulled'] * len(strawberry_hulled_images) +\n",
|
|
|
187 |
" ['whole'] * len(strawberry_whole_images)\n",
|
188 |
")\n",
|
189 |
"\n",
|
|
|
190 |
"X = X.astype(np.float32) / 255.0\n",
|
191 |
+
"X = np.transpose(X, (0, 3, 1, 2)) \n",
|
192 |
"X_tensor = torch.tensor(X)\n",
|
193 |
"\n",
|
|
|
194 |
"le = LabelEncoder()\n",
|
195 |
"y_encoded = le.fit_transform(y)\n",
|
196 |
"y_tensor = torch.tensor(y_encoded)\n",
|
197 |
"\n",
|
|
|
198 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n",
|
199 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
200 |
]
|
|
|
210 |
"\n",
|
211 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
212 |
"\n",
|
|
|
213 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
214 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
215 |
"\n",
|
|
|
216 |
"train_dataset = TensorDataset(X_train_combined, y_train_combined)\n",
|
217 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
218 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
|
|
229 |
"metadata": {},
|
230 |
"outputs": [],
|
231 |
"source": [
|
232 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
233 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
234 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
235 |
]
|
236 |
},
|
237 |
{
|
|
|
334 |
" val_accuracy = val_correct / val_total\n",
|
335 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
336 |
"\n",
|
|
|
337 |
" val_losses.append(validation_loss)\n",
|
338 |
" val_accs.append(val_accuracy)\n",
|
339 |
"\n",
|
|
|
441 |
"source": [
|
442 |
"all_preds = np.array(all_preds)\n",
|
443 |
"all_targets = np.array(all_targets)\n",
|
444 |
+
"all_images = torch.stack(all_images) \n",
|
445 |
"\n",
|
|
|
446 |
"for class_idx, class_name in enumerate(target_names):\n",
|
447 |
+
" print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n",
|
|
|
|
|
448 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
|
|
449 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
450 |
"\n",
|
451 |
" def show_images(indices, title, max_images=5):\n",
|
452 |
" num = min(len(indices), max_images)\n",
|
453 |
" if num == 0:\n",
|
454 |
+
" print(f\"No {title} samples.\")\n",
|
455 |
" return\n",
|
456 |
"\n",
|
457 |
" plt.figure(figsize=(12, 2))\n",
|
458 |
" for i, idx in enumerate(indices[:num]):\n",
|
459 |
" img = all_images[idx]\n",
|
460 |
+
" img = img.permute(1, 2, 0).numpy()\n",
|
461 |
" plt.subplot(1, num, i + 1)\n",
|
462 |
+
" plt.imshow((img - img.min()) / (img.max() - img.min()))\n",
|
463 |
" plt.axis('off')\n",
|
464 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
465 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
|
|
486 |
" activations[name] = output.detach().cpu()\n",
|
487 |
" return hook\n",
|
488 |
"\n",
|
|
|
489 |
" hooks = []\n",
|
490 |
" for i in range(len(model.features)):\n",
|
491 |
" layer = model.features[i]\n",
|
492 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
493 |
"\n",
|
494 |
" with torch.no_grad():\n",
|
495 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
496 |
"\n",
|
497 |
" for h in hooks:\n",
|
498 |
" h.remove()\n",
|
499 |
"\n",
|
500 |
" for layer_name, fmap in activations.items():\n",
|
501 |
+
" fmap = fmap.squeeze(0) \n",
|
502 |
"\n",
|
503 |
+
" channel_scores = fmap.mean(dim=(1, 2))\n",
|
|
|
504 |
"\n",
|
|
|
505 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
506 |
" top_indices = topk.indices\n",
|
507 |
"\n",
|
|
|
508 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
509 |
" for idx, ch in enumerate(top_indices):\n",
|
510 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
537 |
"\n",
|
538 |
"img = Image.open(\"dataset/Strawberry_512/Whole/image_0017.jpg\").convert(\"RGB\")\n",
|
539 |
"\n",
|
|
|
540 |
"transform = transforms.Compose([\n",
|
541 |
" transforms.Resize((224, 224)),\n",
|
542 |
" transforms.ToTensor()\n",
|
543 |
"])\n",
|
544 |
+
"img_tensor = transform(img) \n",
|
545 |
"\n",
|
|
|
546 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
547 |
]
|
548 |
},
|
|
|
556 |
"\n",
|
557 |
"img = Image.open(\"dataset/Strawberry_512/Hulled/image_0001.jpg\").convert(\"RGB\")\n",
|
558 |
"\n",
|
|
|
559 |
"transform = transforms.Compose([\n",
|
560 |
" transforms.Resize((224, 224)),\n",
|
561 |
" transforms.ToTensor()\n",
|
562 |
"])\n",
|
563 |
+
"img_tensor = transform(img) \n",
|
564 |
"\n",
|
|
|
565 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
566 |
]
|
567 |
},
|
|
|
575 |
"\n",
|
576 |
"img = Image.open(\"dataset/Strawberry_512/Sliced/image_0001.jpg\").convert(\"RGB\")\n",
|
577 |
"\n",
|
|
|
578 |
"transform = transforms.Compose([\n",
|
579 |
" transforms.Resize((224, 224)),\n",
|
580 |
" transforms.ToTensor()\n",
|
581 |
"])\n",
|
582 |
+
"img_tensor = transform(img) \n",
|
583 |
"\n",
|
584 |
+
"visualize_channels(model, img_tensor, max_channels=16)"
|
|
|
585 |
]
|
586 |
},
|
587 |
{
|
scripts/CV/script_tomato.ipynb
CHANGED
@@ -13,7 +13,6 @@
|
|
13 |
"import matplotlib.pyplot as plt\n",
|
14 |
"import random\n",
|
15 |
"import torch\n",
|
16 |
-
"import numpy as np\n",
|
17 |
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
|
18 |
"from sklearn.preprocessing import LabelEncoder\n",
|
19 |
"from sklearn.model_selection import train_test_split\n",
|
@@ -59,10 +58,10 @@
|
|
59 |
"def augment_rotations(X, y):\n",
|
60 |
" X_aug = []\n",
|
61 |
" y_aug = []\n",
|
62 |
-
" for k in [1, 2, 3]:
|
63 |
-
" X_rot = torch.rot90(X, k=k, dims=[2, 3])
|
64 |
" X_aug.append(X_rot)\n",
|
65 |
-
" y_aug.append(y.clone())
|
66 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
67 |
]
|
68 |
},
|
@@ -103,8 +102,7 @@
|
|
103 |
"metadata": {},
|
104 |
"outputs": [],
|
105 |
"source": [
|
106 |
-
"
|
107 |
-
"import random\n",
|
108 |
"datasets = {\n",
|
109 |
" \"diced\": tomato_diced_images,\n",
|
110 |
" \"vines\": tomato_vines_images,\n",
|
@@ -124,7 +122,6 @@
|
|
124 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
125 |
" plt.show()\n",
|
126 |
"\n",
|
127 |
-
"# Display for each class\n",
|
128 |
"for class_name, image_array in datasets.items():\n",
|
129 |
" show_random_samples(image_array, class_name)\n"
|
130 |
]
|
@@ -140,7 +137,7 @@
|
|
140 |
"\n",
|
141 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
142 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
143 |
-
" ax.label_outer()
|
144 |
"\n",
|
145 |
"plt.tight_layout()\n",
|
146 |
"plt.show()"
|
@@ -156,7 +153,7 @@
|
|
156 |
"class_names = list(datasets.keys())\n",
|
157 |
"num_classes = len(class_names)\n",
|
158 |
"\n",
|
159 |
-
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))
|
160 |
"\n",
|
161 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
162 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
@@ -175,20 +172,12 @@
|
|
175 |
"metadata": {},
|
176 |
"outputs": [],
|
177 |
"source": [
|
178 |
-
"import torch\n",
|
179 |
-
"import numpy as np\n",
|
180 |
-
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
|
181 |
-
"from sklearn.preprocessing import LabelEncoder\n",
|
182 |
-
"from sklearn.model_selection import train_test_split\n",
|
183 |
-
"from torchvision import transforms\n",
|
184 |
-
"\n",
|
185 |
"datasets = {\n",
|
186 |
" \"diced\": tomato_diced_images,\n",
|
187 |
" \"vines\": tomato_vines_images,\n",
|
188 |
" \"whole\": tomato_whole_images\n",
|
189 |
"}\n",
|
190 |
"\n",
|
191 |
-
"# Combine data\n",
|
192 |
"X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n",
|
193 |
"y = (\n",
|
194 |
" ['diced'] * len(tomato_diced_images) +\n",
|
@@ -196,17 +185,14 @@
|
|
196 |
" ['whole'] * len(tomato_whole_images)\n",
|
197 |
")\n",
|
198 |
"\n",
|
199 |
-
"# Normalize and convert to torch tensors\n",
|
200 |
"X = X.astype(np.float32) / 255.0\n",
|
201 |
-
"X = np.transpose(X, (0, 3, 1, 2))
|
202 |
"X_tensor = torch.tensor(X)\n",
|
203 |
"\n",
|
204 |
-
"# Encode labels\n",
|
205 |
"le = LabelEncoder()\n",
|
206 |
"y_encoded = le.fit_transform(y)\n",
|
207 |
"y_tensor = torch.tensor(y_encoded)\n",
|
208 |
"\n",
|
209 |
-
"# Train/val/test split\n",
|
210 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
211 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
212 |
]
|
@@ -222,17 +208,13 @@
|
|
222 |
"\n",
|
223 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
224 |
"\n",
|
225 |
-
"# Combine original and augmented data\n",
|
226 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
227 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
228 |
"\n",
|
229 |
-
"# Create new training dataset and loader\n",
|
230 |
-
"\n",
|
231 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
232 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
233 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
234 |
"\n",
|
235 |
-
"# DataLoaders\n",
|
236 |
"\n",
|
237 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
238 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
@@ -246,9 +228,9 @@
|
|
246 |
"metadata": {},
|
247 |
"outputs": [],
|
248 |
"source": [
|
249 |
-
"print(f\"
|
250 |
-
"print(f\"
|
251 |
-
"print(f\"
|
252 |
]
|
253 |
},
|
254 |
{
|
@@ -258,18 +240,9 @@
|
|
258 |
"metadata": {},
|
259 |
"outputs": [],
|
260 |
"source": [
|
261 |
-
"import torch.nn as nn\n",
|
262 |
-
"import torch.nn.functional as F\n",
|
263 |
-
"\n",
|
264 |
-
"import torch.nn as nn\n",
|
265 |
-
"import torchvision.models as models\n",
|
266 |
-
"\n",
|
267 |
"def get_efficientnet_model(num_classes):\n",
|
268 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
269 |
-
"\n",
|
270 |
-
" # Replace classifier head with custom head\n",
|
271 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
272 |
-
"\n",
|
273 |
" return model\n",
|
274 |
"\n"
|
275 |
]
|
@@ -283,10 +256,10 @@
|
|
283 |
"source": [
|
284 |
"if torch.backends.mps.is_available():\n",
|
285 |
" device = torch.device(\"mps\")\n",
|
286 |
-
" print(\"
|
287 |
"else:\n",
|
288 |
" device = torch.device(\"cpu\")\n",
|
289 |
-
" print(\"
|
290 |
"\n",
|
291 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
292 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
@@ -329,7 +302,6 @@
|
|
329 |
"\n",
|
330 |
" total_train_loss += loss.item()\n",
|
331 |
"\n",
|
332 |
-
" # Track training accuracy\n",
|
333 |
" pred_labels = preds.argmax(dim=1)\n",
|
334 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
335 |
" train_total += batch_y.size(0)\n",
|
@@ -353,7 +325,6 @@
|
|
353 |
" val_accuracy = val_correct / val_total\n",
|
354 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
355 |
"\n",
|
356 |
-
" # After calculating val_accuracy\n",
|
357 |
" val_losses.append(validation_loss)\n",
|
358 |
" val_accs.append(val_accuracy)\n",
|
359 |
"\n",
|
@@ -381,13 +352,12 @@
|
|
381 |
"metadata": {},
|
382 |
"outputs": [],
|
383 |
"source": [
|
384 |
-
"
|
385 |
"\n",
|
386 |
"epochs = range(1, len(train_losses) + 1)\n",
|
387 |
"\n",
|
388 |
"plt.figure(figsize=(12, 5))\n",
|
389 |
"\n",
|
390 |
-
"# Plot Loss\n",
|
391 |
"plt.subplot(1, 2, 1)\n",
|
392 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
393 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
@@ -397,7 +367,6 @@
|
|
397 |
"plt.legend()\n",
|
398 |
"plt.grid(True)\n",
|
399 |
"\n",
|
400 |
-
"# Plot Accuracy\n",
|
401 |
"plt.subplot(1, 2, 2)\n",
|
402 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
403 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
@@ -462,35 +431,28 @@
|
|
462 |
"metadata": {},
|
463 |
"outputs": [],
|
464 |
"source": [
|
465 |
-
"import torch\n",
|
466 |
-
"import numpy as np\n",
|
467 |
-
"import matplotlib.pyplot as plt\n",
|
468 |
"\n",
|
469 |
"all_preds = np.array(all_preds)\n",
|
470 |
"all_targets = np.array(all_targets)\n",
|
471 |
-
"all_images = torch.stack(all_images)
|
472 |
"\n",
|
473 |
-
"# Per class FP and FN\n",
|
474 |
"for class_idx, class_name in enumerate(target_names):\n",
|
475 |
-
" print(f\"\\
|
476 |
-
"\n",
|
477 |
-
" # False Negatives: True label is class_idx, but predicted something else\n",
|
478 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
479 |
-
" # False Positives: Predicted class_idx, but true label is different\n",
|
480 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
481 |
"\n",
|
482 |
" def show_images(indices, title, max_images=5):\n",
|
483 |
" num = min(len(indices), max_images)\n",
|
484 |
" if num == 0:\n",
|
485 |
-
" print(f\"
|
486 |
" return\n",
|
487 |
"\n",
|
488 |
" plt.figure(figsize=(12, 2))\n",
|
489 |
" for i, idx in enumerate(indices[:num]):\n",
|
490 |
" img = all_images[idx]\n",
|
491 |
-
" img = img.permute(1, 2, 0).numpy()
|
492 |
" plt.subplot(1, num, i + 1)\n",
|
493 |
-
" plt.imshow((img - img.min()) / (img.max() - img.min()))
|
494 |
" plt.axis('off')\n",
|
495 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
496 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
@@ -517,29 +479,25 @@
|
|
517 |
" activations[name] = output.detach().cpu()\n",
|
518 |
" return hook\n",
|
519 |
"\n",
|
520 |
-
" # Register hooks for all layers in model.features\n",
|
521 |
" hooks = []\n",
|
522 |
" for i in range(len(model.features)):\n",
|
523 |
" layer = model.features[i]\n",
|
524 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
525 |
"\n",
|
526 |
" with torch.no_grad():\n",
|
527 |
-
" _ = model(image_tensor.unsqueeze(0))
|
528 |
"\n",
|
529 |
" for h in hooks:\n",
|
530 |
" h.remove()\n",
|
531 |
"\n",
|
532 |
" for layer_name, fmap in activations.items():\n",
|
533 |
-
" fmap = fmap.squeeze(0)
|
534 |
"\n",
|
535 |
-
"
|
536 |
-
" channel_scores = fmap.mean(dim=(1, 2)) # [C]\n",
|
537 |
"\n",
|
538 |
-
" # Get indices of top-k channels\n",
|
539 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
540 |
" top_indices = topk.indices\n",
|
541 |
"\n",
|
542 |
-
" # Plot top-k channels\n",
|
543 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
544 |
" for idx, ch in enumerate(top_indices):\n",
|
545 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
@@ -572,14 +530,12 @@
|
|
572 |
"\n",
|
573 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
574 |
"\n",
|
575 |
-
"# Preprocessing (must match model requirements)\n",
|
576 |
"transform = transforms.Compose([\n",
|
577 |
" transforms.Resize((224, 224)),\n",
|
578 |
" transforms.ToTensor()\n",
|
579 |
"])\n",
|
580 |
-
"img_tensor = transform(img)
|
581 |
"\n",
|
582 |
-
"# Visualize feature maps\n",
|
583 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
584 |
]
|
585 |
},
|
@@ -592,14 +548,12 @@
|
|
592 |
"source": [
|
593 |
"img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n",
|
594 |
"\n",
|
595 |
-
"# Preprocessing (must match model requirements)\n",
|
596 |
"transform = transforms.Compose([\n",
|
597 |
" transforms.Resize((224, 224)),\n",
|
598 |
" transforms.ToTensor()\n",
|
599 |
"])\n",
|
600 |
-
"img_tensor = transform(img)
|
601 |
"\n",
|
602 |
-
"# Visualize feature maps\n",
|
603 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
604 |
]
|
605 |
},
|
@@ -612,14 +566,12 @@
|
|
612 |
"source": [
|
613 |
"img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n",
|
614 |
"\n",
|
615 |
-
"# Preprocessing (must match model requirements)\n",
|
616 |
"transform = transforms.Compose([\n",
|
617 |
" transforms.Resize((224, 224)),\n",
|
618 |
" transforms.ToTensor()\n",
|
619 |
"])\n",
|
620 |
-
"img_tensor = transform(img)
|
621 |
"\n",
|
622 |
-
"# Visualize feature maps\n",
|
623 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
624 |
]
|
625 |
},
|
|
|
13 |
"import matplotlib.pyplot as plt\n",
|
14 |
"import random\n",
|
15 |
"import torch\n",
|
|
|
16 |
"from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
|
17 |
"from sklearn.preprocessing import LabelEncoder\n",
|
18 |
"from sklearn.model_selection import train_test_split\n",
|
|
|
58 |
"def augment_rotations(X, y):\n",
|
59 |
" X_aug = []\n",
|
60 |
" y_aug = []\n",
|
61 |
+
" for k in [1, 2, 3]: \n",
|
62 |
+
" X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n",
|
63 |
" X_aug.append(X_rot)\n",
|
64 |
+
" y_aug.append(y.clone()) \n",
|
65 |
" return torch.cat(X_aug), torch.cat(y_aug)"
|
66 |
]
|
67 |
},
|
|
|
102 |
"metadata": {},
|
103 |
"outputs": [],
|
104 |
"source": [
|
105 |
+
"\n",
|
|
|
106 |
"datasets = {\n",
|
107 |
" \"diced\": tomato_diced_images,\n",
|
108 |
" \"vines\": tomato_vines_images,\n",
|
|
|
122 |
" plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n",
|
123 |
" plt.show()\n",
|
124 |
"\n",
|
|
|
125 |
"for class_name, image_array in datasets.items():\n",
|
126 |
" show_random_samples(image_array, class_name)\n"
|
127 |
]
|
|
|
137 |
"\n",
|
138 |
"for ax, (class_name, images) in zip(axes, datasets.items()):\n",
|
139 |
" plot_rgb_histogram_subplot(ax, images, class_name)\n",
|
140 |
+
" ax.label_outer() \n",
|
141 |
"\n",
|
142 |
"plt.tight_layout()\n",
|
143 |
"plt.show()"
|
|
|
153 |
"class_names = list(datasets.keys())\n",
|
154 |
"num_classes = len(class_names)\n",
|
155 |
"\n",
|
156 |
+
"fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n",
|
157 |
"\n",
|
158 |
"for i, (class_name, images) in enumerate(datasets.items()):\n",
|
159 |
" avg_img = np.mean(images.astype(np.float32), axis=0)\n",
|
|
|
172 |
"metadata": {},
|
173 |
"outputs": [],
|
174 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
"datasets = {\n",
|
176 |
" \"diced\": tomato_diced_images,\n",
|
177 |
" \"vines\": tomato_vines_images,\n",
|
178 |
" \"whole\": tomato_whole_images\n",
|
179 |
"}\n",
|
180 |
"\n",
|
|
|
181 |
"X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n",
|
182 |
"y = (\n",
|
183 |
" ['diced'] * len(tomato_diced_images) +\n",
|
|
|
185 |
" ['whole'] * len(tomato_whole_images)\n",
|
186 |
")\n",
|
187 |
"\n",
|
|
|
188 |
"X = X.astype(np.float32) / 255.0\n",
|
189 |
+
"X = np.transpose(X, (0, 3, 1, 2))\n",
|
190 |
"X_tensor = torch.tensor(X)\n",
|
191 |
"\n",
|
|
|
192 |
"le = LabelEncoder()\n",
|
193 |
"y_encoded = le.fit_transform(y)\n",
|
194 |
"y_tensor = torch.tensor(y_encoded)\n",
|
195 |
"\n",
|
|
|
196 |
"X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n",
|
197 |
"X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n"
|
198 |
]
|
|
|
208 |
"\n",
|
209 |
"X_augmented, y_augmented = augment_rotations(X_train, y_train)\n",
|
210 |
"\n",
|
|
|
211 |
"X_train_combined = torch.cat([X_train, X_augmented])\n",
|
212 |
"y_train_combined = torch.cat([y_train, y_augmented])\n",
|
213 |
"\n",
|
|
|
|
|
214 |
"train_dataset = TensorDataset(X_train, y_train)\n",
|
215 |
"val_dataset = TensorDataset(X_val, y_val)\n",
|
216 |
"test_dataset = TensorDataset(X_test, y_test)\n",
|
217 |
"\n",
|
|
|
218 |
"\n",
|
219 |
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
220 |
"val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
|
|
|
228 |
"metadata": {},
|
229 |
"outputs": [],
|
230 |
"source": [
|
231 |
+
"print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n",
|
232 |
+
"print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n",
|
233 |
+
"print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")"
|
234 |
]
|
235 |
},
|
236 |
{
|
|
|
240 |
"metadata": {},
|
241 |
"outputs": [],
|
242 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
"def get_efficientnet_model(num_classes):\n",
|
244 |
" model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n",
|
|
|
|
|
245 |
" model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n",
|
|
|
246 |
" return model\n",
|
247 |
"\n"
|
248 |
]
|
|
|
256 |
"source": [
|
257 |
"if torch.backends.mps.is_available():\n",
|
258 |
" device = torch.device(\"mps\")\n",
|
259 |
+
" print(\"Using MPS (Apple GPU)\")\n",
|
260 |
"else:\n",
|
261 |
" device = torch.device(\"cpu\")\n",
|
262 |
+
" print(\"MPS not available. Using CPU\")\n",
|
263 |
"\n",
|
264 |
"model = get_efficientnet_model(num_classes=3).to(device)\n",
|
265 |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
|
|
|
302 |
"\n",
|
303 |
" total_train_loss += loss.item()\n",
|
304 |
"\n",
|
|
|
305 |
" pred_labels = preds.argmax(dim=1)\n",
|
306 |
" train_correct += (pred_labels == batch_y).sum().item()\n",
|
307 |
" train_total += batch_y.size(0)\n",
|
|
|
325 |
" val_accuracy = val_correct / val_total\n",
|
326 |
" validation_loss = criterion(model(val_x), val_y).item()\n",
|
327 |
"\n",
|
|
|
328 |
" val_losses.append(validation_loss)\n",
|
329 |
" val_accs.append(val_accuracy)\n",
|
330 |
"\n",
|
|
|
352 |
"metadata": {},
|
353 |
"outputs": [],
|
354 |
"source": [
|
355 |
+
"\n",
|
356 |
"\n",
|
357 |
"epochs = range(1, len(train_losses) + 1)\n",
|
358 |
"\n",
|
359 |
"plt.figure(figsize=(12, 5))\n",
|
360 |
"\n",
|
|
|
361 |
"plt.subplot(1, 2, 1)\n",
|
362 |
"plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n",
|
363 |
"plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n",
|
|
|
367 |
"plt.legend()\n",
|
368 |
"plt.grid(True)\n",
|
369 |
"\n",
|
|
|
370 |
"plt.subplot(1, 2, 2)\n",
|
371 |
"plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n",
|
372 |
"plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n",
|
|
|
431 |
"metadata": {},
|
432 |
"outputs": [],
|
433 |
"source": [
|
|
|
|
|
|
|
434 |
"\n",
|
435 |
"all_preds = np.array(all_preds)\n",
|
436 |
"all_targets = np.array(all_targets)\n",
|
437 |
+
"all_images = torch.stack(all_images) \n",
|
438 |
"\n",
|
|
|
439 |
"for class_idx, class_name in enumerate(target_names):\n",
|
440 |
+
" print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n",
|
|
|
|
|
441 |
" fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n",
|
|
|
442 |
" fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n",
|
443 |
"\n",
|
444 |
" def show_images(indices, title, max_images=5):\n",
|
445 |
" num = min(len(indices), max_images)\n",
|
446 |
" if num == 0:\n",
|
447 |
+
" print(f\"No {title} samples.\")\n",
|
448 |
" return\n",
|
449 |
"\n",
|
450 |
" plt.figure(figsize=(12, 2))\n",
|
451 |
" for i, idx in enumerate(indices[:num]):\n",
|
452 |
" img = all_images[idx]\n",
|
453 |
+
" img = img.permute(1, 2, 0).numpy()\n",
|
454 |
" plt.subplot(1, num, i + 1)\n",
|
455 |
+
" plt.imshow((img - img.min()) / (img.max() - img.min())) \n",
|
456 |
" plt.axis('off')\n",
|
457 |
" plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n",
|
458 |
" plt.suptitle(f\"{title} for {class_name}\")\n",
|
|
|
479 |
" activations[name] = output.detach().cpu()\n",
|
480 |
" return hook\n",
|
481 |
"\n",
|
|
|
482 |
" hooks = []\n",
|
483 |
" for i in range(len(model.features)):\n",
|
484 |
" layer = model.features[i]\n",
|
485 |
" hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n",
|
486 |
"\n",
|
487 |
" with torch.no_grad():\n",
|
488 |
+
" _ = model(image_tensor.unsqueeze(0)) \n",
|
489 |
"\n",
|
490 |
" for h in hooks:\n",
|
491 |
" h.remove()\n",
|
492 |
"\n",
|
493 |
" for layer_name, fmap in activations.items():\n",
|
494 |
+
" fmap = fmap.squeeze(0) \n",
|
495 |
"\n",
|
496 |
+
" channel_scores = fmap.mean(dim=(1, 2)) \n",
|
|
|
497 |
"\n",
|
|
|
498 |
" topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n",
|
499 |
" top_indices = topk.indices\n",
|
500 |
"\n",
|
|
|
501 |
" plt.figure(figsize=(max_channels * 2, 2.5))\n",
|
502 |
" for idx, ch in enumerate(top_indices):\n",
|
503 |
" plt.subplot(1, max_channels, idx + 1)\n",
|
|
|
530 |
"\n",
|
531 |
"img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n",
|
532 |
"\n",
|
|
|
533 |
"transform = transforms.Compose([\n",
|
534 |
" transforms.Resize((224, 224)),\n",
|
535 |
" transforms.ToTensor()\n",
|
536 |
"])\n",
|
537 |
+
"img_tensor = transform(img) \n",
|
538 |
"\n",
|
|
|
539 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
540 |
]
|
541 |
},
|
|
|
548 |
"source": [
|
549 |
"img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n",
|
550 |
"\n",
|
|
|
551 |
"transform = transforms.Compose([\n",
|
552 |
" transforms.Resize((224, 224)),\n",
|
553 |
" transforms.ToTensor()\n",
|
554 |
"])\n",
|
555 |
+
"img_tensor = transform(img) \n",
|
556 |
"\n",
|
|
|
557 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
558 |
]
|
559 |
},
|
|
|
566 |
"source": [
|
567 |
"img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n",
|
568 |
"\n",
|
|
|
569 |
"transform = transforms.Compose([\n",
|
570 |
" transforms.Resize((224, 224)),\n",
|
571 |
" transforms.ToTensor()\n",
|
572 |
"])\n",
|
573 |
+
"img_tensor = transform(img) \n",
|
574 |
"\n",
|
|
|
575 |
"visualize_channels(model, img_tensor, max_channels=16)\n"
|
576 |
]
|
577 |
},
|