File size: 24,617 Bytes
491eded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
"""
Image Part Segmentation and Labeling Tool

This script segments images into meaningful parts using the Segment Anything Model (SAM)
and optionally removes backgrounds using BriaRMBG. It identifies, visualizes, and merges
different parts of objects in images.

Key features:
- Background removal with alpha channel preservation
- Automatic part segmentation with SAM
- Intelligent part merging for logical grouping
- Detection of parts that SAM might miss
- Splitting of disconnected parts into separate components
- Edge cleaning and smoothing of segmentations
- Visualization of segmented parts with clear labeling
"""

import os
import argparse
import numpy as np
import cv2
import torch
from PIL import Image

from torchvision.transforms import functional as F
from torchvision import transforms
import torch.nn.functional as F_nn
from segment_anything import SamAutomaticMaskGenerator, build_sam
from modules.label_2d_mask.visualizer import Visualizer

# Minimum size threshold for considering a segment (in pixels)
size_th = 2000

def get_mask(group_ids, image, ids=None, img_name=None, save_dir=None):
    """
    Creates and saves a colored visualization of mask segments.
    
    Args:
        group_ids: Array of segment IDs for each pixel
        image: Input image
        ids: Identifier to append to output filename
        img_name: Base name of the image for saving
        
    Returns:
        Array of segment IDs (unchanged, just for convenience)
    """
    colored_mask = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
    
    colored_mask[group_ids == -1] = [255, 255, 255]
    
    unique_ids = np.unique(group_ids)
    unique_ids = unique_ids[unique_ids >= 0]
    
    for i, unique_id in enumerate(unique_ids):
        color_r = (i * 50 + 80) % 256
        color_g = (i * 120 + 40) % 256
        color_b = (i * 180 + 20) % 256
        
        mask = (group_ids == unique_id)
        colored_mask[mask] = [color_r, color_g, color_b]
    
    mask_path = os.path.join(save_dir, f"{img_name}_mask_segments_{ids}.png")
    cv2.imwrite(mask_path, cv2.cvtColor(colored_mask, cv2.COLOR_RGB2BGR))
    print(f"Saved mask segments visualization to {mask_path}")
    
    return group_ids


def clean_segment_edges(group_ids):
    """
    Clean up segment edges by applying morphological operations to each segment.
    
    Args:
        group_ids: Array of segment IDs for each pixel
        
    Returns:
        Cleaned array of segment IDs with smoother boundaries
    """
    # Get unique segment IDs (excluding background -1)
    unique_ids = np.unique(group_ids)
    unique_ids = unique_ids[unique_ids >= 0]
    
    # Create a clean group_ids array
    cleaned_group_ids = np.full_like(group_ids, -1)  # Start with all background
    
    # Define kernel for morphological operations
    kernel = np.ones((3, 3), np.uint8)
    
    # Process each segment individually
    for segment_id in unique_ids:
        # Extract the mask for this segment
        segment_mask = (group_ids == segment_id).astype(np.uint8)
        
        # Apply morphological closing to smooth edges
        smoothed_mask = cv2.morphologyEx(segment_mask, cv2.MORPH_CLOSE, kernel, iterations=1)
        
        # Apply morphological opening to remove small isolated pixels
        smoothed_mask = cv2.morphologyEx(smoothed_mask, cv2.MORPH_OPEN, kernel, iterations=1)
        
        # Add this segment back to the cleaned result
        cleaned_group_ids[smoothed_mask > 0] = segment_id
    
    print(f"Cleaned edges for {len(unique_ids)} segments")
    return cleaned_group_ids


def prepare_image(image, bg_color=None, rmbg_net=None):
    image_size = (1024, 1024)
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    input_images = transform_image(image).unsqueeze(0).to('cuda')

    # Prediction
    with torch.no_grad():
        preds = rmbg_net(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image.size)
    image.putalpha(mask)

    return image


