SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/visualization-checkpoint.py
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
def plot_predictions(model, images, masks, device, num_samples=4): | |
"""Visualize model predictions against ground truth.""" | |
with torch.no_grad(): | |
model.eval() | |
predictions = model.predict(images.to(device)) | |
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples)) | |
for idx in range(num_samples): | |
# Original image | |
img = images[idx].permute(1, 2, 0).cpu().numpy() | |
axes[idx, 0].imshow(img) | |
axes[idx, 0].set_title('Original Image') | |
# Ground truth | |
truth = masks[idx].argmax(dim=0).cpu().numpy() | |
axes[idx, 1].imshow(truth, cmap='tab20') | |
axes[idx, 1].set_title('Ground Truth') | |
# Prediction | |
pred = predictions[idx].argmax(dim=0).cpu().numpy() | |
axes[idx, 2].imshow(pred, cmap='tab20') | |
axes[idx, 2].set_title('Prediction') | |
for ax in axes[idx]: | |
ax.axis('off') | |
plt.tight_layout() | |
return fig | |
def create_overlay_mask(image, mask, alpha=0.5, color_map=None): | |
"""Create transparent overlay of segmentation mask on image.""" | |
if color_map is None: | |
color_map = { | |
0: [0, 0, 0], # background | |
1: [255, 0, 0], # class 1 (red) | |
2: [0, 255, 0], # class 2 (green) | |
3: [0, 0, 255], # class 3 (blue) | |
} | |
overlay = image.copy() | |
mask_colored = np.zeros_like(image) | |
for label, color in color_map.items(): | |
mask_colored[mask == label] = color | |
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay) | |
return overlay | |
def plot_training_history(history): | |
"""Plot training and validation metrics.""" | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
# Loss plot | |
ax1.plot(history['train_loss'], label='Training Loss') | |
ax1.plot(history['val_loss'], label='Validation Loss') | |
ax1.set_xlabel('Epoch') | |
ax1.set_ylabel('Loss') | |
ax1.set_title('Training and Validation Loss') | |
ax1.legend() | |
# IoU plot | |
ax2.plot(history['mean_iou'], label='Mean IoU') | |
for class_name, ious in history['class_ious'].items(): | |
ax2.plot(ious, label=f'{class_name} IoU') | |
ax2.set_xlabel('Epoch') | |
ax2.set_ylabel('IoU') | |
ax2.set_title('IoU Metrics') | |
ax2.legend() | |
plt.tight_layout() | |
return fig | |
def visualize_predictions_on_batch(model, batch_images, batch_size=8): | |
"""Create grid visualization for a batch of predictions.""" | |
with torch.no_grad(): | |
predictions = model.predict(batch_images) | |
fig = plt.figure(figsize=(15, 5)) | |
for idx in range(min(batch_size, len(batch_images))): | |
plt.subplot(2, 4, idx + 1) | |
img = batch_images[idx].permute(1, 2, 0).cpu().numpy() | |
mask = predictions[idx].argmax(dim=0).cpu().numpy() | |
overlay = create_overlay_mask(img, mask) | |
plt.imshow(overlay) | |
plt.axis('off') | |
plt.tight_layout() | |
return fig | |
def save_visualization(fig, save_path): | |
"""Save visualization figure.""" | |
fig.savefig(save_path, bbox_inches='tight', dpi=300) | |
plt.close(fig) | |
def generate_color_mapping(num_classes): | |
"""Generate distinct colors for segmentation classes.""" | |
colors = [ | |
[0, 0, 0], # Background (black) | |
[255, 0, 0], # Red | |
[0, 255, 0], # Green | |
[0, 0, 255], # Blue | |
[255, 255, 0], # Yellow | |
[255, 0, 255], # Magenta | |
[0, 255, 255], # Cyan | |
[128, 0, 0], # Dark Red | |
[0, 128, 0], # Dark Green | |
[0, 0, 128] # Dark Blue | |
] | |
return colors[:num_classes] |