File size: 6,810 Bytes
eb4d305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from mmdet.models.detectors import CascadeRCNN
from mmdet.registry import MODELS
import torch
import torch.nn as nn

@MODELS.register_module()
class CustomCascadeWithMeta(CascadeRCNN):
    """Custom Cascade R-CNN with metadata prediction heads."""
    
    def __init__(self,
                 *args,
                 chart_cls_head=None,
                 plot_reg_head=None,
                 axes_info_head=None,
                 data_series_head=None,
                 data_points_count_head=None,
                 coordinate_standardization=None,
                 data_series_config=None,
                 axis_aware_feature=None,
                 **kwargs):
        super().__init__(*args, **kwargs)
        
        # Initialize metadata prediction heads
        if chart_cls_head is not None:
            self.chart_cls_head = MODELS.build(chart_cls_head)
        if plot_reg_head is not None:
            self.plot_reg_head = MODELS.build(plot_reg_head)
        if axes_info_head is not None:
            self.axes_info_head = MODELS.build(axes_info_head)
        if data_series_head is not None:
            self.data_series_head = MODELS.build(data_series_head)
        if data_points_count_head is not None:
            self.data_points_count_head = MODELS.build(data_points_count_head)
        else:
            # Default simple regression head for data point count
            self.data_points_count_head = nn.Sequential(
                nn.Linear(2048, 512),  # Assuming ResNet-50 backbone features
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(512, 1)  # Single output for count
            )
            
        # Store configurations
        self.coordinate_standardization = coordinate_standardization
        self.data_series_config = data_series_config
        self.axis_aware_feature = axis_aware_feature
        
    def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
        """Forward function during training."""
        # Get base detector predictions
        x = self.extract_feat(img)
        losses = dict()
        
        # RPN forward and loss
        if self.with_rpn:
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                            self.test_cfg.rpn)
            rpn_losses, proposal_list = self.rpn_head.forward_train(
                x,
                img_metas,
                gt_bboxes,
                gt_labels=None,
                ann_weight=None,
                proposal_cfg=proposal_cfg)
            losses.update(rpn_losses)
        else:
            proposal_list = kwargs.get('proposals', None)
            
        # ROI forward and loss
        roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
                                               gt_bboxes, gt_labels, **kwargs)
        losses.update(roi_losses)
        
        # Get global features for metadata prediction
        global_feat = x[-1].mean(dim=[2, 3])  # Global average pooling
        
        # Extract ground truth data point counts from img_metas
        gt_data_point_counts = []
        for img_meta in img_metas:
            count = img_meta.get('img_info', {}).get('num_data_points', 0)
            gt_data_point_counts.append(count)
        gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
        
        # Predict data point counts and compute loss
        pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
        data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
        losses['data_points_count_loss'] = data_points_count_loss
        
        # Use predicted data point count as additional feature for ROI head
        # Expand the global feature with data point count information
        normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0)  # Normalize to 0-1 range
        enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
        
        # Metadata prediction losses
        if hasattr(self, 'chart_cls_head'):
            chart_cls_loss = self.chart_cls_head(enhanced_global_feat)
            losses['chart_cls_loss'] = chart_cls_loss
            
        if hasattr(self, 'plot_reg_head'):
            plot_reg_loss = self.plot_reg_head(enhanced_global_feat)
            losses['plot_reg_loss'] = plot_reg_loss
            
        if hasattr(self, 'axes_info_head'):
            axes_info_loss = self.axes_info_head(enhanced_global_feat)
            losses['axes_info_loss'] = axes_info_loss
            
        if hasattr(self, 'data_series_head'):
            data_series_loss = self.data_series_head(enhanced_global_feat)
            losses['data_series_loss'] = data_series_loss
            
        return losses
        
    def simple_test(self, img, img_metas, **kwargs):
        """Test without augmentation."""
        x = self.extract_feat(img)
        proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
        det_bboxes, det_labels = self.roi_head.simple_test_bboxes(
            x, img_metas, proposal_list, self.test_cfg.rcnn, **kwargs)
        
        # Get global features for metadata prediction
        global_feat = x[-1].mean(dim=[2, 3])  # Global average pooling
        
        # Predict data point counts
        pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
        
        # Use predicted data point count as additional feature
        normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0)  # Normalize to 0-1 range
        enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
        
        # Get metadata predictions
        results = []
        for i, (bboxes, labels) in enumerate(zip(det_bboxes, det_labels)):
            result = DetDataSample()
            result.bboxes = bboxes
            result.labels = labels
            
            # Add data point count prediction
            result.predicted_data_points = pred_data_point_counts[i].item()
            
            # Add metadata predictions using enhanced features
            if hasattr(self, 'chart_cls_head'):
                result.chart_type = self.chart_cls_head(enhanced_global_feat[i:i+1])
            if hasattr(self, 'plot_reg_head'):
                result.plot_bb = self.plot_reg_head(enhanced_global_feat[i:i+1])
            if hasattr(self, 'axes_info_head'):
                result.axes_info = self.axes_info_head(enhanced_global_feat[i:i+1])
            if hasattr(self, 'data_series_head'):
                result.data_series = self.data_series_head(enhanced_global_feat[i:i+1])
                
            results.append(result)
            
        return results