In [None]:
import os
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt
import random
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
import torchvision.models as models
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import models
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

### Functions

In [None]:
def load_images_from_folder(folder_path, image_size=(224, 224)):
 images = []
 for root, _, files in os.walk(folder_path):
 for file in files:
 if file.lower().endswith((".jpg", ".jpeg")):
 try:
 img_path = os.path.join(root, file)
 img = Image.open(img_path).convert("RGB")
 img = img.resize(image_size)
 images.append(np.array(img))
 except Exception as e:
 print(f"Failed on {img_path}: {e}")
 return np.array(images)

def plot_rgb_histogram_subplot(ax, images, class_name):
 sample = images[random.randint(0, len(images) - 1)]
 colors = ('r', 'g', 'b')
 for i, col in enumerate(colors):
 hist = np.histogram(sample[:, :, i], bins=256, range=(0, 256))[0]
 ax.plot(hist, color=col)
 ax.set_title(f"RGB Histogram – {class_name.capitalize()}")
 ax.set_xlabel("Pixel Value")
 ax.set_ylabel("Frequency")
 
def augment_rotations(X, y):
 X_aug = []
 y_aug = []
 for k in [1, 2, 3]: 
 X_rot = torch.rot90(X, k=k, dims=[2, 3])
 X_aug.append(X_rot)
 y_aug.append(y.clone())
 return torch.cat(X_aug), torch.cat(y_aug)


### Dataset Location

In [None]:
onion_folder = "dataset/Onion_512"
strawberry_folder = "dataset/Strawberry_512"
pear_folder = "dataset/Pear_512"
tomato_folder = "dataset/Tomato_512"

### loading dataset

In [None]:
onion_images = load_images_from_folder(onion_folder)
strawberry_images = load_images_from_folder(strawberry_folder)
pear_images = load_images_from_folder(pear_folder)
tomato_images = load_images_from_folder(tomato_folder)

print("onion_images:", onion_images.shape)
print("strawberry_images:", strawberry_images.shape)
print("pear_images:", pear_images.shape)
print("tomato_images:", tomato_images.shape)


Each of our classes have got around ~3000 samples

### Visualizing image

In [None]:
import matplotlib.pyplot as plt
import random
datasets = {
 "onion": onion_images,
 "strawberry": strawberry_images,
 "pear": pear_images,
 "tomato": tomato_images
}


def show_random_samples(images, class_name, count=5):
 indices = random.sample(range(images.shape[0]), count)
 selected = images[indices]

 plt.figure(figsize=(10, 2))
 for i, img in enumerate(selected):
 plt.subplot(1, count, i+1)
 plt.imshow(img.astype(np.uint8))
 plt.axis('off')
 plt.suptitle(f"{class_name.capitalize()} – Random {count} Samples", fontsize=16)
 plt.show()

for class_name, image_array in datasets.items():
 show_random_samples(image_array, class_name)


### Getting RGB pixel count per class

In [None]:
fig, axes = plt.subplots(1, len(datasets), figsize=(20, 5))

for ax, (class_name, images) in zip(axes, datasets.items()):
 plot_rgb_histogram_subplot(ax, images, class_name)
 ax.label_outer()

plt.tight_layout()
plt.show()


## RGB Histogram Analysis: What It Tells Us About the Dataset

This RGB histogram plot shows the **distribution of pixel intensities** for the **Red, Green, and Blue channels** in one sample image per class (`Onion`, `Strawberry`, `Pear`, `Tomato`). 
It’s a **visual summary of color composition** and can reveal important patterns about your dataset.

---

### πŸ” General Insights

#### 1. Class Color Signatures
Each class has a unique RGB distribution:

- The model can learn to **distinguish classes based on color patterns**.
- **Example:**
 - `Tomato`: Strong red peaks.
 - `Pear`: Dominant green and blue bands.

---

#### 2. Image Quality / Noise
Unusual spikes or flat histograms may indicate:

- **Overexposed or underexposed images**.
- **Noisy or poor-quality samples** (e.g., background dominates the image).

---

#### 3. πŸ“Š Channel Dominance / Balance
Histogram analysis helps decide:

- Should we **convert to grayscale**? 
 (Useful if R, G, B histograms are nearly identical.)
