{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1b70151d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "import random\n", "import torch\n", "import numpy as np\n", "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import seaborn as sns\n", "import torchvision.models as models" ] }, { "cell_type": "code", "execution_count": null, "id": "c37cc27b", "metadata": {}, "outputs": [], "source": [ "def load_images_from_folder(folder_path, image_size=(224, 224)):\n", " images = []\n", " for root, _, files in os.walk(folder_path):\n", " for file in files:\n", " if file.lower().endswith((\".jpg\", \".jpeg\")):\n", " try:\n", " img_path = os.path.join(root, file)\n", " img = Image.open(img_path).convert(\"RGB\")\n", " img = img.resize(image_size)\n", " images.append(np.array(img))\n", " except Exception as e:\n", " print(f\"Failed on {img_path}: {e}\")\n", " return np.array(images)\n", "\n", "def plot_rgb_histogram_subplot(ax, images, class_name):\n", " sample = images[random.randint(0, len(images) - 1)]\n", " colors = ('r', 'g', 'b')\n", " for i, col in enumerate(colors):\n", " hist = np.histogram(sample[:, :, i], bins=256, range=(0, 256))[0]\n", " ax.plot(hist, color=col)\n", " ax.set_title(f\"RGB Histogram – {class_name.capitalize()}\")\n", " ax.set_xlabel(\"Pixel Value\")\n", " ax.set_ylabel(\"Frequency\")\n", "\n", "def augment_rotations(X, y):\n", " X_aug = []\n", " y_aug = []\n", " for k in [1, 2, 3]: # 90, 180, 270 degrees\n", " X_rot = torch.rot90(X, k=k, dims=[2, 3]) # rotate along H and W\n", " X_aug.append(X_rot)\n", " y_aug.append(y.clone()) # Same labels for rotated images\n", " return torch.cat(X_aug), torch.cat(y_aug)" ] }, { "cell_type": "code", "execution_count": null, "id": "3f833049", "metadata": {}, "outputs": [], "source": [ "tomato_diced = \"dataset/Tomato_512/Diced\"\n", "tomato_vines = \"dataset/Tomato_512/On_the_vines\"\n", "tomato_whole = \"dataset/Tomato_512/Whole\"" ] }, { "cell_type": "code", "execution_count": null, "id": "1e913838", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "tomato_diced_images = load_images_from_folder(tomato_diced)\n", "tomato_vines_images = load_images_from_folder(tomato_vines)\n", "tomato_whole_images = load_images_from_folder(tomato_whole)\n", "\n", "print(\"Strawberry halved images:\", tomato_diced_images.shape)\n", "print(\"Strawberry sliced images:\", tomato_vines_images.shape)\n", "print(\"Strawberry whole images:\", tomato_whole_images.shape)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "00149f35", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import random\n", "datasets = {\n", " \"diced\": tomato_diced_images,\n", " \"vines\": tomato_vines_images,\n", " \"whole\": tomato_whole_images\n", "}\n", "\n", "\n", "def show_random_samples(images, class_name, count=5):\n", " indices = random.sample(range(images.shape[0]), count)\n", " selected = images[indices]\n", "\n", " plt.figure(figsize=(10, 2))\n", " for i, img in enumerate(selected):\n", " plt.subplot(1, count, i+1)\n", " plt.imshow(img.astype(np.uint8))\n", " plt.axis('off')\n", " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n", " plt.show()\n", "\n", "# Display for each class\n", "for class_name, image_array in datasets.items():\n", " show_random_samples(image_array, class_name)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "46a700fe", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, len(datasets), figsize=(20, 5))\n", "\n", "for ax, (class_name, images) in zip(axes, datasets.items()):\n", " plot_rgb_histogram_subplot(ax, images, class_name)\n", " ax.label_outer() # Hide x labels and tick labels for inner plots\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "9b0c6d31", "metadata": {}, "outputs": [], "source": [ "class_names = list(datasets.keys())\n", "num_classes = len(class_names)\n", "\n", "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) # 1 row, 4 columns\n", "\n", "for i, (class_name, images) in enumerate(datasets.items()):\n", " avg_img = np.mean(images.astype(np.float32), axis=0)\n", " axes[i].imshow(avg_img.astype(np.uint8))\n", " axes[i].set_title(f\"Average Image – {class_name.capitalize()}\")\n", " axes[i].axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "dec6064b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from torchvision import transforms\n", "\n", "datasets = {\n", " \"diced\": tomato_diced_images,\n", " \"vines\": tomato_vines_images,\n", " \"whole\": tomato_whole_images\n", "}\n", "\n", "# Combine data\n", "X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n", "y = (\n", " ['diced'] * len(tomato_diced_images) +\n", " ['vines'] * len(tomato_vines_images) +\n", " ['whole'] * len(tomato_whole_images)\n", ")\n", "\n", "# Normalize and convert to torch tensors\n", "X = X.astype(np.float32) / 255.0\n", "X = np.transpose(X, (0, 3, 1, 2)) # (N, C, H, W)\n", "X_tensor = torch.tensor(X)\n", "\n", "# Encode labels\n", "le = LabelEncoder()\n", "y_encoded = le.fit_transform(y)\n", "y_tensor = torch.tensor(y_encoded)\n", "\n", "# Train/val/test split\n", "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n", "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f265aea3", "metadata": {}, "outputs": [], "source": [ "batch_size = 32\n", "\n", "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n", "\n", "# Combine original and augmented data\n", "X_train_combined = torch.cat([X_train, X_augmented])\n", "y_train_combined = torch.cat([y_train, y_augmented])\n", "\n", "# Create new training dataset and loader\n", "\n", "train_dataset = TensorDataset(X_train, y_train)\n", "val_dataset = TensorDataset(X_val, y_val)\n", "test_dataset = TensorDataset(X_test, y_test)\n", "\n", "# DataLoaders\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": null, "id": "c469bc8d", "metadata": {}, "outputs": [], "source": [ "print(f\"šŸ”¢ Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n", "print(f\"šŸ”¢ Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n", "print(f\"šŸ”¢ Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")" ] }, { "cell_type": "code", "execution_count": null, "id": "02440bb8", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torch.nn as nn\n", "import torchvision.models as models\n", "\n", "def get_efficientnet_model(num_classes):\n", " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n", "\n", " # Replace classifier head with custom head\n", " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n", "\n", " return model\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6a516f06", "metadata": {}, "outputs": [], "source": [ "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " print(\"āœ… Using MPS (Apple GPU)\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"āš ļø MPS not available. Using CPU\")\n", "\n", "model = get_efficientnet_model(num_classes=3).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "criterion = nn.CrossEntropyLoss()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "245a6709", "metadata": {}, "outputs": [], "source": [ "best_val_acc = 0.0\n", "train_losses = []\n", "val_losses = []\n", "train_accs = []\n", "val_accs = []\n", "epochs_no_improve = 0\n", "early_stop = False\n", "patience = 3\n", "\n", "for epoch in range(10):\n", " if early_stop:\n", " print(f\"Early stopping at epoch {epoch}\")\n", " break\n", " model.train()\n", " total_train_loss = 0\n", " train_correct = 0\n", " train_total = 0\n", "\n", " for batch_x, batch_y in train_loader:\n", " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", " preds = model(batch_x)\n", " loss = criterion(preds, batch_y)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_train_loss += loss.item()\n", "\n", " # Track training accuracy\n", " pred_labels = preds.argmax(dim=1)\n", " train_correct += (pred_labels == batch_y).sum().item()\n", " train_total += batch_y.size(0)\n", "\n", " train_accuracy = train_correct / train_total\n", " avg_train_loss = total_train_loss / len(train_loader)\n", " train_losses.append(avg_train_loss)\n", " train_accs.append(train_accuracy)\n", "\n", " \n", " model.eval()\n", " val_correct = val_total = 0\n", "\n", " with torch.no_grad():\n", " for val_x, val_y in val_loader:\n", " val_x, val_y = val_x.to(device), val_y.to(device)\n", " val_preds = model(val_x).argmax(dim=1)\n", " val_correct += (val_preds == val_y).sum().item()\n", " val_total += val_y.size(0)\n", "\n", " val_accuracy = val_correct / val_total\n", " validation_loss = criterion(model(val_x), val_y).item()\n", "\n", " # After calculating val_accuracy\n", " val_losses.append(validation_loss)\n", " val_accs.append(val_accuracy)\n", "\n", " print(f\"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | \"\n", " f\"Train Acc: {train_accuracy:.4f} | Val Acc: {val_accuracy:.4f}\")\n", " if val_accuracy > best_val_acc:\n", " best_val_acc = val_accuracy\n", " torch.save(model.state_dict(), \"best_model_tomato_v1.pth\")\n", " print(f\"New best model saved at epoch {epoch+1} with val acc {val_accuracy:.4f}\")\n", " epochs_no_improve = 0\n", " else:\n", " epochs_no_improve += 1\n", " print(f\"No improvement for {epochs_no_improve} epoch(s)\")\n", "\n", " if epochs_no_improve >= patience:\n", " print(f\"Validation accuracy did not improve for {patience} consecutive epochs. Stopping early.\")\n", " early_stop = True\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3bbab1d8", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "epochs = range(1, len(train_losses) + 1)\n", "\n", "plt.figure(figsize=(12, 5))\n", "\n", "# Plot Loss\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n", "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.title('Loss per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "# Plot Accuracy\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n", "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.title('Accuracy per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "930d22bd", "metadata": {}, "outputs": [], "source": [ "model = get_efficientnet_model(num_classes=3).to(device)\n", "model.load_state_dict(torch.load(\"models/best_model_tomato_v1.pth\"))\n", "model.eval() \n", "\n", "all_preds = []\n", "all_targets = []\n", "all_images = []\n", "\n", "with torch.no_grad():\n", " for batch_x, batch_y in test_loader:\n", " batch_x = batch_x.to(device)\n", " preds = model(batch_x).argmax(dim=1).cpu()\n", " all_preds.extend(preds.numpy())\n", " all_targets.extend(batch_y.numpy())\n", " all_images.extend(batch_x.cpu())\n", "\n", "test_correct = sum(np.array(all_preds) == np.array(all_targets))\n", "test_total = len(all_targets)\n", "test_accuracy = test_correct / test_total\n", "\n", "print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n", "\n", "target_names = le.classes_\n", "print(\"\\nClassification Report:\\n\")\n", "print(classification_report(all_targets, all_preds, target_names=target_names))\n", "\n", "cm = confusion_matrix(all_targets, all_preds)\n", "\n", "plt.figure(figsize=(6, 5))\n", "sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=target_names, yticklabels=target_names)\n", "plt.xlabel(\"Predicted Label\")\n", "plt.ylabel(\"True Label\")\n", "plt.title(\"Confusion Matrix\")\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4823498a", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "all_preds = np.array(all_preds)\n", "all_targets = np.array(all_targets)\n", "all_images = torch.stack(all_images) # shape: [N, C, H, W]\n", "\n", "# Per class FP and FN\n", "for class_idx, class_name in enumerate(target_names):\n", " print(f\"\\nšŸ” Showing False Negatives and False Positives for class: {class_name}\")\n", "\n", " # False Negatives: True label is class_idx, but predicted something else\n", " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n", " # False Positives: Predicted class_idx, but true label is different\n", " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n", "\n", " def show_images(indices, title, max_images=5):\n", " num = min(len(indices), max_images)\n", " if num == 0:\n", " print(f\"āŒ No {title} samples.\")\n", " return\n", "\n", " plt.figure(figsize=(12, 2))\n", " for i, idx in enumerate(indices[:num]):\n", " img = all_images[idx]\n", " img = img.permute(1, 2, 0).numpy() # [C, H, W] → [H, W, C]\n", " plt.subplot(1, num, i + 1)\n", " plt.imshow((img - img.min()) / (img.max() - img.min())) # normalize to [0,1] for display\n", " plt.axis('off')\n", " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n", " plt.suptitle(f\"{title} for {class_name}\")\n", " plt.tight_layout()\n", " plt.show()\n", "\n", " show_images(fn_indices, \"False Negatives\")\n", " show_images(fp_indices, \"False Positives\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "551cec6b", "metadata": {}, "outputs": [], "source": [ "def visualize_channels(model, image_tensor, max_channels=6):\n", " model.eval()\n", " activations = {}\n", "\n", " def get_activation(name):\n", " def hook(model, input, output):\n", " activations[name] = output.detach().cpu()\n", " return hook\n", "\n", " # Register hooks for all layers in model.features\n", " hooks = []\n", " for i in range(len(model.features)):\n", " layer = model.features[i]\n", " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n", "\n", " with torch.no_grad():\n", " _ = model(image_tensor.unsqueeze(0)) # Add batch dimension: [1, 3, 224, 224]\n", "\n", " for h in hooks:\n", " h.remove()\n", "\n", " for layer_name, fmap in activations.items():\n", " fmap = fmap.squeeze(0) # Shape: [C, H, W]\n", "\n", " # Compute mean activation per channel\n", " channel_scores = fmap.mean(dim=(1, 2)) # [C]\n", "\n", " # Get indices of top-k channels\n", " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n", " top_indices = topk.indices\n", "\n", " # Plot top-k channels\n", " plt.figure(figsize=(max_channels * 2, 2.5))\n", " for idx, ch in enumerate(top_indices):\n", " plt.subplot(1, max_channels, idx + 1)\n", " plt.imshow(fmap[ch], cmap='viridis')\n", " plt.title(f\"{layer_name}\\nch{ch.item()} ({channel_scores[ch]:.2f})\")\n", " plt.axis('off')\n", " plt.tight_layout()\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4b524ecc", "metadata": {}, "outputs": [], "source": [ "model = get_efficientnet_model(num_classes=3)\n", "model.load_state_dict(torch.load(\"models/best_model_tomato_v1.pth\"))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "8c2c9f8b", "metadata": {}, "outputs": [], "source": [ "\n", "img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n", "\n", "# Preprocessing (must match model requirements)\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) # shape: [3, 224, 224]\n", "\n", "# Visualize feature maps\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c382875b", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n", "\n", "# Preprocessing (must match model requirements)\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) # shape: [3, 224, 224]\n", "\n", "# Visualize feature maps\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3c450913", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n", "\n", "# Preprocessing (must match model requirements)\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) # shape: [3, 224, 224]\n", "\n", "# Visualize feature maps\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d54e3240", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 5 }