theodore-ioann's picture
Update utils.py
171a6dc verified
import cv2
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from supervised import UNet, Segformer, Inception
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from torchvision import transforms
from sklearn.metrics import accuracy_score, jaccard_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
def postprocess(masks, mode="open", kernel_size=5, iters=1):
kernel = np.ones((kernel_size, kernel_size), np.uint8)
if mode == "open":
new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=iters) for mask in masks]
elif mode == "close":
new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iters) for mask in masks]
elif mode == "erosion":
new_masks = [cv2.erode(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
elif mode == "dilation":
new_masks = [cv2.dilate(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
else:
new_masks = masks
return new_masks
def overlay_mask(image, mask, color=(255, 0, 0), alpha=0.5):
"""
Overlay a binary mask on top of an image.
- image: (H, W, 3) numpy array, RGB
- mask: (H, W) numpy array, 0/1 values or 0/255
- color: RGB tuple for mask color
- alpha: transparency factor (0=transparent, 1=opaque)
"""
image = image.copy()
# Make sure mask is binary 0 or 1
if mask.max() > 1:
mask = (mask > 127).astype(np.uint8)
# Create colored mask
colored_mask = np.zeros_like(image)
colored_mask[:, :, 0] = color[0]
colored_mask[:, :, 1] = color[1]
colored_mask[:, :, 2] = color[2]
# Apply mask
mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
overlay = np.where(mask_3d, (1 - alpha) * image + alpha * colored_mask, image)
return overlay.astype(np.uint8)
def predict_and_visualize_single(model, image_path, postprocess_mode='none', alpha=0.5, device='cpu'):
image = Image.fromarray(image_path).convert('RGB')
original_np = np.array(image.resize((128, 128)))
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
input_tensor = transform(image).unsqueeze(0).to(device)
if isinstance(model, (UNet, Segformer, Inception)):
with torch.no_grad():
output = model(input_tensor)
if isinstance(output, dict):
output = output.get("logits") or output.get("out")
pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
elif isinstance(model, (KMeans, GaussianMixture)):
model.fit(original_np.reshape(-1, 3))
pred_mask = model.predict(original_np.reshape(-1, 3)).reshape(128, 128)
if postprocess_mode != 'none':
pred_mask = postprocess([pred_mask], mode=postprocess_mode)[0]
bw_mask = (pred_mask * 255).astype(np.uint8)
overlay = overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha)
# Resize outputs to 384x384
bw_mask = cv2.resize(pred_mask.astype(np.uint8) * 255, (256, 256), interpolation=cv2.INTER_NEAREST)
overlay = cv2.resize(overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha),
(256, 256),
interpolation=cv2.INTER_LINEAR
)
return bw_mask, overlay