- As we see in majority of classes the R,G,B variation is distinct(in onion it's almost the same), hence we need RGB channles in input

---

### πŸ“ˆ Per-Class Histogram Analysis

---

#### πŸ§… Onion
- **Red & Green:** Sharp peaks around 140–150.
- **Blue:** Dominant with a broad peak around 100.
- **Interpretation:**
 - Likely represents white/yellow onion layers with subtle shadows.
 - Dominant blue may come from lighting or background.
- **Implications:**
 - The model may learn to detect **mid-range blue with sharp red-green peaks**.

---

#### πŸ“ Strawberry
- **Red:** Strong peaks at ~80 and ~220.
- **Green & Blue:** Broader and less frequent.
- **Interpretation:**
 - High red intensity is consistent with strawberry skin.
 - Low blue confirms lack of bluish tones.
- **Implications:**
 - A **very color-distinct class**.
 - The model can learn it easily with minimal augmentation.

---

#### 🍐 Pear
- **Green & Blue:** Peaks between 50–120.
- **Red:** Moderate and broad around 100–150.
- **Interpretation:**
 - Pear skin includes light green/yellow shades with reflections.
 - Background or lighting likely increases blue response.
 - All three channels show similar trends, suggesting: Minimal variation in pear color, and Uniform background and illumination conditions
- **Implications:**
 - Not much background variation in pear

---

#### πŸ… Tomato
- **Red:** Extremely sharp peak at ~120.
- **Green & Blue:** Very low, drop sharply after 100.
- **Interpretation:**
 - Strongly saturated red β€” characteristic of ripe tomatoes.
- **Implications:**
 - **Highly distinguishable** via color alone.
 - Risk of **overfitting to red features** if background is red.

---


## Average image

In [None]:
class_names = list(datasets.keys())
num_classes = len(class_names)

fig, axes = plt.subplots(1, num_classes, figsize=(4 * num_classes, 4))

for i, (class_name, images) in enumerate(datasets.items()):
 avg_img = np.mean(images.astype(np.float32), axis=0)
 axes[i].imshow(avg_img.astype(np.uint8))
 axes[i].set_title(f"Average Image – {class_name.capitalize()}")
 axes[i].axis('off')

plt.tight_layout()
plt.show()


# Dataset Analysis Based on Average Images

The average images of **Onion**, **Strawberry**, **Pear**, and **Tomato** offer valuable insights into the characteristics of the dataset they were generated from.

---

## General Observations

1. **Blurriness of All Average Images** 
 - The high level of blur suggests that the objects (fruits/vegetables) vary significantly in position, orientation, and size within the images.
 - There is no consistent alignment or cropping β€” objects appear in different parts of the frame across the dataset.

2. **Centered Color Blobs** 
 - Each average image displays a dominant color region toward the center:
 - πŸ§… Onion: pale pinkish-grey center
 - πŸ“ Strawberry: red core
 - 🍐 Pear: yellow-green diffuse center
 - πŸ… Tomato: reddish-orange with surrounding brown-green
 - This suggests that despite variation, most objects are somewhat centered in their respective images.
 - For pear and tomato, the shape and color are more distinct and localized in the average image. This suggests that in most of these images, the required object was centered with less positional variation. In contrast, for onion and strawberry, the increased blurriness and less defined color blobs suggest more positional variation.

3. **Background Color and Texture** 
 - All images share a similar gray-brown background tone.
 - This implies the dataset likely includes a variety of natural or neutral-colored backgrounds (e.g., kitchen settings, markets) rather than standardized white/black backgrounds.

---

## Implications for Model Training

- **Color is a Strong Signal**
 - Dominant colors are preserved in each average image, suggesting that color-based features will play a major role in classification models. Therefore, it is important to retain all three color channels as input features.

---

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torchvision import transforms

X = np.concatenate([onion_images, strawberry_images, pear_images, tomato_images], axis=0)
y = (
 ['onion'] * len(onion_images) +
 ['strawberry'] * len(strawberry_images) +
 ['pear'] * len(pear_images) +
 ['tomato'] * len(tomato_images)
)

X = X.astype(np.float32) / 255.0
X = np.transpose(X, (0, 3, 1, 2)) 
X_tensor = torch.tensor(X)

le = LabelEncoder()
y_encoded = le.fit_transform(y)
y_tensor = torch.tensor(y_encoded)

X_train, X_temp, y_train, y_temp = train_test_split(X_tensor, y_tensor, test_size=0.5, stratify=y_tensor, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)


In [None]:
batch_size = 32

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
del X_train, y_train

In [None]:
print(f"Train Dataset: {len(train_dataset)} samples, {len(train_loader)} batches")
print(f"Val Dataset: {len(val_dataset)} samples, {len(val_loader)} batches")
print(f"Test Dataset: {len(test_dataset)} samples, {len(test_loader)} batches")

## Model

In [None]:
def get_efficientnet_model(num_classes):
 model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
 model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
 return model

In [None]:
if torch.backends.mps.is_available():
 device = torch.device("mps")
 print("Using MPS (Apple GPU)")
else:
 device = torch.device("cpu")
 print("MPS not available. Using CPU")

model = get_efficientnet_model(num_classes=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()


## Training

In [None]:
best_val_acc = 0.0
train_losses = []
val_losses = []
train_accs = []
val_accs = []
epochs_no_improve = 0
early_stop = False
patience = 5
model_name = "models/best_model_v1.pth"

for epoch in range(10):
 if early_stop:
 print(f"Early stopping at epoch {epoch}")
 break
 model.train()
 total_train_loss = 0
 train_correct = 0
 train_total = 0

 for batch_x, batch_y in train_loader:
 batch_x, batch_y = batch_x.to(device), batch_y.to(device)
 preds = model(batch_x)
 loss = criterion(preds, batch_y)

 optimizer.zero_grad()
 loss.backward()
 optimizer.step()

 total_train_loss += loss.item()
 pred_labels = preds.argmax(dim=1)
 train_correct += (pred_labels == batch_y).sum().item()
 train_total += batch_y.size(0)

 train_accuracy = train_correct / train_total
 avg_train_loss = total_train_loss / len(train_loader)
 train_losses.append(avg_train_loss)
 train_accs.append(train_accuracy)

 
 model.eval()
 val_correct = val_total = 0

 with torch.no_grad():
 for val_x, val_y in val_loader:
 val_x, val_y = val_x.to(device), val_y.to(device)
 val_preds = model(val_x).argmax(dim=1)
 val_correct += (val_preds == val_y).sum().item()
 val_total += val_y.size(0)

 val_accuracy = val_correct / val_total
 validation_loss = criterion(model(val_x), val_y).item()

 val_losses.append(validation_loss)
 val_accs.append(val_accuracy)

 print(f"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | "
 f"Train Acc: {train_accuracy:.4f} | Val Acc: {val_accuracy:.4f}")
 if val_accuracy > best_val_acc:
 best_val_acc = val_accuracy
 torch.save(model.state_dict(), model_name)
 print(f"New best model saved at epoch {epoch+1} with val acc {val_accuracy:.4f}")
 epochs_no_improve = 0
 else:
 epochs_no_improve += 1
 print(f"No improvement for {epochs_no_improve} epoch(s)")

 if epochs_no_improve >= patience:
 print(f"Validation accuracy did not improve for {patience} consecutive epochs. Stopping early.")
 early_stop = True



### Loss and Accuracy Plot

In [None]:
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss', marker='o')
plt.plot(epochs, val_losses, label='Validation Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss per Epoch')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(epochs, train_accs, label='Train Accuracy', marker='o')
plt.plot(epochs, val_accs, label='Validation Accuracy', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy per Epoch')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
model = get_efficientnet_model(num_classes=4).to(device)
model.load_state_dict(torch.load(model_name))
model.eval() 

all_preds = []
all_targets = []
all_images = []

with torch.no_grad():
 for batch_x, batch_y in test_loader:
 batch_x = batch_x.to(device)
 preds = model(batch_x).argmax(dim=1).cpu()
 all_preds.extend(preds.numpy())
 all_targets.extend(batch_y.numpy())
 all_images.extend(batch_x.cpu())

test_correct = sum(np.array(all_preds) == np.array(all_targets))
test_total = len(all_targets)
test_accuracy = test_correct / test_total

print(f"\nTest Accuracy: {test_accuracy:.4f}")

target_names = le.classes_ 
print("\nClassification Report:\n")
print(classification_report(all_targets, all_preds, target_names=target_names))

cm = confusion_matrix(all_targets, all_preds)

plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=target_names, yticklabels=target_names)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()


## Sample FP, FN

In [None]:
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)
all_images = torch.stack(all_images)

for class_idx, class_name in enumerate(target_names):
 print(f"\nShowing False Negatives and False Positives for class: {class_name}")
 fn_indices = np.where((all_targets == class_idx) & (all_preds != class_idx))[0]
 fp_indices = np.where((all_preds == class_idx) & (all_targets != class_idx))[0]

 def show_images(indices, title, max_images=5):
 num = min(len(indices), max_images)
 if num == 0:
 print(f"No {title} samples.")
 return

 plt.figure(figsize=(12, 2))
 for i, idx in enumerate(indices[:num]):
 img = all_images[idx]
 img = img.permute(1, 2, 0).numpy()
 plt.subplot(1, num, i + 1)
 plt.imshow((img - img.min()) / (img.max() - img.min()))
 plt.axis('off')
 plt.title(f"Pred: {target_names[all_preds[idx]]}\nTrue: {target_names[all_targets[idx]]}")
 plt.suptitle(f"{title} for {class_name}")
 plt.tight_layout()
 plt.show()

 show_images(fn_indices, "False Negatives")
 show_images(fp_indices, "False Positives")


In [None]:
def visualize_channels(model, image_tensor, max_channels=6):
 model.eval()
 activations = {}

 def get_activation(name):
 def hook(model, input, output):
 activations[name] = output.detach().cpu()
 return hook

 hooks = []
 for i in range(len(model.features)):
 layer = model.features[i]
 hooks.append(layer.register_forward_hook(get_activation(f"features_{i}")))

 with torch.no_grad():
 _ = model(image_tensor.unsqueeze(0))

 for h in hooks:
 h.remove()

 for layer_name, fmap in activations.items():
 fmap = fmap.squeeze(0)
 num_channels = min(fmap.shape[0], max_channels)

 plt.figure(figsize=(num_channels * 2, 2.5))
 for i in range(num_channels):
 plt.subplot(1, num_channels, i + 1)
 plt.imshow(fmap[i], cmap='viridis')
 plt.title(f"{layer_name} ch{i}")
 plt.axis('off')
 plt.tight_layout()
 plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt

def visualize_channels(model, image_tensor, max_channels=6):
 model.eval()
 activations = {}

 def get_activation(name):
 def hook(model, input, output):
 activations[name] = output.detach().cpu()
 return hook

 hooks = []
 for i in range(len(model.features)):
 layer = model.features[i]
 hooks.append(layer.register_forward_hook(get_activation(f"features_{i}")))

 with torch.no_grad():
 _ = model(image_tensor.unsqueeze(0))

 for h in hooks:
 h.remove()

 for layer_name, fmap in activations.items():
 fmap = fmap.squeeze(0)
 channel_scores = fmap.mean(dim=(1, 2)) 
 topk = torch.topk(channel_scores, k=min(max_channels, fmap.shape[0]))
 top_indices = topk.indices
 plt.figure(figsize=(max_channels * 2, 2.5))
 for idx, ch in enumerate(top_indices):
 plt.subplot(1, max_channels, idx + 1)
 plt.imshow(fmap[ch], cmap='viridis')
 plt.title(f"{layer_name}\nch{ch.item()} ({channel_scores[ch]:.2f})")
 plt.axis('off')
 plt.tight_layout()
 plt.show()


In [None]:
model = get_efficientnet_model(num_classes=4)
model.load_state_dict(torch.load("models/best_model_v1.pth"))
model.eval()

### Onion: Visulaize color channel 

In [None]:

img = Image.open("dataset/Onion_512/Whole/image_0001.jpg").convert("RGB")

transform = transforms.Compose([
 transforms.Resize((224, 224)),
 transforms.ToTensor()
])
img_tensor = transform(img)
visualize_channels(model, img_tensor, max_channels=16)


### Pear: Visulaize color channel 

In [None]:
img = Image.open("dataset/Pear_512/Whole/image_0089.jpg").convert("RGB")

transform = transforms.Compose([
 transforms.Resize((224, 224)),
 transforms.ToTensor()
])
img_tensor = transform(img)

visualize_channels(model, img_tensor, max_channels=16)


### Tomato: Visulaize color channel 

In [None]:
img = Image.open("dataset/Tomato_512/Whole/image_0001.jpg").convert("RGB")

transform = transforms.Compose([
 transforms.Resize((224, 224)),
 transforms.ToTensor()
])
img_tensor = transform(img)

visualize_channels(model, img_tensor, max_channels=16)


### Strawberry: Visulaize color channel 

In [None]:
img = Image.open("dataset/Strawberry_512/Whole/image_0388.jpg").convert("RGB")
transform = transforms.Compose([
 transforms.Resize((224, 224)),
 transforms.ToTensor()
])
img_tensor = transform(img)

visualize_channels(model, img_tensor, max_channels=16)