def resize_and_pad_to_square(image, target_size=518):
    """
    Resize image to have longest side equal to target_size and pad shorter side 
    to create a square image.
    
    Args:
        image: PIL image or numpy array
        target_size: Target square size, defaults to 518
    
    Returns:
        PIL Image resized and padded to square (target_size x target_size)
    """
    # Ensure image is a PIL Image object
    if isinstance(image, np.ndarray):
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 and image.shape[2] == 3 else image)
    
    # Get original dimensions
    width, height = image.size
    
    # Determine which dimension is longer
    if width > height:
        # Width is longer
        new_width = target_size
        new_height = int(height * (target_size / width))
    else:
        # Height is longer
        new_height = target_size
        new_width = int(width * (target_size / height))
    
    # Resize image while maintaining aspect ratio
    resized_image = image.resize((new_width, new_height), Image.LANCZOS)
    
    # Create new square image with proper mode (with or without alpha channel)
    mode = "RGBA" if image.mode == "RGBA" else "RGB"
    background_color = (255, 255, 255, 0) if mode == "RGBA" else (255, 255, 255)
    square_image = Image.new(mode, (target_size, target_size), background_color)
    
    # Calculate position to paste resized image (centered)
    paste_x = (target_size - new_width) // 2
    paste_y = (target_size - new_height) // 2
    
    # Paste resized image onto square background
    if mode == "RGBA":
        square_image.paste(resized_image, (paste_x, paste_y), resized_image)
    else:
        square_image.paste(resized_image, (paste_x, paste_y))
    
    return square_image


def split_disconnected_parts(group_ids, size_threshold=None):
    """
    Split each part into separate parts if they contain disconnected regions.
    
    Args:
        group_ids: Array of segment IDs for each pixel
        size_threshold: Minimum size threshold for considering a segment (in pixels).
                       If None, uses the global size_th variable.
        
    Returns:
        Updated array with each connected component having a unique ID
    """
    # Use provided threshold or fall back to global variable
    if size_threshold is None:
        size_threshold = size_th
    # Create a copy to hold the result
    new_group_ids = np.full_like(group_ids, -1)  # Start with all background
    
    # Get unique part IDs (excluding background -1)
    unique_ids = np.unique(group_ids)
    unique_ids = unique_ids[unique_ids >= 0]
    
    # Track the next available ID
    next_id = 0
    total_split_regions = 0
    
    # For each existing part ID
    for part_id in unique_ids:
        # Extract the mask for this part
        part_mask = (group_ids == part_id).astype(np.uint8)
        
        # Find connected components within this part
        num_labels, labels = cv2.connectedComponents(part_mask, connectivity=8)
        
        if num_labels == 1:  # Just background (0), no regions found
            continue
            
        if num_labels == 2:  # One connected component (background + 1 region)
            # Assign the original part's area to the next available ID
            new_group_ids[labels == 1] = next_id
            next_id += 1
        else:  # Multiple disconnected components
            split_count = 0
            print(f"Part {part_id} has {num_labels-1} disconnected regions, splitting...")
            
            # For each connected component (skipping background label 0)
            for label in range(1, num_labels):
                region_mask = labels == label
                region_size = np.sum(region_mask)
                
                # Only include regions that are large enough
                if region_size >= size_threshold / 5:  # Using size threshold to avoid tiny fragments
                    new_group_ids[region_mask] = next_id
                    split_count += 1
                    next_id += 1
                else:
                    print(f"  Skipping small disconnected region ({region_size} pixels)")
            
            total_split_regions += split_count
            
    if total_split_regions > 0:
        print(f"Split disconnected parts: original {len(unique_ids)} parts -> {next_id} connected parts")
    else:
        print("No parts needed splitting - all parts are already connected")
    
    return new_group_ids

# -------------------------------------------------------
# MAIN SEGMENTATION FUNCTION
# -------------------------------------------------------

