Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmdet.registry import MODELS | |
class FCHead(nn.Module): | |
"""Enhanced fully connected head for classification tasks with attention.""" | |
def __init__(self, in_channels, num_classes, loss=None): | |
super().__init__() | |
self.attention = nn.MultiheadAttention(in_channels, num_heads=8) | |
self.fc1 = nn.Linear(in_channels, in_channels // 2) | |
self.fc2 = nn.Linear(in_channels // 2, num_classes) | |
self.loss = loss | |
def forward(self, x): | |
# Apply self-attention | |
x = self.attention(x, x, x)[0] | |
# Apply MLP | |
x = F.relu(self.fc1(x)) | |
return self.fc2(x) | |
class RegHead(nn.Module): | |
"""Enhanced regression head for coordinate prediction with distance-based loss.""" | |
def __init__(self, in_channels, out_dims, max_points=None, loss=None, attention=False, use_axis_info=False): | |
super().__init__() | |
self.fc = nn.Linear(in_channels, out_dims) | |
self.max_points = max_points | |
self.loss = loss | |
self.attention = attention | |
self.use_axis_info = use_axis_info | |
if attention: | |
self.attention_layer = nn.MultiheadAttention(in_channels, num_heads=8) | |
# Add axis orientation detection | |
if use_axis_info: | |
self.axis_orientation = nn.Linear(in_channels, 2) # 2 for x/y axis orientation | |
def compute_distance_loss(self, pred_points, gt_points): | |
"""Compute distance-based loss between predicted and ground truth points.""" | |
# Ensure points are in the same format | |
if pred_points.dim() == 2: | |
pred_points = pred_points.unsqueeze(0) | |
if gt_points.dim() == 2: | |
gt_points = gt_points.unsqueeze(0) | |
# Compute pairwise distances | |
dist = torch.cdist(pred_points, gt_points) | |
# Get minimum distance for each point | |
min_dist, _ = torch.min(dist, dim=2) | |
# Compute loss (using smooth L1 loss for robustness) | |
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist)) | |
def forward(self, x): | |
if self.attention: | |
x = self.attention_layer(x, x, x)[0] | |
# Get base predictions | |
pred = self.fc(x) | |
# If using axis info, also predict axis orientation | |
if self.use_axis_info: | |
axis_orientation = self.axis_orientation(x) | |
return pred, axis_orientation | |
return pred | |
class CoordinateTransformer: | |
"""Helper class to transform coordinates between different spaces.""" | |
def to_axis_relative(points, axis_info): | |
"""Transform points to be relative to axis coordinates. | |
Args: | |
points (torch.Tensor): Points in image coordinates (N, 2) | |
axis_info (torch.Tensor): Axis information [x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale] | |
""" | |
# Extract axis information | |
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1) | |
# Normalize to [0, 1] range | |
x_norm = (points[..., 0] - x_min) / (x_max - x_min) | |
y_norm = (points[..., 1] - y_min) / (y_max - y_min) | |
# Scale to axis units | |
x_axis = x_norm * x_scale + x_origin | |
y_axis = y_norm * y_scale + y_origin | |
return torch.stack([x_axis, y_axis], dim=-1) | |
def to_image_coordinates(points, axis_info): | |
"""Transform points from axis coordinates to image coordinates.""" | |
# Extract axis information | |
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1) | |
# Convert from axis units to normalized coordinates | |
x_norm = (points[..., 0] - x_origin) / x_scale | |
y_norm = (points[..., 1] - y_origin) / y_scale | |
# Convert to image coordinates | |
x_img = x_norm * (x_max - x_min) + x_min | |
y_img = y_norm * (y_max - y_min) + y_min | |
return torch.stack([x_img, y_img], dim=-1) | |
class DataSeriesHead(nn.Module): | |
"""Specialized head for data series prediction with dual attention to coordinates and axis-relative positions.""" | |
def __init__(self, in_channels, max_points=50, loss=None): | |
super().__init__() | |
self.max_points = max_points | |
self.loss = loss | |
# Feature extraction | |
self.fc1 = nn.Linear(in_channels, in_channels // 2) | |
# Separate branches for absolute and relative coordinates | |
self.absolute_branch = nn.Sequential( | |
nn.Linear(in_channels // 2, in_channels // 4), | |
nn.ReLU(), | |
nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point | |
) | |
self.relative_branch = nn.Sequential( | |
nn.Linear(in_channels // 2, in_channels // 4), | |
nn.ReLU(), | |
nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point | |
) | |
# Attention mechanisms | |
self.coord_attention = nn.MultiheadAttention(in_channels, num_heads=8) | |
self.axis_attention = nn.MultiheadAttention(in_channels, num_heads=8) | |
self.sequence_attention = nn.MultiheadAttention(in_channels, num_heads=8) | |
# Sequence-aware processing | |
self.sequence_encoder = nn.TransformerEncoder( | |
nn.TransformerEncoderLayer( | |
d_model=in_channels, | |
nhead=8, | |
dim_feedforward=in_channels * 4, | |
dropout=0.1 | |
), | |
num_layers=2 | |
) | |
# Pattern recognition | |
self.pattern_recognizer = nn.Sequential( | |
nn.Linear(in_channels, in_channels // 2), | |
nn.ReLU(), | |
nn.Linear(in_channels // 2, 5) # 5 for different chart patterns | |
) | |
# Coordinate transformer | |
self.coord_transformer = CoordinateTransformer() | |
def check_monotonicity(self, points, chart_type): | |
"""Check if points follow expected monotonicity based on chart type.""" | |
if chart_type in ['line', 'scatter']: | |
# For line/scatter, check if points are generally increasing or decreasing | |
diffs = points[..., 1].diff() | |
return torch.all(diffs >= 0) or torch.all(diffs <= 0) | |
return True | |
def forward(self, x, axis_info=None, chart_type=None): | |
# Apply coordinate attention | |
coord_feat = self.coord_attention(x, x, x)[0] | |
# Apply axis attention if axis info is available | |
if axis_info is not None: | |
axis_feat = self.axis_attention(x, x, x)[0] | |
# Combine features | |
x = coord_feat + axis_feat | |
else: | |
x = coord_feat | |
# Apply sequence attention | |
seq_feat = self.sequence_attention(x, x, x)[0] | |
x = x + seq_feat | |
# Process through sequence encoder | |
x = self.sequence_encoder(x.unsqueeze(0)).squeeze(0) | |
# Extract base features | |
x = F.relu(self.fc1(x)) | |
# Get predictions from both branches | |
absolute_points = self.absolute_branch(x) | |
relative_points = self.relative_branch(x) | |
# Reshape to (batch_size, max_points, 2) | |
absolute_points = absolute_points.view(-1, self.max_points, 2) | |
relative_points = relative_points.view(-1, self.max_points, 2) | |
# If axis information is provided, transform relative points | |
if axis_info is not None: | |
relative_points = self.coord_transformer.to_axis_relative(relative_points, axis_info) | |
# Get pattern prediction | |
pattern_logits = self.pattern_recognizer(x) | |
# Check monotonicity if chart type is provided | |
if chart_type is not None: | |
monotonicity = self.check_monotonicity(absolute_points, chart_type) | |
else: | |
monotonicity = None | |
return absolute_points, relative_points, pattern_logits, monotonicity | |
def compute_loss(self, pred_absolute, pred_relative, gt_absolute, gt_relative, | |
pattern_logits, gt_pattern, axis_info=None, chart_type=None): | |
"""Compute combined loss for both absolute and relative coordinates.""" | |
# Ensure points are in the same format | |
if pred_absolute.dim() == 2: | |
pred_absolute = pred_absolute.unsqueeze(0) | |
if pred_relative.dim() == 2: | |
pred_relative = pred_relative.unsqueeze(0) | |
if gt_absolute.dim() == 2: | |
gt_absolute = gt_absolute.unsqueeze(0) | |
if gt_relative.dim() == 2: | |
gt_relative = gt_relative.unsqueeze(0) | |
# Compute absolute coordinate loss | |
absolute_loss = self.compute_distance_loss(pred_absolute, gt_absolute) | |
# Compute relative coordinate loss | |
if axis_info is not None: | |
# Transform predicted absolute points to relative coordinates | |
pred_absolute_relative = self.coord_transformer.to_axis_relative(pred_absolute, axis_info) | |
relative_loss = self.compute_distance_loss(pred_absolute_relative, gt_relative) | |
else: | |
relative_loss = torch.tensor(0.0, device=pred_absolute.device) | |
# Compute pattern recognition loss | |
pattern_loss = F.cross_entropy(pattern_logits, gt_pattern) | |
# Add monotonicity penalty if applicable | |
if chart_type is not None: | |
monotonicity = self.check_monotonicity(pred_absolute, chart_type) | |
monotonicity_loss = F.binary_cross_entropy(monotonicity.float(), torch.ones_like(monotonicity.float())) | |
else: | |
monotonicity_loss = torch.tensor(0.0, device=pred_absolute.device) | |
# Combine losses with weights | |
total_loss = (absolute_loss + relative_loss + | |
0.5 * pattern_loss + 0.3 * monotonicity_loss) | |
return total_loss | |
def compute_distance_loss(self, pred_points, gt_points): | |
"""Compute distance-based loss between predicted and ground truth points.""" | |
# Compute pairwise distances | |
dist = torch.cdist(pred_points, gt_points) | |
# Get minimum distance for each point | |
min_dist, _ = torch.min(dist, dim=2) | |
# Compute loss (using smooth L1 loss for robustness) | |
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist)) |