Spaces:
Sleeping
Sleeping
import numpy as np | |
from typing import Dict, Optional | |
from mmcv.transforms.base import BaseTransform | |
from mmdet.registry import TRANSFORMS | |
from mmdet.datasets.transforms.loading import LoadAnnotations | |
import logging | |
from mmdet.structures.mask import BitmapMasks | |
logger = logging.getLogger(__name__) | |
class FlexibleLoadAnnotations(LoadAnnotations): | |
""" | |
Flexible annotation loader that handles mixed mask/bbox datasets. | |
""" | |
def __init__(self, | |
with_bbox: bool = True, | |
with_mask: bool = True, | |
with_seg: bool = False, | |
poly2mask: bool = True, | |
**kwargs): | |
super().__init__( | |
with_bbox=with_bbox, | |
with_mask=with_mask, | |
with_seg=with_seg, | |
poly2mask=poly2mask, | |
**kwargs | |
) | |
self.mask_stats = {'total': 0, 'with_masks': 0, 'without_masks': 0} | |
def _load_masks(self, results: dict) -> dict: | |
"""Load mask annotations from COCO format instances.""" | |
if not self.with_mask or not isinstance(results, dict): | |
return results | |
# Check for ann_info format (what COCO dataset actually provides) | |
ann_info = results.get('ann_info') | |
if isinstance(ann_info, dict): | |
# Check if segmentation is in ann_info | |
if 'segmentation' in ann_info: | |
segmentation = ann_info['segmentation'] | |
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0: | |
# Convert to mask format | |
ann_info['masks'] = segmentation | |
return super()._load_masks(results) | |
# Check for polygon data in ann_info | |
if 'polygon' in ann_info: | |
polygon = ann_info['polygon'] | |
if polygon and isinstance(polygon, dict): | |
try: | |
# Convert polygon to COCO segmentation format | |
coords = [] | |
for j in range(4): # Assuming 4-point polygons | |
x_key = f'x{j}' | |
y_key = f'y{j}' | |
if x_key in polygon and y_key in polygon: | |
coords.extend([polygon[x_key], polygon[y_key]]) | |
if len(coords) >= 6: # Need at least 3 points (6 coordinates) | |
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...] | |
segmentation = [coords] | |
ann_info['segmentation'] = segmentation | |
ann_info['masks'] = segmentation | |
return super()._load_masks(results) | |
except Exception as e: | |
logger.debug(f"Polygon conversion failed: {e}") | |
# Handle COCO format: instances with segmentation | |
instances = results.get('instances') | |
if isinstance(instances, list): | |
# Process ALL instances - keep both with and without masks | |
valid_instances = [] | |
for i, instance in enumerate(instances): | |
self.mask_stats['total'] += 1 | |
# Check for segmentation in COCO format (COCO dataset stores it in 'mask' field) | |
segmentation = instance.get('mask') or instance.get('segmentation') | |
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0: | |
# Handle nested list format: [[x1, y1, x2, y2, ...]] | |
if isinstance(segmentation[0], list): | |
# Nested format - check if inner list has enough coordinates | |
inner_seg = segmentation[0] | |
if len(inner_seg) >= 6: # Need at least 3 points (6 coordinates) | |
instance['mask'] = segmentation # Keep original nested format for parent | |
valid_instances.append(instance) | |
self.mask_stats['with_masks'] += 1 | |
else: | |
# Keep instance for bbox training even without valid mask | |
instance['mask'] = [] | |
valid_instances.append(instance) | |
self.mask_stats['without_masks'] += 1 | |
else: | |
# Flat format - already correct | |
instance['mask'] = segmentation | |
valid_instances.append(instance) | |
self.mask_stats['with_masks'] += 1 | |
else: | |
# Check for polygon data and convert to segmentation | |
polygon = instance.get('polygon') | |
if polygon and isinstance(polygon, dict): | |
# Convert polygon to COCO segmentation format | |
try: | |
# Extract polygon coordinates | |
coords = [] | |
for j in range(4): # Assuming 4-point polygons | |
x_key = f'x{j}' | |
y_key = f'y{j}' | |
if x_key in polygon and y_key in polygon: | |
coords.extend([polygon[x_key], polygon[y_key]]) | |
if len(coords) >= 6: # Need at least 3 points (6 coordinates) | |
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...] | |
segmentation = [coords] | |
instance['segmentation'] = segmentation | |
instance['mask'] = segmentation | |
valid_instances.append(instance) | |
self.mask_stats['with_masks'] += 1 | |
else: | |
# Keep instance for bbox training even without mask | |
# Add empty mask field to prevent KeyError in parent class | |
instance['mask'] = [] | |
valid_instances.append(instance) | |
self.mask_stats['without_masks'] += 1 | |
except Exception as e: | |
# Keep instance for bbox training even if polygon conversion fails | |
# Add empty mask field to prevent KeyError in parent class | |
instance['mask'] = [] | |
valid_instances.append(instance) | |
self.mask_stats['without_masks'] += 1 | |
else: | |
# Keep instance for bbox training even without segmentation | |
# Add empty mask field to prevent KeyError in parent class | |
instance['mask'] = [] | |
valid_instances.append(instance) | |
self.mask_stats['without_masks'] += 1 | |
# Update results with valid instances only | |
results['instances'] = valid_instances | |
# Call parent method to process the filtered instances | |
if valid_instances: | |
super()._load_masks(results) # Parent modifies results in place | |
return results | |
else: | |
# No valid masks, create empty mask structure | |
h, w = results.get('img_shape', (0, 0)) | |
results['gt_masks'] = BitmapMasks([], h, w) | |
results['gt_ignore_flags'] = np.array([], dtype=bool) | |
return results | |
# Check for direct segmentation in results | |
if 'segmentation' in results: | |
segmentation = results['segmentation'] | |
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0: | |
results['masks'] = segmentation | |
return super()._load_masks(results) | |
return results | |
def transform(self, results: dict) -> dict: | |
"""Transform function to load annotations.""" | |
# ensure we always return a dict | |
if not isinstance(results, dict): | |
logger.error(f"Expected dict, got {type(results)}") | |
return {} | |
# Call parent transform to handle bbox loading | |
results = super().transform(results) | |
# Handle mask loading with our custom logic | |
results = self._load_masks(results) | |
# periodic logging | |
if self.mask_stats['total'] % 1000 == 0: | |
t = self.mask_stats['total'] | |
w = self.mask_stats['with_masks'] | |
wo = self.mask_stats['without_masks'] | |
logger.info(f"Mask stats - total: {t}, with_masks: {w}, without_masks: {wo}") | |
return results | |
def __repr__(self) -> str: | |
"""String representation.""" | |
return (f'{self.__class__.__name__}(' | |
f'with_bbox={self.with_bbox}, ' | |
f'with_mask={self.with_mask}, ' | |
f'with_seg={self.with_seg}, ' | |
f'poly2mask={self.poly2mask})') | |