File size: 9,250 Bytes
eb4d305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import numpy as np
from typing import Dict, Optional
from mmcv.transforms.base import BaseTransform
from mmdet.registry import TRANSFORMS
from mmdet.datasets.transforms.loading import LoadAnnotations
import logging
from mmdet.structures.mask import BitmapMasks

logger = logging.getLogger(__name__)

@TRANSFORMS.register_module()
class FlexibleLoadAnnotations(LoadAnnotations):
    """
    Flexible annotation loader that handles mixed mask/bbox datasets.
    """

    def __init__(self,
                 with_bbox: bool = True,
                 with_mask: bool = True,
                 with_seg: bool = False,
                 poly2mask: bool = True,
                 **kwargs):
        super().__init__(
            with_bbox=with_bbox,
            with_mask=with_mask,
            with_seg=with_seg,
            poly2mask=poly2mask,
            **kwargs
        )
        self.mask_stats = {'total': 0, 'with_masks': 0, 'without_masks': 0}

    def _load_masks(self, results: dict) -> dict:
        """Load mask annotations from COCO format instances."""
        if not self.with_mask or not isinstance(results, dict):
            return results

        # Check for ann_info format (what COCO dataset actually provides)
        ann_info = results.get('ann_info')
        if isinstance(ann_info, dict):
            # Check if segmentation is in ann_info
            if 'segmentation' in ann_info:
                segmentation = ann_info['segmentation']
                if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
                    # Convert to mask format
                    ann_info['masks'] = segmentation
                    return super()._load_masks(results)
            
            # Check for polygon data in ann_info
            if 'polygon' in ann_info:
                polygon = ann_info['polygon']
                if polygon and isinstance(polygon, dict):
                    try:
                        # Convert polygon to COCO segmentation format
                        coords = []
                        for j in range(4):  # Assuming 4-point polygons
                            x_key = f'x{j}'
                            y_key = f'y{j}'
                            if x_key in polygon and y_key in polygon:
                                coords.extend([polygon[x_key], polygon[y_key]])
                        
                        if len(coords) >= 6:  # Need at least 3 points (6 coordinates)
                            # Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
                            segmentation = [coords]
                            ann_info['segmentation'] = segmentation
                            ann_info['masks'] = segmentation
                            return super()._load_masks(results)
                    except Exception as e:
                        logger.debug(f"Polygon conversion failed: {e}")
        
        # Handle COCO format: instances with segmentation
        instances = results.get('instances')
        if isinstance(instances, list):
            # Process ALL instances - keep both with and without masks
            valid_instances = []
            
            for i, instance in enumerate(instances):
                self.mask_stats['total'] += 1
                
                # Check for segmentation in COCO format (COCO dataset stores it in 'mask' field)
                segmentation = instance.get('mask') or instance.get('segmentation')
                if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
                    # Handle nested list format: [[x1, y1, x2, y2, ...]]
                    if isinstance(segmentation[0], list):
                        # Nested format - check if inner list has enough coordinates
                        inner_seg = segmentation[0]
                        if len(inner_seg) >= 6:  # Need at least 3 points (6 coordinates)
                            instance['mask'] = segmentation  # Keep original nested format for parent
                            valid_instances.append(instance)
                            self.mask_stats['with_masks'] += 1
                        else:
                            # Keep instance for bbox training even without valid mask
                            instance['mask'] = []
                            valid_instances.append(instance)
                            self.mask_stats['without_masks'] += 1
                    else:
                        # Flat format - already correct
                        instance['mask'] = segmentation
                        valid_instances.append(instance)
                        self.mask_stats['with_masks'] += 1
                else:
                    # Check for polygon data and convert to segmentation
                    polygon = instance.get('polygon')
                    if polygon and isinstance(polygon, dict):
                        # Convert polygon to COCO segmentation format
                        try:
                            # Extract polygon coordinates
                            coords = []
                            for j in range(4):  # Assuming 4-point polygons
                                x_key = f'x{j}'
                                y_key = f'y{j}'
                                if x_key in polygon and y_key in polygon:
                                    coords.extend([polygon[x_key], polygon[y_key]])
                            
                            if len(coords) >= 6:  # Need at least 3 points (6 coordinates)
                                # Convert to COCO format: [x1, y1, x2, y2, x3, y3, ...]
                                segmentation = [coords]
                                instance['segmentation'] = segmentation
                                instance['mask'] = segmentation
                                valid_instances.append(instance)
                                self.mask_stats['with_masks'] += 1
                            else:
                                # Keep instance for bbox training even without mask
                                # Add empty mask field to prevent KeyError in parent class
                                instance['mask'] = []
                                valid_instances.append(instance)
                                self.mask_stats['without_masks'] += 1
                        except Exception as e:
                            # Keep instance for bbox training even if polygon conversion fails
                            # Add empty mask field to prevent KeyError in parent class
                            instance['mask'] = []
                            valid_instances.append(instance)
                            self.mask_stats['without_masks'] += 1
                    else:
                        # Keep instance for bbox training even without segmentation
                        # Add empty mask field to prevent KeyError in parent class
                        instance['mask'] = []
                        valid_instances.append(instance)
                        self.mask_stats['without_masks'] += 1
            
            # Update results with valid instances only
            results['instances'] = valid_instances
            
            # Call parent method to process the filtered instances
            if valid_instances:
                super()._load_masks(results)  # Parent modifies results in place
                return results
            else:
                # No valid masks, create empty mask structure
                h, w = results.get('img_shape', (0, 0))
                results['gt_masks'] = BitmapMasks([], h, w)
                results['gt_ignore_flags'] = np.array([], dtype=bool)
                return results

        # Check for direct segmentation in results
        if 'segmentation' in results:
            segmentation = results['segmentation']
            if segmentation and isinstance(segmentation, list) and len(segmentation) > 0:
                results['masks'] = segmentation
                return super()._load_masks(results)

        return results

    def transform(self, results: dict) -> dict:
        """Transform function to load annotations."""
        # ensure we always return a dict
        if not isinstance(results, dict):
            logger.error(f"Expected dict, got {type(results)}")
            return {}

        # Call parent transform to handle bbox loading
        results = super().transform(results)
        
        # Handle mask loading with our custom logic
        results = self._load_masks(results)
        
        # periodic logging
        if self.mask_stats['total'] % 1000 == 0:
            t = self.mask_stats['total']
            w = self.mask_stats['with_masks']
            wo = self.mask_stats['without_masks']
            logger.info(f"Mask stats - total: {t}, with_masks: {w}, without_masks: {wo}")
        
        return results

    def __repr__(self) -> str:
        """String representation."""
        return (f'{self.__class__.__name__}('
                f'with_bbox={self.with_bbox}, '
                f'with_mask={self.with_mask}, '
                f'with_seg={self.with_seg}, '
                f'poly2mask={self.poly2mask})')