hanszhu's picture
build(space): initial Docker Space with Gradio app, MMDet, SAM integration
eb4d305
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.registry import MODELS
@MODELS.register_module()
class FCHead(nn.Module):
"""Enhanced fully connected head for classification tasks with attention."""
def __init__(self, in_channels, num_classes, loss=None):
super().__init__()
self.attention = nn.MultiheadAttention(in_channels, num_heads=8)
self.fc1 = nn.Linear(in_channels, in_channels // 2)
self.fc2 = nn.Linear(in_channels // 2, num_classes)
self.loss = loss
def forward(self, x):
# Apply self-attention
x = self.attention(x, x, x)[0]
# Apply MLP
x = F.relu(self.fc1(x))
return self.fc2(x)
@MODELS.register_module()
class RegHead(nn.Module):
"""Enhanced regression head for coordinate prediction with distance-based loss."""
def __init__(self, in_channels, out_dims, max_points=None, loss=None, attention=False, use_axis_info=False):
super().__init__()
self.fc = nn.Linear(in_channels, out_dims)
self.max_points = max_points
self.loss = loss
self.attention = attention
self.use_axis_info = use_axis_info
if attention:
self.attention_layer = nn.MultiheadAttention(in_channels, num_heads=8)
# Add axis orientation detection
if use_axis_info:
self.axis_orientation = nn.Linear(in_channels, 2) # 2 for x/y axis orientation
def compute_distance_loss(self, pred_points, gt_points):
"""Compute distance-based loss between predicted and ground truth points."""
# Ensure points are in the same format
if pred_points.dim() == 2:
pred_points = pred_points.unsqueeze(0)
if gt_points.dim() == 2:
gt_points = gt_points.unsqueeze(0)
# Compute pairwise distances
dist = torch.cdist(pred_points, gt_points)
# Get minimum distance for each point
min_dist, _ = torch.min(dist, dim=2)
# Compute loss (using smooth L1 loss for robustness)
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))
def forward(self, x):
if self.attention:
x = self.attention_layer(x, x, x)[0]
# Get base predictions
pred = self.fc(x)
# If using axis info, also predict axis orientation
if self.use_axis_info:
axis_orientation = self.axis_orientation(x)
return pred, axis_orientation
return pred
class CoordinateTransformer:
"""Helper class to transform coordinates between different spaces."""
@staticmethod
def to_axis_relative(points, axis_info):
"""Transform points to be relative to axis coordinates.
Args:
points (torch.Tensor): Points in image coordinates (N, 2)
axis_info (torch.Tensor): Axis information [x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale]
"""
# Extract axis information
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
# Normalize to [0, 1] range
x_norm = (points[..., 0] - x_min) / (x_max - x_min)
y_norm = (points[..., 1] - y_min) / (y_max - y_min)
# Scale to axis units
x_axis = x_norm * x_scale + x_origin
y_axis = y_norm * y_scale + y_origin
return torch.stack([x_axis, y_axis], dim=-1)
@staticmethod
def to_image_coordinates(points, axis_info):
"""Transform points from axis coordinates to image coordinates."""
# Extract axis information
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
# Convert from axis units to normalized coordinates
x_norm = (points[..., 0] - x_origin) / x_scale
y_norm = (points[..., 1] - y_origin) / y_scale
# Convert to image coordinates
x_img = x_norm * (x_max - x_min) + x_min
y_img = y_norm * (y_max - y_min) + y_min
return torch.stack([x_img, y_img], dim=-1)
@MODELS.register_module()
class DataSeriesHead(nn.Module):
"""Specialized head for data series prediction with dual attention to coordinates and axis-relative positions."""
def __init__(self, in_channels, max_points=50, loss=None):
super().__init__()
self.max_points = max_points
self.loss = loss
# Feature extraction
self.fc1 = nn.Linear(in_channels, in_channels // 2)
# Separate branches for absolute and relative coordinates
self.absolute_branch = nn.Sequential(
nn.Linear(in_channels // 2, in_channels // 4),
nn.ReLU(),
nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point
)
self.relative_branch = nn.Sequential(
nn.Linear(in_channels // 2, in_channels // 4),
nn.ReLU(),
nn.Linear(in_channels // 4, max_points * 2) # 2 coordinates per point
)
# Attention mechanisms
self.coord_attention = nn.MultiheadAttention(in_channels, num_heads=8)
self.axis_attention = nn.MultiheadAttention(in_channels, num_heads=8)
self.sequence_attention = nn.MultiheadAttention(in_channels, num_heads=8)
# Sequence-aware processing
self.sequence_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=in_channels,
nhead=8,
dim_feedforward=in_channels * 4,
dropout=0.1
),
num_layers=2
)
# Pattern recognition
self.pattern_recognizer = nn.Sequential(
nn.Linear(in_channels, in_channels // 2),
nn.ReLU(),
nn.Linear(in_channels // 2, 5) # 5 for different chart patterns
)
# Coordinate transformer
self.coord_transformer = CoordinateTransformer()
def check_monotonicity(self, points, chart_type):
"""Check if points follow expected monotonicity based on chart type."""
if chart_type in ['line', 'scatter']:
# For line/scatter, check if points are generally increasing or decreasing
diffs = points[..., 1].diff()
return torch.all(diffs >= 0) or torch.all(diffs <= 0)
return True
def forward(self, x, axis_info=None, chart_type=None):
# Apply coordinate attention
coord_feat = self.coord_attention(x, x, x)[0]
# Apply axis attention if axis info is available
if axis_info is not None:
axis_feat = self.axis_attention(x, x, x)[0]
# Combine features
x = coord_feat + axis_feat
else:
x = coord_feat
# Apply sequence attention
seq_feat = self.sequence_attention(x, x, x)[0]
x = x + seq_feat
# Process through sequence encoder
x = self.sequence_encoder(x.unsqueeze(0)).squeeze(0)
# Extract base features
x = F.relu(self.fc1(x))
# Get predictions from both branches
absolute_points = self.absolute_branch(x)
relative_points = self.relative_branch(x)
# Reshape to (batch_size, max_points, 2)
absolute_points = absolute_points.view(-1, self.max_points, 2)
relative_points = relative_points.view(-1, self.max_points, 2)
# If axis information is provided, transform relative points
if axis_info is not None:
relative_points = self.coord_transformer.to_axis_relative(relative_points, axis_info)
# Get pattern prediction
pattern_logits = self.pattern_recognizer(x)
# Check monotonicity if chart type is provided
if chart_type is not None:
monotonicity = self.check_monotonicity(absolute_points, chart_type)
else:
monotonicity = None
return absolute_points, relative_points, pattern_logits, monotonicity
def compute_loss(self, pred_absolute, pred_relative, gt_absolute, gt_relative,
pattern_logits, gt_pattern, axis_info=None, chart_type=None):
"""Compute combined loss for both absolute and relative coordinates."""
# Ensure points are in the same format
if pred_absolute.dim() == 2:
pred_absolute = pred_absolute.unsqueeze(0)
if pred_relative.dim() == 2:
pred_relative = pred_relative.unsqueeze(0)
if gt_absolute.dim() == 2:
gt_absolute = gt_absolute.unsqueeze(0)
if gt_relative.dim() == 2:
gt_relative = gt_relative.unsqueeze(0)
# Compute absolute coordinate loss
absolute_loss = self.compute_distance_loss(pred_absolute, gt_absolute)
# Compute relative coordinate loss
if axis_info is not None:
# Transform predicted absolute points to relative coordinates
pred_absolute_relative = self.coord_transformer.to_axis_relative(pred_absolute, axis_info)
relative_loss = self.compute_distance_loss(pred_absolute_relative, gt_relative)
else:
relative_loss = torch.tensor(0.0, device=pred_absolute.device)
# Compute pattern recognition loss
pattern_loss = F.cross_entropy(pattern_logits, gt_pattern)
# Add monotonicity penalty if applicable
if chart_type is not None:
monotonicity = self.check_monotonicity(pred_absolute, chart_type)
monotonicity_loss = F.binary_cross_entropy(monotonicity.float(), torch.ones_like(monotonicity.float()))
else:
monotonicity_loss = torch.tensor(0.0, device=pred_absolute.device)
# Combine losses with weights
total_loss = (absolute_loss + relative_loss +
0.5 * pattern_loss + 0.3 * monotonicity_loss)
return total_loss
def compute_distance_loss(self, pred_points, gt_points):
"""Compute distance-based loss between predicted and ground truth points."""
# Compute pairwise distances
dist = torch.cdist(pred_points, gt_points)
# Get minimum distance for each point
min_dist, _ = torch.min(dist, dim=2)
# Compute loss (using smooth L1 loss for robustness)
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))