|
""" |
|
Image Part Segmentation and Labeling Tool |
|
|
|
This script segments images into meaningful parts using the Segment Anything Model (SAM) |
|
and optionally removes backgrounds using BriaRMBG. It identifies, visualizes, and merges |
|
different parts of objects in images. |
|
|
|
Key features: |
|
- Background removal with alpha channel preservation |
|
- Automatic part segmentation with SAM |
|
- Intelligent part merging for logical grouping |
|
- Detection of parts that SAM might miss |
|
- Splitting of disconnected parts into separate components |
|
- Edge cleaning and smoothing of segmentations |
|
- Visualization of segmented parts with clear labeling |
|
""" |
|
|
|
import os |
|
import argparse |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
from PIL import Image |
|
|
|
from torchvision.transforms import functional as F |
|
from torchvision import transforms |
|
import torch.nn.functional as F_nn |
|
from segment_anything import SamAutomaticMaskGenerator, build_sam |
|
from modules.label_2d_mask.visualizer import Visualizer |
|
|
|
|
|
size_th = 2000 |
|
|
|
def get_mask(group_ids, image, ids=None, img_name=None, save_dir=None): |
|
""" |
|
Creates and saves a colored visualization of mask segments. |
|
|
|
Args: |
|
group_ids: Array of segment IDs for each pixel |
|
image: Input image |
|
ids: Identifier to append to output filename |
|
img_name: Base name of the image for saving |
|
|
|
Returns: |
|
Array of segment IDs (unchanged, just for convenience) |
|
""" |
|
colored_mask = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8) |
|
|
|
colored_mask[group_ids == -1] = [255, 255, 255] |
|
|
|
unique_ids = np.unique(group_ids) |
|
unique_ids = unique_ids[unique_ids >= 0] |
|
|
|
for i, unique_id in enumerate(unique_ids): |
|
color_r = (i * 50 + 80) % 256 |
|
color_g = (i * 120 + 40) % 256 |
|
color_b = (i * 180 + 20) % 256 |
|
|
|
mask = (group_ids == unique_id) |
|
colored_mask[mask] = [color_r, color_g, color_b] |
|
|
|
mask_path = os.path.join(save_dir, f"{img_name}_mask_segments_{ids}.png") |
|
cv2.imwrite(mask_path, cv2.cvtColor(colored_mask, cv2.COLOR_RGB2BGR)) |
|
print(f"Saved mask segments visualization to {mask_path}") |
|
|
|
return group_ids |
|
|
|
|
|
def clean_segment_edges(group_ids): |
|
""" |
|
Clean up segment edges by applying morphological operations to each segment. |
|
|
|
Args: |
|
group_ids: Array of segment IDs for each pixel |
|
|
|
Returns: |
|
Cleaned array of segment IDs with smoother boundaries |
|
""" |
|
|
|
unique_ids = np.unique(group_ids) |
|
unique_ids = unique_ids[unique_ids >= 0] |
|
|
|
|
|
cleaned_group_ids = np.full_like(group_ids, -1) |
|
|
|
|
|
kernel = np.ones((3, 3), np.uint8) |
|
|
|
|
|
for segment_id in unique_ids: |
|
|
|
segment_mask = (group_ids == segment_id).astype(np.uint8) |
|
|
|
|
|
smoothed_mask = cv2.morphologyEx(segment_mask, cv2.MORPH_CLOSE, kernel, iterations=1) |
|
|
|
|
|
smoothed_mask = cv2.morphologyEx(smoothed_mask, cv2.MORPH_OPEN, kernel, iterations=1) |
|
|
|
|
|
cleaned_group_ids[smoothed_mask > 0] = segment_id |
|
|
|
print(f"Cleaned edges for {len(unique_ids)} segments") |
|
return cleaned_group_ids |
|
|
|
|
|
def prepare_image(image, bg_color=None, rmbg_net=None): |
|
image_size = (1024, 1024) |
|
transform_image = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
input_images = transform_image(image).unsqueeze(0).to('cuda') |
|
|
|
|
|
with torch.no_grad(): |
|
preds = rmbg_net(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(image.size) |
|
image.putalpha(mask) |
|
|
|
return image |
|
|
|
|
|
def resize_and_pad_to_square(image, target_size=518): |
|
""" |
|
Resize image to have longest side equal to target_size and pad shorter side |
|
to create a square image. |
|
|
|
Args: |
|
image: PIL image or numpy array |
|
target_size: Target square size, defaults to 518 |
|
|
|
Returns: |
|
PIL Image resized and padded to square (target_size x target_size) |
|
""" |
|
|
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 and image.shape[2] == 3 else image) |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
if width > height: |
|
|
|
new_width = target_size |
|
new_height = int(height * (target_size / width)) |
|
else: |
|
|
|
new_height = target_size |
|
new_width = int(width * (target_size / height)) |
|
|
|
|
|
resized_image = image.resize((new_width, new_height), Image.LANCZOS) |
|
|
|
|
|
mode = "RGBA" if image.mode == "RGBA" else "RGB" |
|
background_color = (255, 255, 255, 0) if mode == "RGBA" else (255, 255, 255) |
|
square_image = Image.new(mode, (target_size, target_size), background_color) |
|
|
|
|
|
paste_x = (target_size - new_width) // 2 |
|
paste_y = (target_size - new_height) // 2 |
|
|
|
|
|
if mode == "RGBA": |
|
square_image.paste(resized_image, (paste_x, paste_y), resized_image) |
|
else: |
|
square_image.paste(resized_image, (paste_x, paste_y)) |
|
|
|
return square_image |
|
|
|
|
|
def split_disconnected_parts(group_ids, size_threshold=None): |
|
""" |
|
Split each part into separate parts if they contain disconnected regions. |
|
|
|
Args: |
|
group_ids: Array of segment IDs for each pixel |
|
size_threshold: Minimum size threshold for considering a segment (in pixels). |
|
If None, uses the global size_th variable. |
|
|
|
Returns: |
|
Updated array with each connected component having a unique ID |
|
""" |
|
|
|
if size_threshold is None: |
|
size_threshold = size_th |
|
|
|
new_group_ids = np.full_like(group_ids, -1) |
|
|
|
|
|
unique_ids = np.unique(group_ids) |
|
unique_ids = unique_ids[unique_ids >= 0] |
|
|
|
|
|
next_id = 0 |
|
total_split_regions = 0 |
|
|
|
|
|
for part_id in unique_ids: |
|
|
|
part_mask = (group_ids == part_id).astype(np.uint8) |
|
|
|
|
|
num_labels, labels = cv2.connectedComponents(part_mask, connectivity=8) |
|
|
|
if num_labels == 1: |
|
continue |
|
|
|
if num_labels == 2: |
|
|
|
new_group_ids[labels == 1] = next_id |
|
next_id += 1 |
|
else: |
|
split_count = 0 |
|
print(f"Part {part_id} has {num_labels-1} disconnected regions, splitting...") |
|
|
|
|
|
for label in range(1, num_labels): |
|
region_mask = labels == label |
|
region_size = np.sum(region_mask) |
|
|
|
|
|
if region_size >= size_threshold / 5: |
|
new_group_ids[region_mask] = next_id |
|
split_count += 1 |
|
next_id += 1 |
|
else: |
|
print(f" Skipping small disconnected region ({region_size} pixels)") |
|
|
|
total_split_regions += split_count |
|
|
|
if total_split_regions > 0: |
|
print(f"Split disconnected parts: original {len(unique_ids)} parts -> {next_id} connected parts") |
|
else: |
|
print("No parts needed splitting - all parts are already connected") |
|
|
|
return new_group_ids |
|
|
|
|
|
|
|
|
|
|
|
def get_sam_mask(image, mask_generator, visual, merge_groups=None, existing_group_ids=None, |
|
check_undetected=True, rgba_image=None, img_name=None, skip_split=False, save_dir=None, size_threshold=None): |
|
""" |
|
Generate and process SAM masks for the image, with optional merging and undetected region detection. |
|
|
|
Args: |
|
size_threshold: Minimum size threshold for considering a segment (in pixels). |
|
If None, uses the global size_th variable. |
|
""" |
|
|
|
if size_threshold is None: |
|
size_threshold = size_th |
|
label_mode = '1' |
|
anno_mode = ['Mask', 'Mark'] |
|
|
|
exist_group = False |
|
|
|
|
|
if existing_group_ids is not None: |
|
group_ids = existing_group_ids.copy() |
|
group_counter = np.max(group_ids) + 1 |
|
exist_group = True |
|
else: |
|
|
|
masks = mask_generator.generate(image) |
|
group_ids = np.full((image.shape[0], image.shape[1]), -1, dtype=int) |
|
num_masks = len(masks) |
|
group_counter = 0 |
|
|
|
|
|
area_sorted_masks = sorted(masks, key=lambda x: x["area"], reverse=True) |
|
|
|
|
|
background_mask = None |
|
if rgba_image is not None: |
|
rgba_array = np.array(rgba_image) |
|
if rgba_array.shape[2] == 4: |
|
|
|
background_mask = rgba_array[:, :, 3] <= 10 |
|
|
|
|
|
for i in range(0, num_masks): |
|
if area_sorted_masks[i]["area"] < size_threshold: |
|
print(f"Skipping mask {i}, area too small: {area_sorted_masks[i]['area']} < {size_threshold}") |
|
continue |
|
|
|
mask = area_sorted_masks[i]["segmentation"] |
|
|
|
|
|
if background_mask is not None: |
|
|
|
background_pixels_in_mask = np.sum(mask & background_mask) |
|
mask_area = np.sum(mask) |
|
background_ratio = background_pixels_in_mask / mask_area |
|
|
|
|
|
if background_ratio > 0.1: |
|
print(f" Skipping mask {i}, background ratio: {background_ratio:.2f}") |
|
continue |
|
|
|
|
|
group_ids[mask] = group_counter |
|
print(f"Assigned mask {i} with area {area_sorted_masks[i]['area']} to group {group_counter}") |
|
group_counter += 1 |
|
|
|
|
|
print("Splitting disconnected parts in initial segmentation...") |
|
group_ids = split_disconnected_parts(group_ids, size_threshold) |
|
|
|
|
|
if np.max(group_ids) >= 0: |
|
group_counter = np.max(group_ids) + 1 |
|
print(f"After early splitting, now have {len(np.unique(group_ids))-1} regions (excluding background)") |
|
|
|
|
|
if check_undetected and rgba_image is not None: |
|
print("Checking for undetected parts using RGBA image...") |
|
|
|
rgba_array = np.array(rgba_image) |
|
|
|
|
|
if rgba_array.shape[2] == 4: |
|
print(f"Image has alpha channel, checking for undetected parts...") |
|
|
|
alpha_mask = rgba_array[:, :, 3] > 0 |
|
|
|
|
|
existing_parts_mask = (group_ids != -1) |
|
kernel = np.ones((4, 4), np.uint8) |
|
|
|
|
|
large_kernel = np.ones((4, 4), np.uint8) |
|
dilated_parts = cv2.dilate(existing_parts_mask.astype(np.uint8), large_kernel) |
|
|
|
|
|
undetected_mask = alpha_mask & (~dilated_parts.astype(bool)) |
|
|
|
|
|
if np.sum(undetected_mask) > size_threshold: |
|
print(f"Found undetected parts with {np.sum(undetected_mask)} pixels") |
|
|
|
|
|
num_labels, labels = cv2.connectedComponents( |
|
undetected_mask.astype(np.uint8), |
|
connectivity=8 |
|
) |
|
|
|
print(f" Found {num_labels-1} initial regions") |
|
|
|
|
|
parent = list(range(num_labels)) |
|
|
|
|
|
def find(x): |
|
"""Find with path compression for Union-Find""" |
|
if parent[x] != x: |
|
parent[x] = find(parent[x]) |
|
return parent[x] |
|
|
|
|
|
def union(x, y): |
|
"""Union operation for Union-Find""" |
|
root_x = find(x) |
|
root_y = find(y) |
|
if root_x != root_y: |
|
|
|
if root_x < root_y: |
|
parent[root_y] = root_x |
|
else: |
|
parent[root_x] = root_y |
|
|
|
|
|
areas = np.bincount(labels.flatten())[1:] if num_labels > 1 else [] |
|
|
|
|
|
valid_regions = np.where(areas >= size_threshold/5)[0] + 1 |
|
|
|
|
|
barrier_mask = existing_parts_mask |
|
|
|
|
|
dilated_regions = {} |
|
for i in valid_regions: |
|
region_mask = (labels == i).astype(np.uint8) |
|
dilated_regions[i] = cv2.dilate(region_mask, kernel, iterations=2) |
|
|
|
|
|
for idx, i in enumerate(valid_regions[:-1]): |
|
for j in valid_regions[idx+1:]: |
|
|
|
overlap = dilated_regions[i] & dilated_regions[j] |
|
overlap_size = np.sum(overlap) |
|
|
|
|
|
if overlap_size > 40 and not np.any(overlap & barrier_mask): |
|
|
|
overlap_ratio_i = overlap_size / areas[i-1] |
|
overlap_ratio_j = overlap_size / areas[j-1] |
|
|
|
if max(overlap_ratio_i, overlap_ratio_j) > 0.03: |
|
union(i, j) |
|
print(f" Merging regions {i} and {j} (overlap: {overlap_size} px)") |
|
|
|
|
|
merged_labels = np.zeros_like(labels) |
|
for label in range(1, num_labels): |
|
merged_labels[labels == label] = find(label) |
|
|
|
|
|
unique_merged_regions = np.unique(merged_labels[merged_labels > 0]) |
|
print(f" After merging: {len(unique_merged_regions)} connected regions") |
|
|
|
|
|
group_counter_start = group_counter |
|
for label in unique_merged_regions: |
|
region_mask = merged_labels == label |
|
region_size = np.sum(region_mask) |
|
|
|
if region_size > size_threshold: |
|
print(f" Adding region with ID {label} ({region_size} pixels) as group {group_counter}") |
|
group_ids[region_mask] = group_counter |
|
group_counter += 1 |
|
else: |
|
print(f" Skipping small region with ID {label} ({region_size} pixels < {size_threshold})") |
|
|
|
print(f" Added {group_counter - group_counter_start} regions that weren't detected by SAM") |
|
|
|
|
|
if group_counter > group_counter_start: |
|
print("Processing edges for newly detected parts...") |
|
|
|
|
|
new_parts_mask = np.zeros_like(group_ids, dtype=bool) |
|
for part_id in range(group_counter_start, group_counter): |
|
new_parts_mask |= (group_ids == part_id) |
|
|
|
|
|
all_new_dilated = cv2.dilate(new_parts_mask.astype(np.uint8), kernel, iterations=1) |
|
all_new_eroded = cv2.erode(new_parts_mask.astype(np.uint8), kernel, iterations=1) |
|
all_new_edges = all_new_dilated.astype(bool) & (~all_new_eroded.astype(bool)) |
|
|
|
print(f"Edge processing completed for {group_counter - group_counter_start} new parts") |
|
|
|
|
|
if not exist_group: |
|
get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=save_dir) |
|
|
|
|
|
if merge_groups is not None: |
|
|
|
merged_group_ids = group_ids |
|
|
|
|
|
merged_group_ids[group_ids == -1] = -1 |
|
|
|
|
|
for new_id, group in enumerate(merge_groups): |
|
|
|
group_mask = np.zeros_like(group_ids, dtype=bool) |
|
|
|
orig_ids_first = group[0] |
|
|
|
for orig_id in group: |
|
|
|
mask = (group_ids == orig_id) |
|
pixels = np.sum(mask) |
|
if pixels > 0: |
|
print(f" Including original ID {orig_id} ({pixels} pixels)") |
|
group_mask = group_mask | mask |
|
else: |
|
print(f" Warning: Original ID {orig_id} does not exist") |
|
|
|
|
|
if np.any(group_mask): |
|
print(f" Merging {np.sum(group_mask)} pixels to ID {orig_ids_first}") |
|
merged_group_ids[group_mask] = orig_ids_first |
|
|
|
|
|
unique_ids = np.unique(merged_group_ids) |
|
unique_ids = unique_ids[unique_ids != -1] |
|
id_reassignment = {old_id: new_id for new_id, old_id in enumerate(unique_ids)} |
|
|
|
|
|
new_group_ids = np.full_like(merged_group_ids, -1) |
|
for old_id, new_id in id_reassignment.items(): |
|
new_group_ids[merged_group_ids == old_id] = new_id |
|
|
|
|
|
merged_group_ids = new_group_ids |
|
|
|
print(f"ID reassignment complete: {len(id_reassignment)} groups now have sequential IDs from 0 to {len(id_reassignment)-1}") |
|
|
|
|
|
group_ids = merged_group_ids |
|
print(f"Merging complete, now have {len(np.unique(group_ids))-1} regions (excluding background)") |
|
|
|
|
|
if not skip_split: |
|
|
|
group_ids = split_disconnected_parts(group_ids, size_threshold) |
|
print(f"After splitting disconnected parts, now have {len(np.unique(group_ids))-1} regions (excluding background)") |
|
else: |
|
|
|
group_ids = split_disconnected_parts(group_ids, size_threshold) |
|
print(f"After splitting disconnected parts, now have {len(np.unique(group_ids))-1} regions (excluding background)") |
|
|
|
|
|
vis_mask = visual |
|
|
|
background_mask = (group_ids == -1) |
|
if np.any(background_mask): |
|
vis_mask = visual.draw_binary_mask(background_mask, color=[1.0, 1.0, 1.0], alpha=0.0) |
|
|
|
|
|
for unique_id in np.unique(group_ids): |
|
if unique_id == -1: |
|
continue |
|
mask = (group_ids == unique_id) |
|
|
|
|
|
y_indices, x_indices = np.where(mask) |
|
if len(y_indices) > 0 and len(x_indices) > 0: |
|
area = len(y_indices) |
|
|
|
print(f"Labeling region {unique_id}, area: {area} pixels") |
|
if area < 30: |
|
continue |
|
|
|
|
|
color_r = (unique_id * 50 + 80) % 200 / 255.0 + 0.2 |
|
color_g = (unique_id * 120 + 40) % 200 / 255.0 + 0.2 |
|
color_b = (unique_id * 180 + 20) % 200 / 255.0 + 0.2 |
|
color = [color_r, color_g, color_b] |
|
|
|
|
|
adaptive_alpha = min(0.3, max(0.1, 0.1 + area / 100000)) |
|
|
|
|
|
kernel = np.ones((3, 3), np.uint8) |
|
dilated = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1) |
|
eroded = cv2.erode(mask.astype(np.uint8), kernel, iterations=1) |
|
edge = dilated.astype(bool) & (~eroded.astype(bool)) |
|
|
|
|
|
label = f"{unique_id}" |
|
|
|
|
|
vis_mask = visual.draw_binary_mask_with_number( |
|
mask, |
|
text=label, |
|
label_mode=label_mode, |
|
alpha=adaptive_alpha, |
|
anno_mode=anno_mode, |
|
color=color, |
|
font_size=20 |
|
) |
|
|
|
|
|
edge_color = [min(c*1.3, 1.0) for c in color] |
|
vis_mask = visual.draw_binary_mask( |
|
edge, |
|
alpha=0.8, |
|
color=edge_color |
|
) |
|
|
|
im = vis_mask.get_image() |
|
|
|
return group_ids, im |