File size: 10,698 Bytes
eb4d305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.registry import MODELS

@MODELS.register_module()
class FCHead(nn.Module):
    """Enhanced fully connected head for classification tasks with attention."""
    
    def __init__(self, in_channels, num_classes, loss=None):
        super().__init__()
        self.attention = nn.MultiheadAttention(in_channels, num_heads=8)
        self.fc1 = nn.Linear(in_channels, in_channels // 2)
        self.fc2 = nn.Linear(in_channels // 2, num_classes)
        self.loss = loss
        
    def forward(self, x):
        # Apply self-attention
        x = self.attention(x, x, x)[0]
        # Apply MLP
        x = F.relu(self.fc1(x))
        return self.fc2(x)

@MODELS.register_module()
class RegHead(nn.Module):
    """Enhanced regression head for coordinate prediction with distance-based loss."""
    
    def __init__(self, in_channels, out_dims, max_points=None, loss=None, attention=False, use_axis_info=False):
        super().__init__()
        self.fc = nn.Linear(in_channels, out_dims)
        self.max_points = max_points
        self.loss = loss
        self.attention = attention
        self.use_axis_info = use_axis_info
        
        if attention:
            self.attention_layer = nn.MultiheadAttention(in_channels, num_heads=8)
            
        # Add axis orientation detection
        if use_axis_info:
            self.axis_orientation = nn.Linear(in_channels, 2)  # 2 for x/y axis orientation
            
    def compute_distance_loss(self, pred_points, gt_points):
        """Compute distance-based loss between predicted and ground truth points."""
        # Ensure points are in the same format
        if pred_points.dim() == 2:
            pred_points = pred_points.unsqueeze(0)
        if gt_points.dim() == 2:
            gt_points = gt_points.unsqueeze(0)
            
        # Compute pairwise distances
        dist = torch.cdist(pred_points, gt_points)
        
        # Get minimum distance for each point
        min_dist, _ = torch.min(dist, dim=2)
        
        # Compute loss (using smooth L1 loss for robustness)
        return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))
            
    def forward(self, x):
        if self.attention:
            x = self.attention_layer(x, x, x)[0]
            
        # Get base predictions
        pred = self.fc(x)
        
        # If using axis info, also predict axis orientation
        if self.use_axis_info:
            axis_orientation = self.axis_orientation(x)
            return pred, axis_orientation
            
        return pred

class CoordinateTransformer:
    """Helper class to transform coordinates between different spaces."""
    
    @staticmethod
    def to_axis_relative(points, axis_info):
        """Transform points to be relative to axis coordinates.
        
        Args:
            points (torch.Tensor): Points in image coordinates (N, 2)
            axis_info (torch.Tensor): Axis information [x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale]
        """
        # Extract axis information
        x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
        
        # Normalize to [0, 1] range
        x_norm = (points[..., 0] - x_min) / (x_max - x_min)
        y_norm = (points[..., 1] - y_min) / (y_max - y_min)
        
        # Scale to axis units
        x_axis = x_norm * x_scale + x_origin
        y_axis = y_norm * y_scale + y_origin
        
        return torch.stack([x_axis, y_axis], dim=-1)
    
    @staticmethod
    def to_image_coordinates(points, axis_info):
        """Transform points from axis coordinates to image coordinates."""
        # Extract axis information
        x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1)
        
        # Convert from axis units to normalized coordinates
        x_norm = (points[..., 0] - x_origin) / x_scale
        y_norm = (points[..., 1] - y_origin) / y_scale
        
        # Convert to image coordinates
        x_img = x_norm * (x_max - x_min) + x_min
        y_img = y_norm * (y_max - y_min) + y_min
        
        return torch.stack([x_img, y_img], dim=-1)

