{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1b70151d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "import random\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import seaborn as sns\n", "import torchvision.models as models" ] }, { "cell_type": "code", "execution_count": null, "id": "c37cc27b", "metadata": {}, "outputs": [], "source": [ "def load_images_from_folder(folder_path, image_size=(224, 224)):\n", " images = []\n", " for root, _, files in os.walk(folder_path):\n", " for file in files:\n", " if file.lower().endswith((\".jpg\", \".jpeg\")):\n", " try:\n", " img_path = os.path.join(root, file)\n", " img = Image.open(img_path).convert(\"RGB\")\n", " img = img.resize(image_size)\n", " images.append(np.array(img))\n", " except Exception as e:\n", " print(f\"Failed on {img_path}: {e}\")\n", " return np.array(images)\n", "\n", "def plot_rgb_histogram_subplot(ax, images, class_name):\n", " sample = images[random.randint(0, len(images) - 1)]\n", " colors = ('r', 'g', 'b')\n", " for i, col in enumerate(colors):\n", " hist = np.histogram(sample[:, :, i], bins=256, range=(0, 256))[0]\n", " ax.plot(hist, color=col)\n", " ax.set_title(f\"RGB Histogram – {class_name.capitalize()}\")\n", " ax.set_xlabel(\"Pixel Value\")\n", " ax.set_ylabel(\"Frequency\")\n", "\n", "def augment_rotations(X, y):\n", " X_aug = []\n", " y_aug = []\n", " for k in [1, 2, 3]: \n", " X_rot = torch.rot90(X, k=k, dims=[2, 3]) \n", " X_aug.append(X_rot)\n", " y_aug.append(y.clone()) \n", " return torch.cat(X_aug), torch.cat(y_aug)" ] }, { "cell_type": "code", "execution_count": null, "id": "3f833049", "metadata": {}, "outputs": [], "source": [ "tomato_diced = \"dataset/Tomato_512/Diced\"\n", "tomato_vines = \"dataset/Tomato_512/On_the_vines\"\n", "tomato_whole = \"dataset/Tomato_512/Whole\"" ] }, { "cell_type": "code", "execution_count": null, "id": "1e913838", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "tomato_diced_images = load_images_from_folder(tomato_diced)\n", "tomato_vines_images = load_images_from_folder(tomato_vines)\n", "tomato_whole_images = load_images_from_folder(tomato_whole)\n", "\n", "print(\"Strawberry halved images:\", tomato_diced_images.shape)\n", "print(\"Strawberry sliced images:\", tomato_vines_images.shape)\n", "print(\"Strawberry whole images:\", tomato_whole_images.shape)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "00149f35", "metadata": {}, "outputs": [], "source": [ "\n", "datasets = {\n", " \"diced\": tomato_diced_images,\n", " \"vines\": tomato_vines_images,\n", " \"whole\": tomato_whole_images\n", "}\n", "\n", "\n", "def show_random_samples(images, class_name, count=5):\n", " indices = random.sample(range(images.shape[0]), count)\n", " selected = images[indices]\n", "\n", " plt.figure(figsize=(10, 2))\n", " for i, img in enumerate(selected):\n", " plt.subplot(1, count, i+1)\n", " plt.imshow(img.astype(np.uint8))\n", " plt.axis('off')\n", " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n", " plt.show()\n", "\n", "for class_name, image_array in datasets.items():\n", " show_random_samples(image_array, class_name)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "46a700fe", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, len(datasets), figsize=(20, 5))\n", "\n", "for ax, (class_name, images) in zip(axes, datasets.items()):\n", " plot_rgb_histogram_subplot(ax, images, class_name)\n", " ax.label_outer() \n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "9b0c6d31", "metadata": {}, "outputs": [], "source": [ "class_names = list(datasets.keys())\n", "num_classes = len(class_names)\n", "\n", "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4)) \n", "\n", "for i, (class_name, images) in enumerate(datasets.items()):\n", " avg_img = np.mean(images.astype(np.float32), axis=0)\n", " axes[i].imshow(avg_img.astype(np.uint8))\n", " axes[i].set_title(f\"Average Image – {class_name.capitalize()}\")\n", " axes[i].axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "dec6064b", "metadata": {}, "outputs": [], "source": [ "datasets = {\n", " \"diced\": tomato_diced_images,\n", " \"vines\": tomato_vines_images,\n", " \"whole\": tomato_whole_images\n", "}\n", "\n", "X = np.concatenate([tomato_diced_images, tomato_vines_images, tomato_whole_images], axis=0)\n", "y = (\n", " ['diced'] * len(tomato_diced_images) +\n", " ['vines'] * len(tomato_vines_images) +\n", " ['whole'] * len(tomato_whole_images)\n", ")\n", "\n", "X = X.astype(np.float32) / 255.0\n", "X = np.transpose(X, (0, 3, 1, 2))\n", "X_tensor = torch.tensor(X)\n", "\n", "le = LabelEncoder()\n", "y_encoded = le.fit_transform(y)\n", "y_tensor = torch.tensor(y_encoded)\n", "\n", "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.4, stratify=y_tensor, random_state=42)\n", "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f265aea3", "metadata": {}, "outputs": [], "source": [ "batch_size = 32\n", "\n", "X_augmented, y_augmented = augment_rotations(X_train, y_train)\n", "\n", "X_train_combined = torch.cat([X_train, X_augmented])\n", "y_train_combined = torch.cat([y_train, y_augmented])\n", "\n", "train_dataset = TensorDataset(X_train, y_train)\n", "val_dataset = TensorDataset(X_val, y_val)\n", "test_dataset = TensorDataset(X_test, y_test)\n", "\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": null, "id": "c469bc8d", "metadata": {}, "outputs": [], "source": [ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n", "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n", "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")" ] }, { "cell_type": "code", "execution_count": null, "id": "02440bb8", "metadata": {}, "outputs": [], "source": [ "def get_efficientnet_model(num_classes):\n", " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n", " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n", " return model\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6a516f06", "metadata": {}, "outputs": [], "source": [ "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " print(\"Using MPS (Apple GPU)\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"MPS not available. Using CPU\")\n", "\n", "model = get_efficientnet_model(num_classes=3).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "criterion = nn.CrossEntropyLoss()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "245a6709", "metadata": {}, "outputs": [], "source": [ "best_val_acc = 0.0\n", "train_losses = []\n", "val_losses = []\n", "train_accs = []\n", "val_accs = []\n", "epochs_no_improve = 0\n", "early_stop = False\n", "patience = 3\n", "\n", "for epoch in range(10):\n", " if early_stop:\n", " print(f\"Early stopping at epoch {epoch}\")\n", " break\n", " model.train()\n", " total_train_loss = 0\n", " train_correct = 0\n", " train_total = 0\n", "\n", " for batch_x, batch_y in train_loader:\n", " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", " preds = model(batch_x)\n", " loss = criterion(preds, batch_y)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_train_loss += loss.item()\n", "\n", " pred_labels = preds.argmax(dim=1)\n", " train_correct += (pred_labels == batch_y).sum().item()\n", " train_total += batch_y.size(0)\n", "\n", " train_accuracy = train_correct / train_total\n", " avg_train_loss = total_train_loss / len(train_loader)\n", " train_losses.append(avg_train_loss)\n", " train_accs.append(train_accuracy)\n", "\n", " \n", " model.eval()\n", " val_correct = val_total = 0\n", "\n", " with torch.no_grad():\n", " for val_x, val_y in val_loader:\n", " val_x, val_y = val_x.to(device), val_y.to(device)\n", " val_preds = model(val_x).argmax(dim=1)\n", " val_correct += (val_preds == val_y).sum().item()\n", " val_total += val_y.size(0)\n", "\n", " val_accuracy = val_correct / val_total\n", " validation_loss = criterion(model(val_x), val_y).item()\n", "\n", " val_losses.append(validation_loss)\n", " val_accs.append(val_accuracy)\n", "\n", " print(f\"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | \"\n", " f\"Train Acc: {train_accuracy:.4f} | Val Acc: {val_accuracy:.4f}\")\n", " if val_accuracy > best_val_acc:\n", " best_val_acc = val_accuracy\n", " torch.save(model.state_dict(), \"best_model_tomato_v1.pth\")\n", " print(f\"New best model saved at epoch {epoch+1} with val acc {val_accuracy:.4f}\")\n", " epochs_no_improve = 0\n", " else:\n", " epochs_no_improve += 1\n", " print(f\"No improvement for {epochs_no_improve} epoch(s)\")\n", "\n", " if epochs_no_improve >= patience:\n", " print(f\"Validation accuracy did not improve for {patience} consecutive epochs. Stopping early.\")\n", " early_stop = True\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3bbab1d8", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "epochs = range(1, len(train_losses) + 1)\n", "\n", "plt.figure(figsize=(12, 5))\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n", "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.title('Loss per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n", "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.title('Accuracy per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "930d22bd", "metadata": {}, "outputs": [], "source": [ "model = get_efficientnet_model(num_classes=3).to(device)\n", "model.load_state_dict(torch.load(\"models/best_model_tomato_v1.pth\"))\n", "model.eval() \n", "\n", "all_preds = []\n", "all_targets = []\n", "all_images = []\n", "\n", "with torch.no_grad():\n", " for batch_x, batch_y in test_loader:\n", " batch_x = batch_x.to(device)\n", " preds = model(batch_x).argmax(dim=1).cpu()\n", " all_preds.extend(preds.numpy())\n", " all_targets.extend(batch_y.numpy())\n", " all_images.extend(batch_x.cpu())\n", "\n", "test_correct = sum(np.array(all_preds) == np.array(all_targets))\n", "test_total = len(all_targets)\n", "test_accuracy = test_correct / test_total\n", "\n", "print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n", "\n", "target_names = le.classes_\n", "print(\"\\nClassification Report:\\n\")\n", "print(classification_report(all_targets, all_preds, target_names=target_names))\n", "\n", "cm = confusion_matrix(all_targets, all_preds)\n", "\n", "plt.figure(figsize=(6, 5))\n", "sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=target_names, yticklabels=target_names)\n", "plt.xlabel(\"Predicted Label\")\n", "plt.ylabel(\"True Label\")\n", "plt.title(\"Confusion Matrix\")\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4823498a", "metadata": {}, "outputs": [], "source": [ "\n", "all_preds = np.array(all_preds)\n", "all_targets = np.array(all_targets)\n", "all_images = torch.stack(all_images) \n", "\n", "for class_idx, class_name in enumerate(target_names):\n", " print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n", " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n", " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n", "\n", " def show_images(indices, title, max_images=5):\n", " num = min(len(indices), max_images)\n", " if num == 0:\n", " print(f\"No {title} samples.\")\n", " return\n", "\n", " plt.figure(figsize=(12, 2))\n", " for i, idx in enumerate(indices[:num]):\n", " img = all_images[idx]\n", " img = img.permute(1, 2, 0).numpy()\n", " plt.subplot(1, num, i + 1)\n", " plt.imshow((img - img.min()) / (img.max() - img.min())) \n", " plt.axis('off')\n", " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n", " plt.suptitle(f\"{title} for {class_name}\")\n", " plt.tight_layout()\n", " plt.show()\n", "\n", " show_images(fn_indices, \"False Negatives\")\n", " show_images(fp_indices, \"False Positives\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "551cec6b", "metadata": {}, "outputs": [], "source": [ "def visualize_channels(model, image_tensor, max_channels=6):\n", " model.eval()\n", " activations = {}\n", "\n", " def get_activation(name):\n", " def hook(model, input, output):\n", " activations[name] = output.detach().cpu()\n", " return hook\n", "\n", " hooks = []\n", " for i in range(len(model.features)):\n", " layer = model.features[i]\n", " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n", "\n", " with torch.no_grad():\n", " _ = model(image_tensor.unsqueeze(0)) \n", "\n", " for h in hooks:\n", " h.remove()\n", "\n", " for layer_name, fmap in activations.items():\n", " fmap = fmap.squeeze(0) \n", "\n", " channel_scores = fmap.mean(dim=(1, 2)) \n", "\n", " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n", " top_indices = topk.indices\n", "\n", " plt.figure(figsize=(max_channels * 2, 2.5))\n", " for idx, ch in enumerate(top_indices):\n", " plt.subplot(1, max_channels, idx + 1)\n", " plt.imshow(fmap[ch], cmap='viridis')\n", " plt.title(f\"{layer_name}\\nch{ch.item()} ({channel_scores[ch]:.2f})\")\n", " plt.axis('off')\n", " plt.tight_layout()\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4b524ecc", "metadata": {}, "outputs": [], "source": [ "model = get_efficientnet_model(num_classes=3)\n", "model.load_state_dict(torch.load(\"models/best_model_tomato_v1.pth\"))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "8c2c9f8b", "metadata": {}, "outputs": [], "source": [ "\n", "img = Image.open(\"dataset/Tomato_512/Whole/image_0007.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) \n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c382875b", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"dataset/Tomato_512/On_the_vines/image_0578.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) \n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3c450913", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"dataset/Tomato_512/Diced/image_0578.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img) \n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d54e3240", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 5 }