|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from supervised import UNet, Segformer, Inception |
|
from sklearn.cluster import KMeans |
|
from sklearn.mixture import GaussianMixture |
|
from torchvision import transforms |
|
from sklearn.metrics import accuracy_score, jaccard_score, f1_score, confusion_matrix, ConfusionMatrixDisplay |
|
|
|
def postprocess(masks, mode="open", kernel_size=5, iters=1): |
|
kernel = np.ones((kernel_size, kernel_size), np.uint8) |
|
if mode == "open": |
|
new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=iters) for mask in masks] |
|
elif mode == "close": |
|
new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iters) for mask in masks] |
|
elif mode == "erosion": |
|
new_masks = [cv2.erode(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks] |
|
elif mode == "dilation": |
|
new_masks = [cv2.dilate(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks] |
|
else: |
|
new_masks = masks |
|
return new_masks |
|
|
|
def overlay_mask(image, mask, color=(255, 0, 0), alpha=0.5): |
|
""" |
|
Overlay a binary mask on top of an image. |
|
- image: (H, W, 3) numpy array, RGB |
|
- mask: (H, W) numpy array, 0/1 values or 0/255 |
|
- color: RGB tuple for mask color |
|
- alpha: transparency factor (0=transparent, 1=opaque) |
|
""" |
|
image = image.copy() |
|
|
|
|
|
if mask.max() > 1: |
|
mask = (mask > 127).astype(np.uint8) |
|
|
|
|
|
colored_mask = np.zeros_like(image) |
|
colored_mask[:, :, 0] = color[0] |
|
colored_mask[:, :, 1] = color[1] |
|
colored_mask[:, :, 2] = color[2] |
|
|
|
|
|
mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2) |
|
overlay = np.where(mask_3d, (1 - alpha) * image + alpha * colored_mask, image) |
|
|
|
return overlay.astype(np.uint8) |
|
|
|
def predict_and_visualize_single(model, image_path, postprocess_mode='none', alpha=0.5, device='cpu'): |
|
image = Image.fromarray(image_path).convert('RGB') |
|
original_np = np.array(image.resize((128, 128))) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((128, 128)), |
|
transforms.ToTensor() |
|
]) |
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
if isinstance(model, (UNet, Segformer, Inception)): |
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
if isinstance(output, dict): |
|
output = output.get("logits") or output.get("out") |
|
pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy() |
|
elif isinstance(model, (KMeans, GaussianMixture)): |
|
model.fit(original_np.reshape(-1, 3)) |
|
pred_mask = model.predict(original_np.reshape(-1, 3)).reshape(128, 128) |
|
|
|
if postprocess_mode != 'none': |
|
pred_mask = postprocess([pred_mask], mode=postprocess_mode)[0] |
|
|
|
bw_mask = (pred_mask * 255).astype(np.uint8) |
|
overlay = overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha) |
|
|
|
bw_mask = cv2.resize(pred_mask.astype(np.uint8) * 255, (256, 256), interpolation=cv2.INTER_NEAREST) |
|
overlay = cv2.resize(overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha), |
|
(256, 256), |
|
interpolation=cv2.INTER_LINEAR |
|
) |
|
|
|
return bw_mask, overlay |