Spaces:
Sleeping
Sleeping
File size: 7,716 Bytes
eb4d305 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# custom_faster_rcnn_with_meta.py - Faster R-CNN with coordinate handling for chart data
import torch
import torch.nn as nn
from mmdet.models.detectors.faster_rcnn import FasterRCNN
from mmdet.registry import MODELS
@MODELS.register_module()
class CustomFasterRCNNWithMeta(FasterRCNN):
"""Faster R-CNN with coordinate standardization for chart detection."""
def __init__(self,
*args,
coordinate_standardization=None,
data_points_count_head=None,
**kwargs):
super().__init__(*args, **kwargs)
# Store coordinate standardization settings
self.coord_std = coordinate_standardization or {}
# Initialize data points count head
if data_points_count_head is not None:
self.data_points_count_head = MODELS.build(data_points_count_head)
else:
# Default simple regression head for data point count
self.data_points_count_head = nn.Sequential(
nn.Linear(2048, 512), # Assuming ResNet-50 backbone features
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 1) # Single output for count
)
print(f"🎯 CustomFasterRCNNWithMeta initialized with coordinate handling:")
print(f" • Enabled: {self.coord_std.get('enabled', False)}")
print(f" • Origin: {self.coord_std.get('origin', 'top_left')}")
print(f" • Normalize: {self.coord_std.get('normalize', False)}")
print(f" • Data points count prediction: Enabled")
def transform_coordinates(self, coords, img_shape, plot_bb=None, axes_info=None):
"""Transform coordinates based on standardization settings."""
if not self.coord_std.get('enabled', False):
return coords
# Get image dimensions
img_height, img_width = img_shape[-2:]
# Convert to tensor if not already
if not isinstance(coords, torch.Tensor):
coords = torch.tensor(coords, device=img_shape.device if hasattr(img_shape, 'device') else 'cpu')
# Ensure coords is 2D
if coords.dim() == 1:
coords = coords.view(-1, 2)
# Normalize coordinates if needed
if self.coord_std.get('normalize', True):
coords = coords / torch.tensor([img_width, img_height], device=coords.device)
# Handle bottom-left to top-left origin conversion
if self.coord_std.get('origin', 'bottom_left') == 'bottom_left':
# Flip y-coordinates to convert from bottom-left to top-left origin
coords[:, 1] = 1.0 - coords[:, 1]
# Convert back to pixel coordinates
if self.coord_std.get('normalize', True):
coords = coords * torch.tensor([img_width, img_height], device=coords.device)
return coords
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
**kwargs):
"""Forward function during training with coordinate transformation."""
# Transform ground truth bboxes if coordinate standardization is enabled
if self.coord_std.get('enabled', False) and gt_bboxes is not None:
transformed_gt_bboxes = []
for i, bboxes in enumerate(gt_bboxes):
if len(bboxes) > 0:
# Convert bbox format for transformation
# MMDet uses [x1, y1, x2, y2] format
centers = torch.stack([
(bboxes[:, 0] + bboxes[:, 2]) / 2, # center_x
(bboxes[:, 1] + bboxes[:, 3]) / 2 # center_y
], dim=1)
# Transform centers
img_shape = img.shape if hasattr(img, 'shape') else (img_metas[i]['img_shape'][0], img_metas[i]['img_shape'][1])
transformed_centers = self.transform_coordinates(
centers, img_shape,
plot_bb=img_metas[i].get('plot_bb'),
axes_info=img_metas[i].get('axes_info')
)
# Reconstruct bboxes with transformed centers
widths = bboxes[:, 2] - bboxes[:, 0]
heights = bboxes[:, 3] - bboxes[:, 1]
transformed_bboxes = torch.stack([
transformed_centers[:, 0] - widths / 2, # x1
transformed_centers[:, 1] - heights / 2, # y1
transformed_centers[:, 0] + widths / 2, # x2
transformed_centers[:, 1] + heights / 2 # y2
], dim=1)
transformed_gt_bboxes.append(transformed_bboxes)
else:
transformed_gt_bboxes.append(bboxes)
gt_bboxes = transformed_gt_bboxes
# Call parent forward_train with transformed coordinates to get losses
losses = super().forward_train(
img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore, **kwargs)
# Extract features for data point count prediction
x = self.extract_feat(img)
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
# Extract ground truth data point counts from img_metas
gt_data_point_counts = []
for img_meta in img_metas:
count = img_meta.get('img_info', {}).get('num_data_points', 0)
gt_data_point_counts.append(count)
gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
# Predict data point counts and compute loss
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
losses['data_points_count_loss'] = data_points_count_loss
return losses
def simple_test(self, img, img_metas, proposals=None, rescale=False):
"""Simple test function with coordinate inverse transformation."""
# Get predictions from parent
results = super().simple_test(img, img_metas, proposals, rescale)
# Extract features for data point count prediction
x = self.extract_feat(img)
global_feat = x[-1].mean(dim=[2, 3]) # Global average pooling
# Predict data point counts
pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
# Add data point count predictions to results
if results is not None:
for i, result in enumerate(results):
if hasattr(result, 'pred_instances'):
result.pred_instances.predicted_data_points = pred_data_point_counts[i].item()
elif hasattr(result, 'bboxes'):
# For older MMDet versions, add as additional attribute
result.predicted_data_points = pred_data_point_counts[i].item()
# Inverse transform predictions if coordinate standardization is enabled
if self.coord_std.get('enabled', False) and results is not None:
# Note: For simplicity, we're not doing inverse transform in test
# The coordinate system should be consistent during training
pass
return results |