Spaces:
Sleeping
Sleeping
from mmdet.models.detectors import CascadeRCNN | |
from mmdet.registry import MODELS | |
import torch | |
import torch.nn as nn | |
class CustomCascadeWithMeta(CascadeRCNN): | |
"""Custom Cascade R-CNN with metadata prediction heads.""" | |
def __init__(self, | |
*args, | |
chart_cls_head=None, | |
plot_reg_head=None, | |
axes_info_head=None, | |
data_series_head=None, | |
data_points_count_head=None, | |
coordinate_standardization=None, | |
data_series_config=None, | |
axis_aware_feature=None, | |
**kwargs): | |
super().__init__(*args, **kwargs) | |
# Initialize metadata prediction heads | |
if chart_cls_head is not None: | |
self.chart_cls_head = MODELS.build(chart_cls_head) | |
if plot_reg_head is not None: | |
self.plot_reg_head = MODELS.build(plot_reg_head) | |
if axes_info_head is not None: | |
self.axes_info_head = MODELS.build(axes_info_head) | |
if data_series_head is not None: | |
self.data_series_head = MODELS.build(data_series_head) | |
if data_points_count_head is not None: | |
self.data_points_count_head = MODELS.build(data_points_count_head) | |
else: | |
# Default simple regression head for data point count | |
self.data_points_count_head = nn.Sequential( | |
nn.Linear(2048, 512), # Assuming ResNet-50 backbone features | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(512, 1) # Single output for count | |
) | |
# Store configurations | |
self.coordinate_standardization = coordinate_standardization | |
self.data_series_config = data_series_config | |
self.axis_aware_feature = axis_aware_feature | |
def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs): | |
"""Forward function during training.""" | |
# Get base detector predictions | |
x = self.extract_feat(img) | |
losses = dict() | |
# RPN forward and loss | |
if self.with_rpn: | |
proposal_cfg = self.train_cfg.get('rpn_proposal', | |
self.test_cfg.rpn) | |
rpn_losses, proposal_list = self.rpn_head.forward_train( | |
x, | |
img_metas, | |
gt_bboxes, | |
gt_labels=None, | |
ann_weight=None, | |
proposal_cfg=proposal_cfg) | |
losses.update(rpn_losses) | |
else: | |
proposal_list = kwargs.get('proposals', None) | |
# ROI forward and loss | |
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, | |
gt_bboxes, gt_labels, **kwargs) | |
losses.update(roi_losses) | |
# Get global features for metadata prediction | |
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling | |
# Extract ground truth data point counts from img_metas | |
gt_data_point_counts = [] | |
for img_meta in img_metas: | |
count = img_meta.get('img_info', {}).get('num_data_points', 0) | |
gt_data_point_counts.append(count) | |
gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device) | |
# Predict data point counts and compute loss | |
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1) | |
data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts) | |
losses['data_points_count_loss'] = data_points_count_loss | |
# Use predicted data point count as additional feature for ROI head | |
# Expand the global feature with data point count information | |
normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range | |
enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1) | |
# Metadata prediction losses | |
if hasattr(self, 'chart_cls_head'): | |
chart_cls_loss = self.chart_cls_head(enhanced_global_feat) | |
losses['chart_cls_loss'] = chart_cls_loss | |
if hasattr(self, 'plot_reg_head'): | |
plot_reg_loss = self.plot_reg_head(enhanced_global_feat) | |
losses['plot_reg_loss'] = plot_reg_loss | |
if hasattr(self, 'axes_info_head'): | |
axes_info_loss = self.axes_info_head(enhanced_global_feat) | |
losses['axes_info_loss'] = axes_info_loss | |
if hasattr(self, 'data_series_head'): | |
data_series_loss = self.data_series_head(enhanced_global_feat) | |
losses['data_series_loss'] = data_series_loss | |
return losses | |
def simple_test(self, img, img_metas, **kwargs): | |
"""Test without augmentation.""" | |
x = self.extract_feat(img) | |
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) | |
det_bboxes, det_labels = self.roi_head.simple_test_bboxes( | |
x, img_metas, proposal_list, self.test_cfg.rcnn, **kwargs) | |
# Get global features for metadata prediction | |
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling | |
# Predict data point counts | |
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1) | |
# Use predicted data point count as additional feature | |
normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range | |
enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1) | |
# Get metadata predictions | |
results = [] | |
for i, (bboxes, labels) in enumerate(zip(det_bboxes, det_labels)): | |
result = DetDataSample() | |
result.bboxes = bboxes | |
result.labels = labels | |
# Add data point count prediction | |
result.predicted_data_points = pred_data_point_counts[i].item() | |
# Add metadata predictions using enhanced features | |
if hasattr(self, 'chart_cls_head'): | |
result.chart_type = self.chart_cls_head(enhanced_global_feat[i:i+1]) | |
if hasattr(self, 'plot_reg_head'): | |
result.plot_bb = self.plot_reg_head(enhanced_global_feat[i:i+1]) | |
if hasattr(self, 'axes_info_head'): | |
result.axes_info = self.axes_info_head(enhanced_global_feat[i:i+1]) | |
if hasattr(self, 'data_series_head'): | |
result.data_series = self.data_series_head(enhanced_global_feat[i:i+1]) | |
results.append(result) | |
return results |