File size: 1,865 Bytes
e611d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from torch_scatter import scatter_softmax, scatter_sum
import torch
import torch.nn as nn

from .activation import get_activation

class AttentionPooling(nn.Module):
    """Attention-based pooling layer supporting batched graphs."""
    def __init__(self, input_dim: int, dropout: float = 0.2, activation: str = 'gelu'):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            get_activation(activation),
            nn.Dropout(dropout),
            nn.Linear(input_dim // 2, 1)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x:     (N, input_dim) Node features from multiple graphs
            batch: (N,) Graph ID per node
        Returns:
            (num_graphs, input_dim) Pooled graph features
        """
        attn_logits = self.attention(x).squeeze(-1)       # (N,)
        attn_weights = scatter_softmax(attn_logits, batch)  # (N,)
        weighted_x = x * attn_weights.unsqueeze(-1)       # (N, D)
        pooled = scatter_sum(weighted_x, batch, dim=0)    # (num_graphs, D)
        return self.dropout(pooled)

class AddPooling(nn.Module):
    """Simple addition-based pooling layer supporting batched graphs."""
    def __init__(self, input_dim: int, dropout: float = 0.2):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x:     (N, input_dim) Node features from multiple graphs
            batch: (N,) Graph ID per node
        Returns:
            (num_graphs, input_dim) Pooled graph features
        """
        pooled = scatter_sum(x, batch, dim=0)    # (num_graphs, input_dim)
        return self.dropout(pooled)