{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1b70151d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from PIL import Image\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import random\n", "import torch.nn as nn\n", "from PIL import Image\n", "import torch.nn.functional as F\n", "import torchvision.models as models\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from torchvision import models\n", "from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights" ] }, { "cell_type": "markdown", "id": "445679c7", "metadata": {}, "source": [ "### Functions" ] }, { "cell_type": "code", "execution_count": null, "id": "c37cc27b", "metadata": {}, "outputs": [], "source": [ "def load_images_from_folder(folder_path, image_size=(224, 224)):\n", " images = []\n", " for root, _, files in os.walk(folder_path):\n", " for file in files:\n", " if file.lower().endswith((\".jpg\", \".jpeg\")):\n", " try:\n", " img_path = os.path.join(root, file)\n", " img = Image.open(img_path).convert(\"RGB\")\n", " img = img.resize(image_size)\n", " images.append(np.array(img))\n", " except Exception as e:\n", " print(f\"Failed on {img_path}: {e}\")\n", " return np.array(images)\n", "\n", "def plot_rgb_histogram_subplot(ax, images, class_name):\n", " sample = images[random.randint(0, len(images) - 1)]\n", " colors = ('r', 'g', 'b')\n", " for i, col in enumerate(colors):\n", " hist = np.histogram(sample[:, :, i], bins=256, range=(0, 256))[0]\n", " ax.plot(hist, color=col)\n", " ax.set_title(f\"RGB Histogram – {class_name.capitalize()}\")\n", " ax.set_xlabel(\"Pixel Value\")\n", " ax.set_ylabel(\"Frequency\")\n", " \n", "def augment_rotations(X, y):\n", " X_aug = []\n", " y_aug = []\n", " for k in [1, 2, 3]: \n", " X_rot = torch.rot90(X, k=k, dims=[2, 3])\n", " X_aug.append(X_rot)\n", " y_aug.append(y.clone())\n", " return torch.cat(X_aug), torch.cat(y_aug)\n" ] }, { "cell_type": "markdown", "id": "17c5b6fa", "metadata": {}, "source": [ "### Dataset Location" ] }, { "cell_type": "code", "execution_count": null, "id": "3f833049", "metadata": {}, "outputs": [], "source": [ "onion_folder = \"dataset/Onion_512\"\n", "strawberry_folder = \"dataset/Strawberry_512\"\n", "pear_folder = \"dataset/Pear_512\"\n", "tomato_folder = \"dataset/Tomato_512\"" ] }, { "cell_type": "markdown", "id": "afce8611", "metadata": {}, "source": [ "### loading dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "1e913838", "metadata": {}, "outputs": [], "source": [ "onion_images = load_images_from_folder(onion_folder)\n", "strawberry_images = load_images_from_folder(strawberry_folder)\n", "pear_images = load_images_from_folder(pear_folder)\n", "tomato_images = load_images_from_folder(tomato_folder)\n", "\n", "print(\"onion_images:\", onion_images.shape)\n", "print(\"strawberry_images:\", strawberry_images.shape)\n", "print(\"pear_images:\", pear_images.shape)\n", "print(\"tomato_images:\", tomato_images.shape)\n" ] }, { "cell_type": "markdown", "id": "07f5d12e", "metadata": {}, "source": [ "Each of our classes have got around ~3000 samples" ] }, { "cell_type": "markdown", "id": "80e9ecc3", "metadata": {}, "source": [ "### Visualizing image" ] }, { "cell_type": "code", "execution_count": null, "id": "00149f35", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import random\n", "datasets = {\n", " \"onion\": onion_images,\n", " \"strawberry\": strawberry_images,\n", " \"pear\": pear_images,\n", " \"tomato\": tomato_images\n", "}\n", "\n", "\n", "def show_random_samples(images, class_name, count=5):\n", " indices = random.sample(range(images.shape[0]), count)\n", " selected = images[indices]\n", "\n", " plt.figure(figsize=(10, 2))\n", " for i, img in enumerate(selected):\n", " plt.subplot(1, count, i+1)\n", " plt.imshow(img.astype(np.uint8))\n", " plt.axis('off')\n", " plt.suptitle(f\"{class_name.capitalize()} – Random {count} Samples\", fontsize=16)\n", " plt.show()\n", "\n", "for class_name, image_array in datasets.items():\n", " show_random_samples(image_array, class_name)\n" ] }, { "cell_type": "markdown", "id": "ab765929", "metadata": {}, "source": [ "### Getting RGB pixel count per class" ] }, { "cell_type": "code", "execution_count": null, "id": "dcafbe0c", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, len(datasets), figsize=(20, 5))\n", "\n", "for ax, (class_name, images) in zip(axes, datasets.items()):\n", " plot_rgb_histogram_subplot(ax, images, class_name)\n", " ax.label_outer()\n", "\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "7a565dee", "metadata": {}, "source": [ "## RGB Histogram Analysis: What It Tells Us About the Dataset\n", "\n", "This RGB histogram plot shows the **distribution of pixel intensities** for the **Red, Green, and Blue channels** in one sample image per class (`Onion`, `Strawberry`, `Pear`, `Tomato`). \n", "It’s a **visual summary of color composition** and can reveal important patterns about your dataset.\n", "\n", "---\n", "\n", "### πŸ” General Insights\n", "\n", "#### 1. Class Color Signatures\n", "Each class has a unique RGB distribution:\n", "\n", "- The model can learn to **distinguish classes based on color patterns**.\n", "- **Example:**\n", " - `Tomato`: Strong red peaks.\n", " - `Pear`: Dominant green and blue bands.\n", "\n", "---\n", "\n", "#### 2. Image Quality / Noise\n", "Unusual spikes or flat histograms may indicate:\n", "\n", "- **Overexposed or underexposed images**.\n", "- **Noisy or poor-quality samples** (e.g., background dominates the image).\n", "\n", "---\n", "\n", "#### 3. πŸ“Š Channel Dominance / Balance\n", "Histogram analysis helps decide:\n", "\n", "- Should we **convert to grayscale**? \n", " (Useful if R, G, B histograms are nearly identical.)\n", "- As we see in majority of classes the R,G,B variation is distinct(in onion it's almost the same), hence we need RGB channles in input\n", "\n", "---\n", "\n", "### πŸ“ˆ Per-Class Histogram Analysis\n", "\n", "---\n", "\n", "#### πŸ§… Onion\n", "- **Red & Green:** Sharp peaks around 140–150.\n", "- **Blue:** Dominant with a broad peak around 100.\n", "- **Interpretation:**\n", " - Likely represents white/yellow onion layers with subtle shadows.\n", " - Dominant blue may come from lighting or background.\n", "- **Implications:**\n", " - The model may learn to detect **mid-range blue with sharp red-green peaks**.\n", "\n", "---\n", "\n", "#### πŸ“ Strawberry\n", "- **Red:** Strong peaks at ~80 and ~220.\n", "- **Green & Blue:** Broader and less frequent.\n", "- **Interpretation:**\n", " - High red intensity is consistent with strawberry skin.\n", " - Low blue confirms lack of bluish tones.\n", "- **Implications:**\n", " - A **very color-distinct class**.\n", " - The model can learn it easily with minimal augmentation.\n", "\n", "---\n", "\n", "#### 🍐 Pear\n", "- **Green & Blue:** Peaks between 50–120.\n", "- **Red:** Moderate and broad around 100–150.\n", "- **Interpretation:**\n", " - Pear skin includes light green/yellow shades with reflections.\n", " - Background or lighting likely increases blue response.\n", " - All three channels show similar trends, suggesting: Minimal variation in pear color, and Uniform background and illumination conditions\n", "- **Implications:**\n", " - Not much background variation in pear\n", "\n", "---\n", "\n", "#### πŸ… Tomato\n", "- **Red:** Extremely sharp peak at ~120.\n", "- **Green & Blue:** Very low, drop sharply after 100.\n", "- **Interpretation:**\n", " - Strongly saturated red β€” characteristic of ripe tomatoes.\n", "- **Implications:**\n", " - **Highly distinguishable** via color alone.\n", " - Risk of **overfitting to red features** if background is red.\n", "\n", "---\n" ] }, { "cell_type": "markdown", "id": "146e8b61", "metadata": {}, "source": [ "## Average image" ] }, { "cell_type": "code", "execution_count": null, "id": "42fca26e", "metadata": {}, "outputs": [], "source": [ "class_names = list(datasets.keys())\n", "num_classes = len(class_names)\n", "\n", "fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))\n", "\n", "for i, (class_name, images) in enumerate(datasets.items()):\n", " avg_img = np.mean(images.astype(np.float32), axis=0)\n", " axes[i].imshow(avg_img.astype(np.uint8))\n", " axes[i].set_title(f\"Average Image – {class_name.capitalize()}\")\n", " axes[i].axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "5ba3e21f", "metadata": {}, "source": [ "# Dataset Analysis Based on Average Images\n", "\n", "The average images of **Onion**, **Strawberry**, **Pear**, and **Tomato** offer valuable insights into the characteristics of the dataset they were generated from.\n", "\n", "---\n", "\n", "## General Observations\n", "\n", "1. **Blurriness of All Average Images** \n", " - The high level of blur suggests that the objects (fruits/vegetables) vary significantly in position, orientation, and size within the images.\n", " - There is no consistent alignment or cropping β€” objects appear in different parts of the frame across the dataset.\n", "\n", "2. **Centered Color Blobs** \n", " - Each average image displays a dominant color region toward the center:\n", " - πŸ§… Onion: pale pinkish-grey center\n", " - πŸ“ Strawberry: red core\n", " - 🍐 Pear: yellow-green diffuse center\n", " - πŸ… Tomato: reddish-orange with surrounding brown-green\n", " - This suggests that despite variation, most objects are somewhat centered in their respective images.\n", " - For pear and tomato, the shape and color are more distinct and localized in the average image. This suggests that in most of these images, the required object was centered with less positional variation. In contrast, for onion and strawberry, the increased blurriness and less defined color blobs suggest more positional variation.\n", "\n", "3. **Background Color and Texture** \n", " - All images share a similar gray-brown background tone.\n", " - This implies the dataset likely includes a variety of natural or neutral-colored backgrounds (e.g., kitchen settings, markets) rather than standardized white/black backgrounds.\n", "\n", "---\n", "\n", "## Implications for Model Training\n", "\n", "- **Color is a Strong Signal**\n", " - Dominant colors are preserved in each average image, suggesting that color-based features will play a major role in classification models. Therefore, it is important to retain all three color channels as input features.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "dec6064b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from torch.utils.data import Dataset, DataLoader, TensorDataset\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from torchvision import transforms\n", "\n", "X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)\n", "y = (\n", " ['onion'] * len(onion_images) +\n", " ['strawberry'] * len(strawberry_images) +\n", " ['pear'] * len(pear_images) +\n", " ['tomato'] * len(tomato_images)\n", ")\n", "\n", "X = X.astype(np.float32) / 255.0\n", "X = np.transpose(X, (0, 3, 1, 2)) \n", "X_tensor = torch.tensor(X)\n", "\n", "le = LabelEncoder()\n", "y_encoded = le.fit_transform(y)\n", "y_tensor = torch.tensor(y_encoded)\n", "\n", "X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)\n", "X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f265aea3", "metadata": {}, "outputs": [], "source": [ "batch_size = 32\n", "\n", "train_dataset = TensorDataset(X_train, y_train)\n", "val_dataset = TensorDataset(X_val, y_val)\n", "test_dataset = TensorDataset(X_test, y_test)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": null, "id": "36f26386", "metadata": {}, "outputs": [], "source": [ "del X_train, y_train" ] }, { "cell_type": "code", "execution_count": null, "id": "c469bc8d", "metadata": {}, "outputs": [], "source": [ "print(f\"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches\")\n", "print(f\"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches\")\n", "print(f\"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches\")" ] }, { "cell_type": "markdown", "id": "50277bca", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": null, "id": "02440bb8", "metadata": {}, "outputs": [], "source": [ "def get_efficientnet_model(num_classes):\n", " model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)\n", " model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "id": "6a516f06", "metadata": {}, "outputs": [], "source": [ "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " print(\"Using MPS (Apple GPU)\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"MPS not available. Using CPU\")\n", "\n", "model = get_efficientnet_model(num_classes=4).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "criterion = nn.CrossEntropyLoss()\n" ] }, { "cell_type": "markdown", "id": "eb7d6007", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "id": "245a6709", "metadata": {}, "outputs": [], "source": [ "best_val_acc = 0.0\n", "train_losses = []\n", "val_losses = []\n", "train_accs = []\n", "val_accs = []\n", "epochs_no_improve = 0\n", "early_stop = False\n", "patience = 5\n", "model_name = \"models/best_model_v1.pth\"\n", "\n", "for epoch in range(10):\n", " if early_stop:\n", " print(f\"Early stopping at epoch {epoch}\")\n", " break\n", " model.train()\n", " total_train_loss = 0\n", " train_correct = 0\n", " train_total = 0\n", "\n", " for batch_x, batch_y in train_loader:\n", " batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n", " preds = model(batch_x)\n", " loss = criterion(preds, batch_y)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_train_loss += loss.item()\n", " pred_labels = preds.argmax(dim=1)\n", " train_correct += (pred_labels == batch_y).sum().item()\n", " train_total += batch_y.size(0)\n", "\n", " train_accuracy = train_correct / train_total\n", " avg_train_loss = total_train_loss / len(train_loader)\n", " train_losses.append(avg_train_loss)\n", " train_accs.append(train_accuracy)\n", "\n", " \n", " model.eval()\n", " val_correct = val_total = 0\n", "\n", " with torch.no_grad():\n", " for val_x, val_y in val_loader:\n", " val_x, val_y = val_x.to(device), val_y.to(device)\n", " val_preds = model(val_x).argmax(dim=1)\n", " val_correct += (val_preds == val_y).sum().item()\n", " val_total += val_y.size(0)\n", "\n", " val_accuracy = val_correct / val_total\n", " validation_loss = criterion(model(val_x), val_y).item()\n", "\n", " val_losses.append(validation_loss)\n", " val_accs.append(val_accuracy)\n", "\n", " print(f\"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | \"\n", " f\"Train Acc: {train_accuracy:.4f} | Val Acc: {val_accuracy:.4f}\")\n", " if val_accuracy > best_val_acc:\n", " best_val_acc = val_accuracy\n", " torch.save(model.state_dict(), model_name)\n", " print(f\"New best model saved at epoch {epoch+1} with val acc {val_accuracy:.4f}\")\n", " epochs_no_improve = 0\n", " else:\n", " epochs_no_improve += 1\n", " print(f\"No improvement for {epochs_no_improve} epoch(s)\")\n", "\n", " if epochs_no_improve >= patience:\n", " print(f\"Validation accuracy did not improve for {patience} consecutive epochs. Stopping early.\")\n", " early_stop = True\n", "\n" ] }, { "cell_type": "markdown", "id": "abab0422", "metadata": {}, "source": [ "### Loss and Accuracy Plot" ] }, { "cell_type": "code", "execution_count": null, "id": "3bbab1d8", "metadata": {}, "outputs": [], "source": [ "epochs = range(1, len(train_losses) + 1)\n", "\n", "plt.figure(figsize=(12, 5))\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs, train_losses, label='Train Loss', marker='o')\n", "plt.plot(epochs, val_losses, label='Validation Loss', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.title('Loss per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')\n", "plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.title('Accuracy per Epoch')\n", "plt.legend()\n", "plt.grid(True)\n", "\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "930d22bd", "metadata": {}, "outputs": [], "source": [ "model = get_efficientnet_model(num_classes=4).to(device)\n", "model.load_state_dict(torch.load(model_name))\n", "model.eval() \n", "\n", "all_preds = []\n", "all_targets = []\n", "all_images = []\n", "\n", "with torch.no_grad():\n", " for batch_x, batch_y in test_loader:\n", " batch_x = batch_x.to(device)\n", " preds = model(batch_x).argmax(dim=1).cpu()\n", " all_preds.extend(preds.numpy())\n", " all_targets.extend(batch_y.numpy())\n", " all_images.extend(batch_x.cpu())\n", "\n", "test_correct = sum(np.array(all_preds) == np.array(all_targets))\n", "test_total = len(all_targets)\n", "test_accuracy = test_correct / test_total\n", "\n", "print(f\"\\nTest Accuracy: {test_accuracy:.4f}\")\n", "\n", "target_names = le.classes_ \n", "print(\"\\nClassification Report:\\n\")\n", "print(classification_report(all_targets, all_preds, target_names=target_names))\n", "\n", "cm = confusion_matrix(all_targets, all_preds)\n", "\n", "plt.figure(figsize=(6, 5))\n", "sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=target_names, yticklabels=target_names)\n", "plt.xlabel(\"Predicted Label\")\n", "plt.ylabel(\"True Label\")\n", "plt.title(\"Confusion Matrix\")\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "0430479f", "metadata": {}, "source": [ "## Sample FP, FN" ] }, { "cell_type": "code", "execution_count": null, "id": "4823498a", "metadata": {}, "outputs": [], "source": [ "all_preds = np.array(all_preds)\n", "all_targets = np.array(all_targets)\n", "all_images = torch.stack(all_images)\n", "\n", "for class_idx, class_name in enumerate(target_names):\n", " print(f\"\\nShowing False Negatives and False Positives for class: {class_name}\")\n", " fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]\n", " fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]\n", "\n", " def show_images(indices, title, max_images=5):\n", " num = min(len(indices), max_images)\n", " if num == 0:\n", " print(f\"No {title} samples.\")\n", " return\n", "\n", " plt.figure(figsize=(12, 2))\n", " for i, idx in enumerate(indices[:num]):\n", " img = all_images[idx]\n", " img = img.permute(1, 2, 0).numpy()\n", " plt.subplot(1, num, i + 1)\n", " plt.imshow((img - img.min()) / (img.max() - img.min()))\n", " plt.axis('off')\n", " plt.title(f\"Pred: {target_names[all_preds[idx]]}\\nTrue: {target_names[all_targets[idx]]}\")\n", " plt.suptitle(f\"{title} for {class_name}\")\n", " plt.tight_layout()\n", " plt.show()\n", "\n", " show_images(fn_indices, \"False Negatives\")\n", " show_images(fp_indices, \"False Positives\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "551cec6b", "metadata": {}, "outputs": [], "source": [ "def visualize_channels(model, image_tensor, max_channels=6):\n", " model.eval()\n", " activations = {}\n", "\n", " def get_activation(name):\n", " def hook(model, input, output):\n", " activations[name] = output.detach().cpu()\n", " return hook\n", "\n", " hooks = []\n", " for i in range(len(model.features)):\n", " layer = model.features[i]\n", " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n", "\n", " with torch.no_grad():\n", " _ = model(image_tensor.unsqueeze(0))\n", "\n", " for h in hooks:\n", " h.remove()\n", "\n", " for layer_name, fmap in activations.items():\n", " fmap = fmap.squeeze(0)\n", " num_channels = min(fmap.shape[0], max_channels)\n", "\n", " plt.figure(figsize=(num_channels * 2, 2.5))\n", " for i in range(num_channels):\n", " plt.subplot(1, num_channels, i + 1)\n", " plt.imshow(fmap[i], cmap='viridis')\n", " plt.title(f\"{layer_name} ch{i}\")\n", " plt.axis('off')\n", " plt.tight_layout()\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "147d63d5", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import matplotlib.pyplot as plt\n", "\n", "def visualize_channels(model, image_tensor, max_channels=6):\n", " model.eval()\n", " activations = {}\n", "\n", " def get_activation(name):\n", " def hook(model, input, output):\n", " activations[name] = output.detach().cpu()\n", " return hook\n", "\n", " hooks = []\n", " for i in range(len(model.features)):\n", " layer = model.features[i]\n", " hooks.append(layer.register_forward_hook(get_activation(f\"features_{i}\")))\n", "\n", " with torch.no_grad():\n", " _ = model(image_tensor.unsqueeze(0))\n", "\n", " for h in hooks:\n", " h.remove()\n", "\n", " for layer_name, fmap in activations.items():\n", " fmap = fmap.squeeze(0)\n", " channel_scores = fmap.mean(dim=(1, 2)) \n", " topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))\n", " top_indices = topk.indices\n", " plt.figure(figsize=(max_channels * 2, 2.5))\n", " for idx, ch in enumerate(top_indices):\n", " plt.subplot(1, max_channels, idx + 1)\n", " plt.imshow(fmap[ch], cmap='viridis')\n", " plt.title(f\"{layer_name}\\nch{ch.item()} ({channel_scores[ch]:.2f})\")\n", " plt.axis('off')\n", " plt.tight_layout()\n", " plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a6cc824a", "metadata": {}, "outputs": [], "source": [ "model = get_efficientnet_model(num_classes=4)\n", "model.load_state_dict(torch.load(\"models/best_model_v1.pth\"))\n", "model.eval()" ] }, { "cell_type": "markdown", "id": "36044d25", "metadata": {}, "source": [ "### Onion: Visulaize color channel " ] }, { "cell_type": "code", "execution_count": null, "id": "07206168", "metadata": {}, "outputs": [], "source": [ "\n", "img = Image.open(\"dataset/Onion_512/Whole/image_0001.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img)\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "markdown", "id": "bd24811d", "metadata": {}, "source": [ "### Pear: Visulaize color channel " ] }, { "cell_type": "code", "execution_count": null, "id": "deb3981a", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"dataset/Pear_512/Whole/image_0089.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img)\n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "markdown", "id": "be9d8f98", "metadata": {}, "source": [ "### Tomato: Visulaize color channel " ] }, { "cell_type": "code", "execution_count": null, "id": "930ebe01", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"dataset/Tomato_512/Whole/image_0001.jpg\").convert(\"RGB\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img)\n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "markdown", "id": "a4769764", "metadata": {}, "source": [ "### Strawberry: Visulaize color channel " ] }, { "cell_type": "code", "execution_count": null, "id": "a45dc523", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"dataset/Strawberry_512/Whole/image_0388.jpg\").convert(\"RGB\")\n", "transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor()\n", "])\n", "img_tensor = transform(img)\n", "\n", "visualize_channels(model, img_tensor, max_channels=16)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c695b7b6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 5 }