{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1b70151d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "import random\n", "import torch\n", "import numpy as np\n", "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import seaborn as sns\n", "import torchvision.models as models" ] }, { "cell_type": "code", "execution_count": null, "id": "c37cc27b", "metadata": {}, "outputs": [], "source": [ "def load_images_from_folder(folder_path, image_size=(224, 224)):\n", " images = []\n", " for root, _, files in os.walk(folder_path):\n", " for file in files:\n", " if file.lower().endswith((\".jpg\", \".jpeg\")):\n", " try:\n", " img_path = os.path.join(root, file)\n", " img = Image.open(img_path).convert(\"RGB\")\n", " img = img.resize(image_size)\n", " images.append(np.array(img))\n", " except Exception as e:\n", " print(f\"Failed on {img_path}: {e}\")\n", " return np.array(images)\n", "\n", "def plot_rgb_histogram_subplot(ax, images, class_name):\n", " sample = images[random.randint(0, len(images) - 1)]\n", " colors = ('r', 'g', 'b')\n", " for i, col in enumerate(colors):\n", " hist = np.histogram(sample[:, :, i], bins=256, range=(0, 256))[0]\n", " ax.plot(hist, color=col)\n", " ax.set_title(f\"RGB Histogram – {class_name.capitalize()}\")\n", " ax.set_xlabel(\"Pixel Value\")\n", " ax.set_ylabel(\"Frequency\")\n", "\n", "def augment_rotations(X, y):\n", " X_aug = []\n", " y_aug = []\n", " for k in [1, 2, 3]: \n", " X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n", " X_aug.append(X_rot)\n", " y_aug.append(y.clone()) \n", " return torch.cat(X_aug), torch.cat(y_aug)" ] }, { "cell_type": "code", "execution_count": null, "id": "3f833049", "metadata": {}, "outputs": [], "source": [ "strawberry_halved = \"dataset/Strawberry_512/Hulled\"\n", "strawberry_sliced = \"dataset/Strawberry_512/Sliced\"\n", "strawberry_whole = \"dataset/Strawberry_512/Whole\"" ] }, { "cell_type": "code", "execution_count": null, "id": "1e913838", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "strawberry_hulled_images = load_images_from_folder(strawberry_halved)\n", "strawberry_sliced_images = load_images_from_folder(strawberry_sliced)\n", "strawberry_whole_images = load_images_from_folder(strawberry_whole)\n", "\n", "print(\"Strawberry halved images:\", strawberry_hulled_images.shape)\n", "print(\"Strawberry sliced images:\", strawberry_sliced_images.shape)\n", "print(\"Strawberry whole images:\", strawberry_whole_images.shape)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "00149f35", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import random\n", "datasets = {\n", " \"Hulled\": strawberry_hulled_images,\n", " \"sliced\": strawberry_sliced_images,\n", " \"whole\": strawberry_whole_images\n", "}\n", "\n", "\n", "def show_random_samples(images, class_name, count=5):\n", " indices = random.sample(range(images.shape[0]), count)\n", " selected = images[indices]\n", "\n", " plt.figure(figsize=(10, 2))\n", " for i, img in enumerate(selected):\n", " plt.subplot(1, count, i+1)\n", " plt.imshow(img.astype(np.uint8))\n", " plt.axis('off')\n", " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n", " plt.show()\n", "\n", "for class_name, image_array in datasets.items():\n", " show_random_samples(image_array, class_name)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a7c574a7", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, len(datasets), figsize=(20, 5))\n", "\n", "for ax, (class_name, images) in zip(axes, datasets.items()):\n", " plot_rgb_histogram_subplot(ax, images, class_name)\n", " ax.label_outer() \n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "a69bf6f5", "metadata": {}, "outputs": [], "source": [ "class_names = list(datasets.keys())\n", "num_classes = len(class_names)\n", "\n", "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n", "\n", "for i, (class_name, images) in enumerate(datasets.items()):\n", " avg_img = np.mean(images.astype(np.float32), axis=0)\n", " axes[i].imshow(avg_img.astype(np.uint8))\n", " axes[i].set_title(f\"Average Image – {class_name.capitalize()}\")\n", " axes[i].axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "dec6064b", "metadata": {}, "outputs": [], "source": [ "datasets = {\n", " \"hulled\": strawberry_hulled_images,\n", " \"sliced\": strawberry_sliced_images,\n", " \"whole\": strawberry_whole_images\n", "}\n", "\n", "X = np.concatenate([strawberry_hulled_images, strawberry_sliced_images, strawberry_whole_images], axis=0)\n", "y = (\n", " ['hulled'] * len(strawberry_hulled_images) +\n", " ['sliced'] * len(strawberry_sliced_images) +\n", " ['whole'] * len(strawberry_whole_images)\n", ")\n", "\n", "X = X.astype(np.float32) / 255.0\n", "X = np.transpose(X, (0, 3, 1, 2)) \n", "X_tensor = torch.tensor(X)\n", "\n", "le = LabelEncoder()\n", "y_encoded = le.fit_transform(y)\n", "y_tensor = torch.tensor(y_encoded)\n", "\n", "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n", "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f265aea3", "metadata": {}, "outputs": [], "source": [ "batch_size = 32\n", "\n", "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n", "\n", "X_train_combined = torch.cat([X_train, X_augmented])\n", "y_train_combined = torch.cat([y_train, y_augmented])\n", "\n", "train_dataset = TensorDataset(X_train_combined, y_train_combined)\n", "val_dataset = TensorDataset(X_val, y_val)\n", "test_dataset = TensorDataset(X_test, y_test)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": null, "id": "c469bc8d", "metadata": {}, "outputs": [], "source": [ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n", "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n", "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")" ] }, { "cell_type": "code", "execution_count": null, "id": "02440bb8", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torch.nn as nn\n", "import torchvision.models as models\n", "\n", "def get_efficientnet_model(num_classes):\n", " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n", " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n", " return model\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6a516f06", "metadata": {}, "outputs": [], "source": [ "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " print(\"Using MPS (Apple GPU)\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"MPS not available. Using CPU\")\n", "\n", "model = get_efficientnet_model(num_classes=3).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "criterion = nn.CrossEntropyLoss()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "245a6709", "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "best_val_acc = 0.0\n", "train_losses = []\n", "val_losses = []\n", "train_accs = []\n", "val_accs = []\n", "epochs_no_improve = 0\n", "early_stop = False\n", "patience = 6\n", "model_name = \"models/best_model_strawberry_v1.pth\"\n", "\n", "for epoch in range(30):\n", " if early_stop:\n", " print(f\"Early stopping at epoch {epoch}\")\n", " break\n", " model.train()\n", " total_train_loss = 0\n", " train_correct = 0\n", " train_total = 0\n", "\n", " for batch_x, batch_y in train_loader:\n", " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", " preds = model(batch_x)\n", " loss = criterion(preds, batch_y)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_train_loss += loss.item()\n", "\n", " pred_labels = preds.argmax(dim=1)\n", " train_correct += (pred_labels == batch_y).sum().item()\n", " train_total += batch_y.size(0)\n", "\n", " train_accuracy = train_correct / train_total\n", " avg_train_loss = total_train_loss / len(train_loader)\n", " train_losses.append(avg_train_loss)\n", " train_accs.append(train_accuracy)\n", "\n", " \n", " model.eval()\n", " val_correct = val_total = 0\n", "\n", " with torch.no_grad():\n", " for val_x, val_y in val_loader:\n", " val_x, val_y = val_x.to(device), val_y.to(device)\n", " val_preds = model(val_x).argmax(dim=1)\n", " val_correct += (val_preds == val_y).sum().item()\n", " val_total += val_y.size(0)\n", "\n", " val_accuracy = val_correct / val_total\n", " validation_loss = criterion(model(val_x), val_y).item()\n", "\n", " val_losses.append(validation_loss)\n", " val_accs.append(val_accuracy)\n", "\n", " print(f\"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | \"\n", " f\"Train Acc: {train_accuracy:.4f} | Val Acc: {val_accuracy:.4f}\")\n", " if val_accuracy > best_val_acc:\n", " best_val_acc = val_accuracy\n", " torch.save(model.state_dict(), model_name)\n", " print(f\"New best model saved at epoch {epoch+1} with val acc {val_accuracy:.4f}\")\n", " epochs_no_improve = 0\n", " else:\n", " epochs_no_improve += 1\n", " print(f\"No improvement for {epochs_no_improve} epoch(s)\")\n", "\n", " if epochs_no_improve >= patience:\n", " print(f\"Validation accuracy did not improve for {patience} consecutive epochs. Stopping early.\")\n", " early_stop = True\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3bbab1d8", "metadata": {}, "outputs": [], "source": [ "epochs = range(1, len(train_losses) + 1)\n", "\n", "plt.figure(figsize=(12, 5))\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n", "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.title('Loss per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n", "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.title('Accuracy per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "930d22bd", "metadata": {}, "outputs": [], "source": [ "\n", "model = get_efficientnet_model(num_classes=3).to(device)\n", "model.load_state_dict(torch.load(model_name))\n", "model.eval() \n", "\n", "all_preds = []\n", "all_targets = []\n", "all_images = []\n", "\n", "with torch.no_grad():\n", " for batch_x, batch_y in test_loader:\n", " batch_x = batch_x.to(device)\n", " preds = model(batch_x).argmax(dim=1).cpu()\n", " all_preds.extend(preds.numpy())\n", " all_targets.extend(batch_y.numpy())\n", " all_images.extend(batch_x.cpu())\n", "\n", "test_correct = sum(np.array(all_preds) == np.array(all_targets))\n", "test_total = len(all_targets)\n", "test_accuracy = test_correct / test_total\n", "\n", "print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n", "\n", "target_names = le.classes_\n", "print(\"\\nClassification Report:\\n\")\n", "print(classification_report(all_targets, all_preds, target_names=target_names))\n", "\n", "cm = confusion_matrix(all_targets, all_preds)\n", "\n", "plt.figure(figsize=(6, 5))\n", "sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=target_names, yticklabels=target_names)\n", "plt.xlabel(\"Predicted Label\")\n", "plt.ylabel(\"True Label\")\n", "plt.title(\"Confusion Matrix\")\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4823498a", "metadata": {}, "outputs": [], "source": [ "all_preds = np.array(all_preds)\n", "all_targets = np.array(all_targets)\n", "all_images = torch.stack(all_images) \n", "\n", "for class_idx, class_name in enumerate(target_names):\n", " print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n", " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n", " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n", "\n", " def show_images(indices, title, max_images=5):\n", " num = min(len(indices), max_images)\n", " if num == 0:\n", " print(f\"No {title} samples.\")\n", " return\n", "\n", " plt.figure(figsize=(12, 2))\n", " for i, idx in enumerate(indices[:num]):\n", " img = all_images[idx]\n", " img = img.permute(1, 2, 0).numpy()\n", " plt.subplot(1, num, i + 1)\n", " plt.imshow((img - img.min()) / (img.max() - img.min()))\n", " plt.axis('off')\n", " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n", " plt.suptitle(f\"{title} for {class_name}\")\n", " plt.tight_layout()\n", " plt.show()\n", "\n", " show_images(fn_indices, \"False Negatives\")\n", " show_images(fp_indices, \"False Positives\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "551cec6b", "metadata": {}, "outputs": [], "source": [ "def visualize_channels(model, image_tensor, max_channels=6):\n", " model.eval()\n", " activations = {}\n", "\n", " def get_activation(name):\n", " def hook(model, input, output):\n", " activations[name] = output.detach().cpu()\n", " return hook\n", "\n", " hooks = []\n", " for i in range(len(model.features)):\n", " layer = model.features[i]\n", " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n", "\n", " with torch.no_grad():\n", " _ = model(image_tensor.unsqueeze(0)) \n", "\n", " for h in hooks:\n", " h.remove()\n", "\n", " for layer_name, fmap in activations.items():\n", " fmap = fmap.squeeze(0) \n", "\n", " channel_scores = fmap.mean(dim=(1, 2))\n", "\n", " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n", " top_indices = topk.indices\n", "\n", " plt.figure(figsize=(max_channels * 2, 2.5))\n", " for idx, ch in enumerate(top_indices):\n", " plt.subplot(1, max_channels, idx + 1)\n", " plt.imshow(fmap[ch], cmap='viridis')\n", " plt.title(f\"{layer_name}\\nch{ch.item()} ({channel_scores[ch]:.2f})\")\n", " plt.axis('off')\n", " plt.tight_layout()\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6a0c0cdb", "metadata": {}, "outputs": [], "source": [ "model = get_efficientnet_model(num_classes=3)\n", "model.load_state_dict(torch.load(model_name))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "71770f98", "metadata": {}, "outputs": [], "source": [ "\n", "img = Image.open(\"dataset/Strawberry_512/Whole/image_0017.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) \n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fda0bcc9", "metadata": {}, "outputs": [], "source": [ "\n", "img = Image.open(\"dataset/Strawberry_512/Hulled/image_0001.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) \n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4334ee87", "metadata": {}, "outputs": [], "source": [ "\n", "img = Image.open(\"dataset/Strawberry_512/Sliced/image_0001.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) \n", "\n", "visualize_channels(model, img_tensor, max_channels=16)" ] }, { "cell_type": "code", "execution_count": null, "id": "2e261e5f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 5 }