@MODELS.register_module()
class DataSeriesHead(nn.Module):
    """Specialized head for data series prediction with dual attention to coordinates and axis-relative positions."""
    
    def __init__(self, in_channels, max_points=50, loss=None):
        super().__init__()
        self.max_points = max_points
        self.loss = loss
        
        # Feature extraction
        self.fc1 = nn.Linear(in_channels, in_channels // 2)
        
        # Separate branches for absolute and relative coordinates
        self.absolute_branch = nn.Sequential(
            nn.Linear(in_channels // 2, in_channels // 4),
            nn.ReLU(),
            nn.Linear(in_channels // 4, max_points * 2)  # 2 coordinates per point
        )
        
        self.relative_branch = nn.Sequential(
            nn.Linear(in_channels // 2, in_channels // 4),
            nn.ReLU(),
            nn.Linear(in_channels // 4, max_points * 2)  # 2 coordinates per point
        )
        
        # Attention mechanisms
        self.coord_attention = nn.MultiheadAttention(in_channels, num_heads=8)
        self.axis_attention = nn.MultiheadAttention(in_channels, num_heads=8)
        self.sequence_attention = nn.MultiheadAttention(in_channels, num_heads=8)
        
        # Sequence-aware processing
        self.sequence_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=in_channels,
                nhead=8,
                dim_feedforward=in_channels * 4,
                dropout=0.1
            ),
            num_layers=2
        )
        
        # Pattern recognition
        self.pattern_recognizer = nn.Sequential(
            nn.Linear(in_channels, in_channels // 2),
            nn.ReLU(),
            nn.Linear(in_channels // 2, 5)  # 5 for different chart patterns
        )
        
        # Coordinate transformer
        self.coord_transformer = CoordinateTransformer()
        
    def check_monotonicity(self, points, chart_type):
        """Check if points follow expected monotonicity based on chart type."""
        if chart_type in ['line', 'scatter']:
            # For line/scatter, check if points are generally increasing or decreasing
            diffs = points[..., 1].diff()
            return torch.all(diffs >= 0) or torch.all(diffs <= 0)
        return True
        
    def forward(self, x, axis_info=None, chart_type=None):
        # Apply coordinate attention
        coord_feat = self.coord_attention(x, x, x)[0]
        
        # Apply axis attention if axis info is available
        if axis_info is not None:
            axis_feat = self.axis_attention(x, x, x)[0]
            # Combine features
            x = coord_feat + axis_feat
        else:
            x = coord_feat
            
        # Apply sequence attention
        seq_feat = self.sequence_attention(x, x, x)[0]
        x = x + seq_feat
        
        # Process through sequence encoder
        x = self.sequence_encoder(x.unsqueeze(0)).squeeze(0)
        
        # Extract base features
        x = F.relu(self.fc1(x))
        
        # Get predictions from both branches
        absolute_points = self.absolute_branch(x)
        relative_points = self.relative_branch(x)
        
        # Reshape to (batch_size, max_points, 2)
        absolute_points = absolute_points.view(-1, self.max_points, 2)
        relative_points = relative_points.view(-1, self.max_points, 2)
        
        # If axis information is provided, transform relative points
        if axis_info is not None:
            relative_points = self.coord_transformer.to_axis_relative(relative_points, axis_info)
            
        # Get pattern prediction
        pattern_logits = self.pattern_recognizer(x)
        
        # Check monotonicity if chart type is provided
        if chart_type is not None:
            monotonicity = self.check_monotonicity(absolute_points, chart_type)
        else:
            monotonicity = None
        
        return absolute_points, relative_points, pattern_logits, monotonicity
        
    def compute_loss(self, pred_absolute, pred_relative, gt_absolute, gt_relative, 
                    pattern_logits, gt_pattern, axis_info=None, chart_type=None):
        """Compute combined loss for both absolute and relative coordinates."""
        # Ensure points are in the same format
        if pred_absolute.dim() == 2:
            pred_absolute = pred_absolute.unsqueeze(0)
        if pred_relative.dim() == 2:
            pred_relative = pred_relative.unsqueeze(0)
        if gt_absolute.dim() == 2:
            gt_absolute = gt_absolute.unsqueeze(0)
        if gt_relative.dim() == 2:
            gt_relative = gt_relative.unsqueeze(0)
            
        # Compute absolute coordinate loss
        absolute_loss = self.compute_distance_loss(pred_absolute, gt_absolute)
        
        # Compute relative coordinate loss
        if axis_info is not None:
            # Transform predicted absolute points to relative coordinates
            pred_absolute_relative = self.coord_transformer.to_axis_relative(pred_absolute, axis_info)
            relative_loss = self.compute_distance_loss(pred_absolute_relative, gt_relative)
        else:
            relative_loss = torch.tensor(0.0, device=pred_absolute.device)
            
        # Compute pattern recognition loss
        pattern_loss = F.cross_entropy(pattern_logits, gt_pattern)
        
        # Add monotonicity penalty if applicable
        if chart_type is not None:
            monotonicity = self.check_monotonicity(pred_absolute, chart_type)
            monotonicity_loss = F.binary_cross_entropy(monotonicity.float(), torch.ones_like(monotonicity.float()))
        else:
            monotonicity_loss = torch.tensor(0.0, device=pred_absolute.device)
            
        # Combine losses with weights
        total_loss = (absolute_loss + relative_loss + 
                     0.5 * pattern_loss + 0.3 * monotonicity_loss)
        
        return total_loss
        
    def compute_distance_loss(self, pred_points, gt_points):
        """Compute distance-based loss between predicted and ground truth points."""
        # Compute pairwise distances
        dist = torch.cdist(pred_points, gt_points)
        
        # Get minimum distance for each point
        min_dist, _ = torch.min(dist, dim=2)
        
        # Compute loss (using smooth L1 loss for robustness)
        return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist))