from mmdet.models.detectors import CascadeRCNN from mmdet.registry import MODELS import torch import torch.nn as nn @MODELS.register_module() class CustomCascadeWithMeta(CascadeRCNN): """Custom Cascade R-CNN with metadata prediction heads.""" def __init__(self, *args, chart_cls_head=None, plot_reg_head=None, axes_info_head=None, data_series_head=None, data_points_count_head=None, coordinate_standardization=None, data_series_config=None, axis_aware_feature=None, **kwargs): super().__init__(*args, **kwargs) # Initialize metadata prediction heads if chart_cls_head is not None: self.chart_cls_head = MODELS.build(chart_cls_head) if plot_reg_head is not None: self.plot_reg_head = MODELS.build(plot_reg_head) if axes_info_head is not None: self.axes_info_head = MODELS.build(axes_info_head) if data_series_head is not None: self.data_series_head = MODELS.build(data_series_head) if data_points_count_head is not None: self.data_points_count_head = MODELS.build(data_points_count_head) else: # Default simple regression head for data point count self.data_points_count_head = nn.Sequential( nn.Linear(2048, 512), # Assuming ResNet-50 backbone features nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, 1) # Single output for count ) # Store configurations self.coordinate_standardization = coordinate_standardization self.data_series_config = data_series_config self.axis_aware_feature = axis_aware_feature def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs): """Forward function during training.""" # Get base detector predictions x = self.extract_feat(img) losses = dict() # RPN forward and loss if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) rpn_losses, proposal_list = self.rpn_head.forward_train( x, img_metas, gt_bboxes, gt_labels=None, ann_weight=None, proposal_cfg=proposal_cfg) losses.update(rpn_losses) else: proposal_list = kwargs.get('proposals', None) # ROI forward and loss roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, gt_bboxes, gt_labels, **kwargs) losses.update(roi_losses) # Get global features for metadata prediction global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling # Extract ground truth data point counts from img_metas gt_data_point_counts = [] for img_meta in img_metas: count = img_meta.get('img_info', {}).get('num_data_points', 0) gt_data_point_counts.append(count) gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device) # Predict data point counts and compute loss pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1) data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts) losses['data_points_count_loss'] = data_points_count_loss # Use predicted data point count as additional feature for ROI head # Expand the global feature with data point count information normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1) # Metadata prediction losses if hasattr(self, 'chart_cls_head'): chart_cls_loss = self.chart_cls_head(enhanced_global_feat) losses['chart_cls_loss'] = chart_cls_loss if hasattr(self, 'plot_reg_head'): plot_reg_loss = self.plot_reg_head(enhanced_global_feat) losses['plot_reg_loss'] = plot_reg_loss if hasattr(self, 'axes_info_head'): axes_info_loss = self.axes_info_head(enhanced_global_feat) losses['axes_info_loss'] = axes_info_loss if hasattr(self, 'data_series_head'): data_series_loss = self.data_series_head(enhanced_global_feat) losses['data_series_loss'] = data_series_loss return losses def simple_test(self, img, img_metas, **kwargs): """Test without augmentation.""" x = self.extract_feat(img) proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) det_bboxes, det_labels = self.roi_head.simple_test_bboxes( x, img_metas, proposal_list, self.test_cfg.rcnn, **kwargs) # Get global features for metadata prediction global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling # Predict data point counts pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1) # Use predicted data point count as additional feature normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1) # Get metadata predictions results = [] for i, (bboxes, labels) in enumerate(zip(det_bboxes, det_labels)): result = DetDataSample() result.bboxes = bboxes result.labels = labels # Add data point count prediction result.predicted_data_points = pred_data_point_counts[i].item() # Add metadata predictions using enhanced features if hasattr(self, 'chart_cls_head'): result.chart_type = self.chart_cls_head(enhanced_global_feat[i:i+1]) if hasattr(self, 'plot_reg_head'): result.plot_bb = self.plot_reg_head(enhanced_global_feat[i:i+1]) if hasattr(self, 'axes_info_head'): result.axes_info = self.axes_info_head(enhanced_global_feat[i:i+1]) if hasattr(self, 'data_series_head'): result.data_series = self.data_series_head(enhanced_global_feat[i:i+1]) results.append(result) return results