Dense-Captioning-Platform / custom_models /custom_faster_rcnn_with_meta.py
hanszhu's picture
build(space): initial Docker Space with Gradio app, MMDet, SAM integration
eb4d305
# custom_faster_rcnn_with_meta.py - Faster R-CNN with coordinate handling for chart data
import torch
import torch.nn as nn
from mmdet.models.detectors.faster_rcnn import FasterRCNN
from mmdet.registry import MODELS
@MODELS.register_module()
class CustomFasterRCNNWithMeta(FasterRCNN):
"""Faster R-CNN with coordinate standardization for chart detection."""
def __init__(self,
*args,
coordinate_standardization=None,
data_points_count_head=None,
**kwargs):
super().__init__(*args, **kwargs)
# Store coordinate standardization settings
self.coord_std = coordinate_standardization or {}
# Initialize data points count head
if data_points_count_head is not None:
self.data_points_count_head = MODELS.build(data_points_count_head)
else:
# Default simple regression head for data point count
self.data_points_count_head = nn.Sequential(
nn.Linear(2048, 512), # Assuming ResNet-50 backbone features
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 1) # Single output for count
)
print(f"🎯 CustomFasterRCNNWithMeta initialized with coordinate handling:")
print(f" • Enabled: {self.coord_std.get('enabled', False)}")
print(f" • Origin: {self.coord_std.get('origin', 'top_left')}")
print(f" • Normalize: {self.coord_std.get('normalize', False)}")
print(f" • Data points count prediction: Enabled")
def transform_coordinates(self, coords, img_shape, plot_bb=None, axes_info=None):
"""Transform coordinates based on standardization settings."""
if not self.coord_std.get('enabled', False):
return coords
# Get image dimensions
img_height, img_width = img_shape[-2:]
# Convert to tensor if not already
if not isinstance(coords, torch.Tensor):
coords = torch.tensor(coords, device=img_shape.device if hasattr(img_shape, 'device') else 'cpu')
# Ensure coords is 2D
if coords.dim() == 1:
coords = coords.view(-1, 2)
# Normalize coordinates if needed
if self.coord_std.get('normalize', True):
coords = coords / torch.tensor([img_width, img_height], device=coords.device)
# Handle bottom-left to top-left origin conversion
if self.coord_std.get('origin', 'bottom_left') == 'bottom_left':
# Flip y-coordinates to convert from bottom-left to top-left origin
coords[:, 1] = 1.0 - coords[:, 1]
# Convert back to pixel coordinates
if self.coord_std.get('normalize', True):
coords = coords * torch.tensor([img_width, img_height], device=coords.device)
return coords
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
**kwargs):
"""Forward function during training with coordinate transformation."""
# Transform ground truth bboxes if coordinate standardization is enabled
if self.coord_std.get('enabled', False) and gt_bboxes is not None:
transformed_gt_bboxes = []
for i, bboxes in enumerate(gt_bboxes):
if len(bboxes) > 0:
# Convert bbox format for transformation
# MMDet uses [x1, y1, x2, y2] format
centers = torch.stack([
(bboxes[:, 0] + bboxes[:, 2]) / 2, # center_x
(bboxes[:, 1] + bboxes[:, 3]) / 2 # center_y
], dim=1)
# Transform centers
img_shape = img.shape if hasattr(img, 'shape') else (img_metas[i]['img_shape'][0], img_metas[i]['img_shape'][1])
transformed_centers = self.transform_coordinates(
centers, img_shape,
plot_bb=img_metas[i].get('plot_bb'),
axes_info=img_metas[i].get('axes_info')
)
# Reconstruct bboxes with transformed centers
widths = bboxes[:, 2] - bboxes[:, 0]
heights = bboxes[:, 3] - bboxes[:, 1]
transformed_bboxes = torch.stack([
transformed_centers[:, 0] - widths / 2, # x1
transformed_centers[:, 1] - heights / 2, # y1
transformed_centers[:, 0] + widths / 2, # x2
transformed_centers[:, 1] + heights / 2 # y2
], dim=1)
transformed_gt_bboxes.append(transformed_bboxes)
else:
transformed_gt_bboxes.append(bboxes)
gt_bboxes = transformed_gt_bboxes
# Call parent forward_train with transformed coordinates to get losses
losses = super().forward_train(
img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore, **kwargs)
# Extract features for data point count prediction
x = self.extract_feat(img)
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
# Extract ground truth data point counts from img_metas
gt_data_point_counts = []
for img_meta in img_metas:
count = img_meta.get('img_info', {}).get('num_data_points', 0)
gt_data_point_counts.append(count)
gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
# Predict data point counts and compute loss
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
losses['data_points_count_loss'] = data_points_count_loss
return losses
def simple_test(self, img, img_metas, proposals=None, rescale=False):
"""Simple test function with coordinate inverse transformation."""
# Get predictions from parent
results = super().simple_test(img, img_metas, proposals, rescale)
# Extract features for data point count prediction
x = self.extract_feat(img)
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
# Predict data point counts
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
# Add data point count predictions to results
if results is not None:
for i, result in enumerate(results):
if hasattr(result, 'pred_instances'):
result.pred_instances.predicted_data_points = pred_data_point_counts[i].item()
elif hasattr(result, 'bboxes'):
# For older MMDet versions, add as additional attribute
result.predicted_data_points = pred_data_point_counts[i].item()
# Inverse transform predictions if coordinate standardization is enabled
if self.coord_std.get('enabled', False) and results is not None:
# Note: For simplicity, we're not doing inverse transform in test
# The coordinate system should be consistent during training
pass
return results