Dense-Captioning-Platform / custom_models /custom_cascade_with_meta.py
hanszhu's picture
build(space): initial Docker Space with Gradio app, MMDet, SAM integration
eb4d305
from mmdet.models.detectors import CascadeRCNN
from mmdet.registry import MODELS
import torch
import torch.nn as nn
@MODELS.register_module()
class CustomCascadeWithMeta(CascadeRCNN):
"""Custom Cascade R-CNN with metadata prediction heads."""
def __init__(self,
*args,
chart_cls_head=None,
plot_reg_head=None,
axes_info_head=None,
data_series_head=None,
data_points_count_head=None,
coordinate_standardization=None,
data_series_config=None,
axis_aware_feature=None,
**kwargs):
super().__init__(*args, **kwargs)
# Initialize metadata prediction heads
if chart_cls_head is not None:
self.chart_cls_head = MODELS.build(chart_cls_head)
if plot_reg_head is not None:
self.plot_reg_head = MODELS.build(plot_reg_head)
if axes_info_head is not None:
self.axes_info_head = MODELS.build(axes_info_head)
if data_series_head is not None:
self.data_series_head = MODELS.build(data_series_head)
if data_points_count_head is not None:
self.data_points_count_head = MODELS.build(data_points_count_head)
else:
# Default simple regression head for data point count
self.data_points_count_head = nn.Sequential(
nn.Linear(2048, 512), # Assuming ResNet-50 backbone features
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 1) # Single output for count
)
# Store configurations
self.coordinate_standardization = coordinate_standardization
self.data_series_config = data_series_config
self.axis_aware_feature = axis_aware_feature
def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
"""Forward function during training."""
# Get base detector predictions
x = self.extract_feat(img)
losses = dict()
# RPN forward and loss
if self.with_rpn:
proposal_cfg = self.train_cfg.get('rpn_proposal',
self.test_cfg.rpn)
rpn_losses, proposal_list = self.rpn_head.forward_train(
x,
img_metas,
gt_bboxes,
gt_labels=None,
ann_weight=None,
proposal_cfg=proposal_cfg)
losses.update(rpn_losses)
else:
proposal_list = kwargs.get('proposals', None)
# ROI forward and loss
roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
gt_bboxes, gt_labels, **kwargs)
losses.update(roi_losses)
# Get global features for metadata prediction
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
# Extract ground truth data point counts from img_metas
gt_data_point_counts = []
for img_meta in img_metas:
count = img_meta.get('img_info', {}).get('num_data_points', 0)
gt_data_point_counts.append(count)
gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
# Predict data point counts and compute loss
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
losses['data_points_count_loss'] = data_points_count_loss
# Use predicted data point count as additional feature for ROI head
# Expand the global feature with data point count information
normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range
enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
# Metadata prediction losses
if hasattr(self, 'chart_cls_head'):
chart_cls_loss = self.chart_cls_head(enhanced_global_feat)
losses['chart_cls_loss'] = chart_cls_loss
if hasattr(self, 'plot_reg_head'):
plot_reg_loss = self.plot_reg_head(enhanced_global_feat)
losses['plot_reg_loss'] = plot_reg_loss
if hasattr(self, 'axes_info_head'):
axes_info_loss = self.axes_info_head(enhanced_global_feat)
losses['axes_info_loss'] = axes_info_loss
if hasattr(self, 'data_series_head'):
data_series_loss = self.data_series_head(enhanced_global_feat)
losses['data_series_loss'] = data_series_loss
return losses
def simple_test(self, img, img_metas, **kwargs):
"""Test without augmentation."""
x = self.extract_feat(img)
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
det_bboxes, det_labels = self.roi_head.simple_test_bboxes(
x, img_metas, proposal_list, self.test_cfg.rcnn, **kwargs)
# Get global features for metadata prediction
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
# Predict data point counts
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
# Use predicted data point count as additional feature
normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0) # Normalize to 0-1 range
enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
# Get metadata predictions
results = []
for i, (bboxes, labels) in enumerate(zip(det_bboxes, det_labels)):
result = DetDataSample()
result.bboxes = bboxes
result.labels = labels
# Add data point count prediction
result.predicted_data_points = pred_data_point_counts[i].item()
# Add metadata predictions using enhanced features
if hasattr(self, 'chart_cls_head'):
result.chart_type = self.chart_cls_head(enhanced_global_feat[i:i+1])
if hasattr(self, 'plot_reg_head'):
result.plot_bb = self.plot_reg_head(enhanced_global_feat[i:i+1])
if hasattr(self, 'axes_info_head'):
result.axes_info = self.axes_info_head(enhanced_global_feat[i:i+1])
if hasattr(self, 'data_series_head'):
result.data_series = self.data_series_head(enhanced_global_feat[i:i+1])
results.append(result)
return results