Spaces:
Sleeping
Sleeping
File size: 9,250 Bytes
eb4d305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import numpy as np
from typing import Dict, Optional
from mmcv.transforms.base import BaseTransform
from mmdet.registry import TRANSFORMS
from mmdet.datasets.transforms.loading import LoadAnnotations
import logging
from mmdet.structures.mask import BitmapMasks
logger = logging.getLogger(__name__)
@TRANSFORMS.register_module()
class FlexibleLoadAnnotations(LoadAnnotations):
"""
Flexible annotation loader that handles mixed mask/bbox datasets.
"""
def __init__(self,
with_bbox: bool = True,
with_mask: bool = True,
with_seg: bool = False,
poly2mask: bool = True,
**kwargs):
super().__init__(
with_bbox=with_bbox,
with_mask=with_mask,
with_seg=with_seg,
poly2mask=poly2mask,
**kwargs
)
self.mask_stats = {'total': 0, 'with_masks': 0, 'without_masks': 0}
def _load_masks(self, results: dict) -> dict:
"""Load mask annotations from COCO format instances."""
if not self.with_mask or not isinstance(results, dict):
return results
# Check for ann_info format (what COCO dataset actually provides)
ann_info = results.get('ann_info')
if isinstance(ann_info, dict):
# Check if segmentation is in ann_info
if 'segmentation' in ann_info:
segmentation = ann_info['segmentation']
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
# Convert to mask format
ann_info['masks'] = segmentation
return super()._load_masks(results)
# Check for polygon data in ann_info
if 'polygon' in ann_info:
polygon = ann_info['polygon']
if polygon and isinstance(polygon, dict):
try:
# Convert polygon to COCO segmentation format
coords = []
for j in range(4): # Assuming 4-point polygons
x_key = f'x{j}'
y_key = f'y{j}'
if x_key in polygon and y_key in polygon:
coords.extend([polygon[x_key], polygon[y_key]])
if len(coords) >= 6: # Need at least 3 points (6 coordinates)
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
segmentation = [coords]
ann_info['segmentation'] = segmentation
ann_info['masks'] = segmentation
return super()._load_masks(results)
except Exception as e:
logger.debug(f"Polygon conversion failed: {e}")
# Handle COCO format: instances with segmentation
instances = results.get('instances')
if isinstance(instances, list):
# Process ALL instances - keep both with and without masks
valid_instances = []
for i, instance in enumerate(instances):
self.mask_stats['total'] += 1
# Check for segmentation in COCO format (COCO dataset stores it in 'mask' field)
segmentation = instance.get('mask') or instance.get('segmentation')
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
# Handle nested list format: [[x1, y1, x2, y2, ...]]
if isinstance(segmentation[0], list):
# Nested format - check if inner list has enough coordinates
inner_seg = segmentation[0]
if len(inner_seg) >= 6: # Need at least 3 points (6 coordinates)
instance['mask'] = segmentation # Keep original nested format for parent
valid_instances.append(instance)
self.mask_stats['with_masks'] += 1
else:
# Keep instance for bbox training even without valid mask
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
else:
# Flat format - already correct
instance['mask'] = segmentation
valid_instances.append(instance)
self.mask_stats['with_masks'] += 1
else:
# Check for polygon data and convert to segmentation
polygon = instance.get('polygon')
if polygon and isinstance(polygon, dict):
# Convert polygon to COCO segmentation format
try:
# Extract polygon coordinates
coords = []
for j in range(4): # Assuming 4-point polygons
x_key = f'x{j}'
y_key = f'y{j}'
if x_key in polygon and y_key in polygon:
coords.extend([polygon[x_key], polygon[y_key]])
if len(coords) >= 6: # Need at least 3 points (6 coordinates)
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
segmentation = [coords]
instance['segmentation'] = segmentation
instance['mask'] = segmentation
valid_instances.append(instance)
self.mask_stats['with_masks'] += 1
else:
# Keep instance for bbox training even without mask
# Add empty mask field to prevent KeyError in parent class
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
except Exception as e:
# Keep instance for bbox training even if polygon conversion fails
# Add empty mask field to prevent KeyError in parent class
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
else:
# Keep instance for bbox training even without segmentation
# Add empty mask field to prevent KeyError in parent class
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
# Update results with valid instances only
results['instances'] = valid_instances
# Call parent method to process the filtered instances
if valid_instances:
super()._load_masks(results) # Parent modifies results in place
return results
else:
# No valid masks, create empty mask structure
h, w = results.get('img_shape', (0, 0))
results['gt_masks'] = BitmapMasks([], h, w)
results['gt_ignore_flags'] = np.array([], dtype=bool)
return results
# Check for direct segmentation in results
if 'segmentation' in results:
segmentation = results['segmentation']
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
results['masks'] = segmentation
return super()._load_masks(results)
return results
def transform(self, results: dict) -> dict:
"""Transform function to load annotations."""
# ensure we always return a dict
if not isinstance(results, dict):
logger.error(f"Expected dict, got {type(results)}")
return {}
# Call parent transform to handle bbox loading
results = super().transform(results)
# Handle mask loading with our custom logic
results = self._load_masks(results)
# periodic logging
if self.mask_stats['total'] % 1000 == 0:
t = self.mask_stats['total']
w = self.mask_stats['with_masks']
wo = self.mask_stats['without_masks']
logger.info(f"Mask stats - total: {t}, with_masks: {w}, without_masks: {wo}")
return results
def __repr__(self) -> str:
"""String representation."""
return (f'{self.__class__.__name__}('
f'with_bbox={self.with_bbox}, '
f'with_mask={self.with_mask}, '
f'with_seg={self.with_seg}, '
f'poly2mask={self.poly2mask})')
|