Dense-Captioning-Platform / custom_models /flexible_load_annotations.py
hanszhu's picture
build(space): initial Docker Space with Gradio app, MMDet, SAM integration
eb4d305
import numpy as np
from typing import Dict, Optional
from mmcv.transforms.base import BaseTransform
from mmdet.registry import TRANSFORMS
from mmdet.datasets.transforms.loading import LoadAnnotations
import logging
from mmdet.structures.mask import BitmapMasks
logger = logging.getLogger(__name__)
@TRANSFORMS.register_module()
class FlexibleLoadAnnotations(LoadAnnotations):
"""
Flexible annotation loader that handles mixed mask/bbox datasets.
"""
def __init__(self,
with_bbox: bool = True,
with_mask: bool = True,
with_seg: bool = False,
poly2mask: bool = True,
**kwargs):
super().__init__(
with_bbox=with_bbox,
with_mask=with_mask,
with_seg=with_seg,
poly2mask=poly2mask,
**kwargs
)
self.mask_stats = {'total': 0, 'with_masks': 0, 'without_masks': 0}
def _load_masks(self, results: dict) -> dict:
"""Load mask annotations from COCO format instances."""
if not self.with_mask or not isinstance(results, dict):
return results
# Check for ann_info format (what COCO dataset actually provides)
ann_info = results.get('ann_info')
if isinstance(ann_info, dict):
# Check if segmentation is in ann_info
if 'segmentation' in ann_info:
segmentation = ann_info['segmentation']
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
# Convert to mask format
ann_info['masks'] = segmentation
return super()._load_masks(results)
# Check for polygon data in ann_info
if 'polygon' in ann_info:
polygon = ann_info['polygon']
if polygon and isinstance(polygon, dict):
try:
# Convert polygon to COCO segmentation format
coords = []
for j in range(4): # Assuming 4-point polygons
x_key = f'x{j}'
y_key = f'y{j}'
if x_key in polygon and y_key in polygon:
coords.extend([polygon[x_key], polygon[y_key]])
if len(coords) >= 6: # Need at least 3 points (6 coordinates)
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
segmentation = [coords]
ann_info['segmentation'] = segmentation
ann_info['masks'] = segmentation
return super()._load_masks(results)
except Exception as e:
logger.debug(f"Polygon conversion failed: {e}")
# Handle COCO format: instances with segmentation
instances = results.get('instances')
if isinstance(instances, list):
# Process ALL instances - keep both with and without masks
valid_instances = []
for i, instance in enumerate(instances):
self.mask_stats['total'] += 1
# Check for segmentation in COCO format (COCO dataset stores it in 'mask' field)
segmentation = instance.get('mask') or instance.get('segmentation')
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
# Handle nested list format: [[x1, y1, x2, y2, ...]]
if isinstance(segmentation[0], list):
# Nested format - check if inner list has enough coordinates
inner_seg = segmentation[0]
if len(inner_seg) >= 6: # Need at least 3 points (6 coordinates)
instance['mask'] = segmentation # Keep original nested format for parent
valid_instances.append(instance)
self.mask_stats['with_masks'] += 1
else:
# Keep instance for bbox training even without valid mask
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
else:
# Flat format - already correct
instance['mask'] = segmentation
valid_instances.append(instance)
self.mask_stats['with_masks'] += 1
else:
# Check for polygon data and convert to segmentation
polygon = instance.get('polygon')
if polygon and isinstance(polygon, dict):
# Convert polygon to COCO segmentation format
try:
# Extract polygon coordinates
coords = []
for j in range(4): # Assuming 4-point polygons
x_key = f'x{j}'
y_key = f'y{j}'
if x_key in polygon and y_key in polygon:
coords.extend([polygon[x_key], polygon[y_key]])
if len(coords) >= 6: # Need at least 3 points (6 coordinates)
# Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
segmentation = [coords]
instance['segmentation'] = segmentation
instance['mask'] = segmentation
valid_instances.append(instance)
self.mask_stats['with_masks'] += 1
else:
# Keep instance for bbox training even without mask
# Add empty mask field to prevent KeyError in parent class
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
except Exception as e:
# Keep instance for bbox training even if polygon conversion fails
# Add empty mask field to prevent KeyError in parent class
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
else:
# Keep instance for bbox training even without segmentation
# Add empty mask field to prevent KeyError in parent class
instance['mask'] = []
valid_instances.append(instance)
self.mask_stats['without_masks'] += 1
# Update results with valid instances only
results['instances'] = valid_instances
# Call parent method to process the filtered instances
if valid_instances:
super()._load_masks(results) # Parent modifies results in place
return results
else:
# No valid masks, create empty mask structure
h, w = results.get('img_shape', (0, 0))
results['gt_masks'] = BitmapMasks([], h, w)
results['gt_ignore_flags'] = np.array([], dtype=bool)
return results
# Check for direct segmentation in results
if 'segmentation' in results:
segmentation = results['segmentation']
if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
results['masks'] = segmentation
return super()._load_masks(results)
return results
def transform(self, results: dict) -> dict:
"""Transform function to load annotations."""
# ensure we always return a dict
if not isinstance(results, dict):
logger.error(f"Expected dict, got {type(results)}")
return {}
# Call parent transform to handle bbox loading
results = super().transform(results)
# Handle mask loading with our custom logic
results = self._load_masks(results)
# periodic logging
if self.mask_stats['total'] % 1000 == 0:
t = self.mask_stats['total']
w = self.mask_stats['with_masks']
wo = self.mask_stats['without_masks']
logger.info(f"Mask stats - total: {t}, with_masks: {w}, without_masks: {wo}")
return results
def __repr__(self) -> str:
"""String representation."""
return (f'{self.__class__.__name__}('
f'with_bbox={self.with_bbox}, '
f'with_mask={self.with_mask}, '
f'with_seg={self.with_seg}, '
f'poly2mask={self.poly2mask})')