def get_sam_mask(image, mask_generator, visual, merge_groups=None, existing_group_ids=None, 
                check_undetected=True, rgba_image=None, img_name=None, skip_split=False, save_dir=None, size_threshold=None):
    """
    Generate and process SAM masks for the image, with optional merging and undetected region detection.
    
    Args:
        size_threshold: Minimum size threshold for considering a segment (in pixels). 
                       If None, uses the global size_th variable.
    """
    # Use provided threshold or fall back to global variable
    if size_threshold is None:
        size_threshold = size_th
    label_mode = '1'
    anno_mode = ['Mask', 'Mark']
    
    exist_group = False

    # Use existing group IDs if provided, otherwise generate new ones with SAM
    if existing_group_ids is not None:
        group_ids = existing_group_ids.copy()
        group_counter = np.max(group_ids) + 1
        exist_group = True
    else:
        # Generate masks using SAM
        masks = mask_generator.generate(image)
        group_ids = np.full((image.shape[0], image.shape[1]), -1, dtype=int)
        num_masks = len(masks)
        group_counter = 0

        # Sort masks by area (largest first)
        area_sorted_masks = sorted(masks, key=lambda x: x["area"], reverse=True)
        
        # Create background mask if we have RGBA image
        background_mask = None
        if rgba_image is not None:
            rgba_array = np.array(rgba_image)
            if rgba_array.shape[2] == 4:
                # Use alpha channel to create foreground/background mask
                background_mask = rgba_array[:, :, 3] <= 10  # Areas with very low alpha are background

        # First pass: assign original group IDs
        for i in range(0, num_masks):
            if area_sorted_masks[i]["area"] < size_threshold:
                print(f"Skipping mask {i}, area too small: {area_sorted_masks[i]['area']} < {size_threshold}")
                continue
            
            mask = area_sorted_masks[i]["segmentation"]
            
            # Check proportion of background pixels in this mask
            if background_mask is not None:
                # Calculate how many pixels in this mask are background
                background_pixels_in_mask = np.sum(mask & background_mask)
                mask_area = np.sum(mask)
                background_ratio = background_pixels_in_mask / mask_area
                
                # Skip mask if background proportion is too high (>10%)
                if background_ratio > 0.1:
                    print(f"  Skipping mask {i}, background ratio: {background_ratio:.2f}")
                    continue
            
            # Assign group ID to this mask's pixels
            group_ids[mask] = group_counter
            print(f"Assigned mask {i} with area {area_sorted_masks[i]['area']} to group {group_counter}")
            group_counter += 1
        
        # Split disconnected parts immediately after SAM segmentation
        print("Splitting disconnected parts in initial segmentation...")
        group_ids = split_disconnected_parts(group_ids, size_threshold)
        
        # Update group counter after splitting
        if np.max(group_ids) >= 0:
            group_counter = np.max(group_ids) + 1
        print(f"After early splitting, now have {len(np.unique(group_ids))-1} regions (excluding background)")
    
    # Check for undetected parts using RGBA information
    if check_undetected and rgba_image is not None:
        print("Checking for undetected parts using RGBA image...")
        # Create a foreground mask from the alpha channel
        rgba_array = np.array(rgba_image)
        
        # Check if the image has an alpha channel
        if rgba_array.shape[2] == 4:
            print(f"Image has alpha channel, checking for undetected parts...")
            # Use alpha channel to identify non-transparent pixels (foreground)
            alpha_mask = rgba_array[:, :, 3] > 0  
            
            # Create existing parts mask and dilate it
            existing_parts_mask = (group_ids != -1)
            kernel = np.ones((4, 4), np.uint8)
            
            # Use larger kernel for faster dilation
            large_kernel = np.ones((4, 4), np.uint8)
            dilated_parts = cv2.dilate(existing_parts_mask.astype(np.uint8), large_kernel)
            
            # Find undetected areas (foreground but not detected by SAM)
            undetected_mask = alpha_mask & (~dilated_parts.astype(bool))
            
            # Process only if there are enough undetected pixels
            if np.sum(undetected_mask) > size_threshold:
                print(f"Found undetected parts with {np.sum(undetected_mask)} pixels")
                
                # Find connected components in undetected regions
                num_labels, labels = cv2.connectedComponents(
                    undetected_mask.astype(np.uint8), 
                    connectivity=8
                )
                
                print(f"  Found {num_labels-1} initial regions")
                
                # Use Union-Find data structure for efficient region merging
                parent = list(range(num_labels))
                
                # Find with path compression
                def find(x):
                    """Find with path compression for Union-Find"""
                    if parent[x] != x:
                        parent[x] = find(parent[x])
                    return parent[x]
                
                # Union by rank/size
                def union(x, y):
                    """Union operation for Union-Find"""
                    root_x = find(x)
                    root_y = find(y)
                    if root_x != root_y:
                        # Use smaller ID as parent
                        if root_x < root_y:
                            parent[root_y] = root_x
                        else:
                            parent[root_x] = root_y
                
                # Calculate areas for all regions at once
                areas = np.bincount(labels.flatten())[1:] if num_labels > 1 else []
                
                # Filter regions by minimum size
                valid_regions = np.where(areas >= size_threshold/5)[0] + 1
                
                # Barrier mask for connectivity checks
                barrier_mask = existing_parts_mask
                
                # Pre-compute dilated regions for all valid regions
                dilated_regions = {}
                for i in valid_regions:
                    region_mask = (labels == i).astype(np.uint8)
                    dilated_regions[i] = cv2.dilate(region_mask, kernel, iterations=2)
                
                # Check for region merges based on proximity and overlap
                for idx, i in enumerate(valid_regions[:-1]):
                    for j in valid_regions[idx+1:]:
                        # Check overlap between dilated regions
                        overlap = dilated_regions[i] & dilated_regions[j]
                        overlap_size = np.sum(overlap)
                        
                        # Merge if significant overlap and not separated by existing parts
                        if overlap_size > 40 and not np.any(overlap & barrier_mask):
                            # Calculate overlap ratios
                            overlap_ratio_i = overlap_size / areas[i-1]
                            overlap_ratio_j = overlap_size / areas[j-1]
                            
                            if max(overlap_ratio_i, overlap_ratio_j) > 0.03:
                                union(i, j)
                                print(f"  Merging regions {i} and {j} (overlap: {overlap_size} px)")
                
                # Apply the merging results to create merged labels
                merged_labels = np.zeros_like(labels)
                for label in range(1, num_labels):
                    merged_labels[labels == label] = find(label)
                
                # Get unique merged regions
                unique_merged_regions = np.unique(merged_labels[merged_labels > 0])
                print(f"  After merging: {len(unique_merged_regions)} connected regions")
                
                # Add regions to group_ids if they're large enough
                group_counter_start = group_counter
                for label in unique_merged_regions:
                    region_mask = merged_labels == label
                    region_size = np.sum(region_mask)
                    
                    if region_size > size_threshold:
                        print(f"  Adding region with ID {label} ({region_size} pixels) as group {group_counter}")
                        group_ids[region_mask] = group_counter
                        group_counter += 1
                    else:
                        print(f"  Skipping small region with ID {label} ({region_size} pixels < {size_threshold})")
                
                print(f"  Added {group_counter - group_counter_start} regions that weren't detected by SAM")

                # Process edges for all new parts at once
                if group_counter > group_counter_start:
                    print("Processing edges for newly detected parts...")
                    
                    # Create combined mask for all new parts
                    new_parts_mask = np.zeros_like(group_ids, dtype=bool)
                    for part_id in range(group_counter_start, group_counter):
                        new_parts_mask |= (group_ids == part_id)
                    
                    # Compute edges for all new parts at once
                    all_new_dilated = cv2.dilate(new_parts_mask.astype(np.uint8), kernel, iterations=1)
                    all_new_eroded = cv2.erode(new_parts_mask.astype(np.uint8), kernel, iterations=1)
                    all_new_edges = all_new_dilated.astype(bool) & (~all_new_eroded.astype(bool))
                    
                    print(f"Edge processing completed for {group_counter - group_counter_start} new parts")

    # Save debug visualization of initial segmentation
    if not exist_group:
        get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=save_dir)

    # Merge groups if specified
    if merge_groups is not None:
        # Start with current group_ids
        merged_group_ids = group_ids
        
        # Preserve background regions
        merged_group_ids[group_ids == -1] = -1
        
        # For each merge group, assign all pixels to the first ID in that group
        for new_id, group in enumerate(merge_groups):
            # Create a mask to include all original IDs in this group
            group_mask = np.zeros_like(group_ids, dtype=bool)

            orig_ids_first = group[0]
            # Process each original ID
            for orig_id in group:
                # Get mask for this original ID
                mask = (group_ids == orig_id)
                pixels = np.sum(mask)
                if pixels > 0:
                    print(f"  Including original ID {orig_id} ({pixels} pixels)")
                    group_mask = group_mask | mask
                else:
                    print(f"  Warning: Original ID {orig_id} does not exist")
            
            # Set all pixels in this group to the first ID in the group
            if np.any(group_mask):
                print(f"  Merging {np.sum(group_mask)} pixels to ID {orig_ids_first}")
                merged_group_ids[group_mask] = orig_ids_first

        # Reassign IDs to be continuous from 0
        unique_ids = np.unique(merged_group_ids)
        unique_ids = unique_ids[unique_ids != -1]  # Exclude background
        id_reassignment = {old_id: new_id for new_id, old_id in enumerate(unique_ids)}

        # Create new array with reassigned IDs
        new_group_ids = np.full_like(merged_group_ids, -1)  # Start with all background
        for old_id, new_id in id_reassignment.items():
            new_group_ids[merged_group_ids == old_id] = new_id

        # Update merged_group_ids with continuous IDs
        merged_group_ids = new_group_ids

        print(f"ID reassignment complete: {len(id_reassignment)} groups now have sequential IDs from 0 to {len(id_reassignment)-1}")

        # Replace original group IDs with merged result
        group_ids = merged_group_ids
        print(f"Merging complete, now have {len(np.unique(group_ids))-1} regions (excluding background)")
        
        # Skip splitting disconnected parts if requested
        if not skip_split:
            # Split disconnected parts into separate parts
            group_ids = split_disconnected_parts(group_ids, size_threshold)
            print(f"After splitting disconnected parts, now have {len(np.unique(group_ids))-1} regions (excluding background)")
    else:
        # Always split disconnected parts for initial segmentation
        group_ids = split_disconnected_parts(group_ids, size_threshold)
        print(f"After splitting disconnected parts, now have {len(np.unique(group_ids))-1} regions (excluding background)")

    # Create visualization with clear labeling
    vis_mask = visual
    # First draw background areas (ID -1)
    background_mask = (group_ids == -1)
    if np.any(background_mask):
        vis_mask = visual.draw_binary_mask(background_mask, color=[1.0, 1.0, 1.0], alpha=0.0)

    # Then draw each segment with unique colors and labels
    for unique_id in np.unique(group_ids):
        if unique_id == -1:  # Skip background
            continue
        mask = (group_ids == unique_id)
        
        # Calculate center point and area of this region
        y_indices, x_indices = np.where(mask)
        if len(y_indices) > 0 and len(x_indices) > 0:
            area = len(y_indices)  # Calculate region area
            
            print(f"Labeling region {unique_id}, area: {area} pixels")
            if area < 30:  # Skip very small regions
                continue
            
            # Use different colors for different IDs to enhance visual distinction
            color_r = (unique_id * 50 + 80) % 200 / 255.0 + 0.2
            color_g = (unique_id * 120 + 40) % 200 / 255.0 + 0.2
            color_b = (unique_id * 180 + 20) % 200 / 255.0 + 0.2
            color = [color_r, color_g, color_b]
            
            # Adjust transparency based on area size
            adaptive_alpha = min(0.3, max(0.1, 0.1 + area / 100000))
            
            # Extract edges of this region
            kernel = np.ones((3, 3), np.uint8)
            dilated = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)
            eroded = cv2.erode(mask.astype(np.uint8), kernel, iterations=1)
            edge = dilated.astype(bool) & (~eroded.astype(bool))
            
            # Build label text
            label = f"{unique_id}"
            
            # First draw the main body of the region
            vis_mask = visual.draw_binary_mask_with_number(
                mask, 
                text=label,
                label_mode=label_mode,
                alpha=adaptive_alpha,
                anno_mode=anno_mode,
                color=color,
                font_size=20
            )
            
            # Enhance edges (add border effect for all parts)
            edge_color = [min(c*1.3, 1.0) for c in color]  # Slightly brighter edge color
            vis_mask = visual.draw_binary_mask(
                edge,
                alpha=0.8,  # Lower transparency for edges to make them more visible
                color=edge_color
            )
            
    im = vis_mask.get_image()
    
    return group_ids, im