|
from pathlib import Path |
|
from prettytable import PrettyTable |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch_scatter import scatter_softmax, scatter_sum |
|
from torch_geometric.data import Data, Batch |
|
|
|
from .dihedral import DihedralFeatures |
|
from .EGNN import EGNNLayer |
|
from .pooling import AttentionPooling, AddPooling |
|
from .activation import get_activation |
|
|
|
|
|
class ReCEP(nn.Module): |
|
""" |
|
Refined Graph Epitope Predictor with optional EGNN layer skipping for ablation. |
|
""" |
|
def __init__( |
|
self, |
|
in_dim: int = 2560, |
|
rsa: bool = True, |
|
dihedral: bool = True, |
|
node_dims: list = [512, 256, 256], |
|
edge_dim: int = 32, |
|
dropout: float = 0.3, |
|
activation: str = "gelu", |
|
residual: bool = True, |
|
attention: bool = True, |
|
normalize: bool = True, |
|
coords_agg: str = 'mean', |
|
ffn: bool = True, |
|
batch_norm: bool = True, |
|
concat: bool = False, |
|
addition: bool = False, |
|
|
|
pooling: str = 'attention', |
|
|
|
fusion_type: str = 'concat', |
|
node_gate: bool = False, |
|
node_norm: bool = False, |
|
node_layers: int = 2, |
|
out_dropout: float = 0.2, |
|
use_egnn: bool = True, |
|
encoder: str = 'esmc', |
|
): |
|
super().__init__() |
|
self.use_egnn = use_egnn |
|
self.in_dim = in_dim |
|
self.rsa = rsa |
|
self.dihedral = dihedral |
|
self.original_node_dims = node_dims.copy() |
|
self.edge_dim = edge_dim |
|
self.dropout = dropout |
|
self.activation = activation |
|
self.residual = residual |
|
self.attention = attention |
|
self.normalize = normalize |
|
self.ffn = ffn |
|
self.batch_norm = batch_norm |
|
self.coords_agg = coords_agg |
|
self.concat = concat |
|
self.addition = addition |
|
self.fusion_type = fusion_type |
|
self.node_gate = node_gate |
|
self.node_norm = node_norm |
|
self.node_layers = node_layers |
|
self.out_dropout = out_dropout |
|
self.pooling = pooling |
|
|
|
self.base_node_dim = node_dims[0] |
|
self.node_dims = node_dims.copy() |
|
self.node_dims[0] += 1 if rsa else 0 |
|
|
|
self.node_dims[-1] = self.node_dims[0] if addition else self.node_dims[-1] |
|
|
|
|
|
self.encoder = encoder |
|
if encoder == 'esmc': |
|
self.in_dim = 2560 |
|
elif encoder == 'esm2': |
|
self.in_dim = 1280 |
|
else: |
|
self.in_dim = in_dim |
|
|
|
|
|
if self.use_egnn: |
|
self.final_node_dim = self.node_dims[-1] |
|
else: |
|
self.final_node_dim = self.node_dims[0] |
|
self.concat = False |
|
self.addition = False |
|
|
|
self.proj_layer = nn.Sequential( |
|
nn.Linear(self.in_dim, self.base_node_dim), |
|
get_activation(activation), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
if dihedral: |
|
try: |
|
self.dihedral_features = DihedralFeatures(self.base_node_dim) |
|
except: |
|
print("Warning: DihedralFeatures not found, skipping dihedral features") |
|
self.dihedral = False |
|
|
|
self.egnn_layers = nn.ModuleList() |
|
if self.use_egnn: |
|
for i in range(len(self.node_dims) - 1): |
|
self.egnn_layers.append( |
|
EGNNLayer( |
|
input_nf=self.node_dims[i], |
|
output_nf=self.node_dims[i+1], |
|
hidden_nf=self.node_dims[i+1], |
|
edges_in_d=edge_dim, |
|
act_fn=get_activation(activation), |
|
residual=residual, |
|
attention=attention, |
|
normalize=normalize, |
|
coords_agg=coords_agg, |
|
dropout=dropout, |
|
ffn=ffn, |
|
batch_norm=batch_norm |
|
) |
|
) |
|
|
|
if concat and self.use_egnn: |
|
self.final_node_dim += self.node_dims[0] |
|
|
|
if addition and self.use_egnn: |
|
assert self.node_dims[0] == self.node_dims[-1], "Node dimension mismatch for addition" |
|
self.final_node_dim = self.node_dims[0] |
|
|
|
|
|
if fusion_type == 'concat': |
|
self.node_classifier_input_dim = self.final_node_dim * 2 |
|
elif fusion_type == 'add': |
|
self.node_classifier_input_dim = self.final_node_dim |
|
else: |
|
raise ValueError(f"Unsupported fusion type: {fusion_type}") |
|
|
|
|
|
if node_gate: |
|
if fusion_type == 'concat': |
|
self.node_gate_input_dim = self.final_node_dim * 2 |
|
elif fusion_type == 'add': |
|
self.node_gate_input_dim = self.final_node_dim |
|
else: |
|
raise ValueError(f"Unsupported fusion type: {fusion_type}") |
|
|
|
if pooling == 'attention': |
|
self.graph_pool = AttentionPooling( |
|
input_dim=self.final_node_dim, |
|
dropout=dropout, |
|
activation=activation |
|
) |
|
elif pooling == 'add': |
|
self.graph_pool = AddPooling( |
|
input_dim=self.final_node_dim, |
|
dropout=dropout |
|
) |
|
else: |
|
raise ValueError(f"Unsupported pooling method: {pooling}") |
|
|
|
self.global_predictor = nn.Sequential( |
|
nn.Linear(self.final_node_dim, self.final_node_dim // 2), |
|
get_activation(activation), |
|
nn.Dropout(out_dropout), |
|
nn.Linear(self.final_node_dim // 2, 1) |
|
) |
|
|
|
if node_gate: |
|
self.node_gate = nn.Sequential( |
|
nn.Linear(self.node_gate_input_dim, self.final_node_dim), |
|
get_activation(activation), |
|
nn.LayerNorm(self.final_node_dim), |
|
nn.Linear(self.final_node_dim, self.final_node_dim), |
|
nn.Sigmoid() |
|
) |
|
|
|
self.node_classifier = self._build_node_classifier() |
|
|
|
self._param_printed = False |
|
self.apply(self._init_weights) |
|
|
|
def _build_node_classifier(self): |
|
layers = [] |
|
input_dim = self.node_classifier_input_dim |
|
current_dim = input_dim |
|
for i in range(self.node_layers): |
|
output_dim = 1 if i == self.node_layers - 1 else max(current_dim // 2, 32) |
|
layers.append(nn.Linear(current_dim, output_dim)) |
|
if self.node_norm and i < self.node_layers - 1: |
|
layers.append(nn.LayerNorm(output_dim)) |
|
if i < self.node_layers - 1: |
|
layers.append(get_activation(self.activation)) |
|
layers.append(nn.Dropout(self.out_dropout)) |
|
current_dim = output_dim |
|
return nn.Sequential(*layers) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_normal_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0.0 if module.out_features == 1 else 0.01) |
|
elif isinstance(module, nn.LayerNorm): |
|
nn.init.ones_(module.weight) |
|
nn.init.zeros_(module.bias) |
|
|
|
def forward(self, data: Batch) -> dict: |
|
if self.training and not self._param_printed: |
|
print(f"ReCEP total params: {sum(p.numel() for p in self.parameters()):,}") |
|
self._param_printed = True |
|
|
|
x = data.x |
|
|
|
coords = data.pos |
|
batch = data.batch |
|
e_attr = data.edge_attr |
|
coords_C = coords[:, 1].clone() |
|
|
|
x = self.proj_layer(x) |
|
if self.dihedral and coords is not None: |
|
x = x + self.dihedral_features(coords) |
|
if self.rsa and data.rsa is not None: |
|
rsa = data.rsa.unsqueeze(-1) |
|
x = torch.cat([x, rsa], dim=-1) |
|
|
|
h = x |
|
assert h.shape[1] == self.node_dims[0], f"[ReCEP] Node feature dim mismatch: got {h.shape[1]}, expected {self.node_dims[0]}" |
|
|
|
if self.use_egnn: |
|
for layer in self.egnn_layers: |
|
h, coords_C, _ = layer(h, coords_C, data.edge_index, batch, edge_attr=e_attr) |
|
|
|
if self.concat and self.use_egnn: |
|
h = torch.cat([x, h], dim=-1) |
|
elif self.addition and self.use_egnn: |
|
h = h + x |
|
|
|
graph_feats = self.graph_pool(h, batch) |
|
global_pred = self.global_predictor(graph_feats).squeeze(-1) |
|
|
|
context = graph_feats[batch] |
|
if self.node_gate and hasattr(self, 'node_gate'): |
|
if self.fusion_type == 'concat': |
|
gate_input = torch.cat([h, context], dim=-1) |
|
elif self.fusion_type == 'add': |
|
gate_input = h + context |
|
else: |
|
raise ValueError(f"Unsupported fusion type: {self.fusion_type}") |
|
gate = self.node_gate(gate_input) |
|
gated_h = h + gate * h |
|
else: |
|
gated_h = h |
|
|
|
if self.fusion_type == 'concat': |
|
cat = torch.cat([gated_h, context], dim=-1) |
|
elif self.fusion_type == 'add': |
|
|
|
assert gated_h.shape[-1] == context.shape[-1], f"[ReCEP] Dimension mismatch for add fusion: gated_h {gated_h.shape[-1]} vs context {context.shape[-1]}" |
|
cat = gated_h + context |
|
else: |
|
raise ValueError(f"Unsupported fusion type: {self.fusion_type}") |
|
|
|
|
|
expected_dim = self.node_classifier_input_dim |
|
actual_dim = cat.shape[-1] |
|
assert actual_dim == expected_dim, f"[ReCEP] Node classifier input dim mismatch: got {actual_dim}, expected {expected_dim}" |
|
|
|
node_preds = self.node_classifier(cat).squeeze(-1) |
|
|
|
return {"global_pred": global_pred, "node_preds": node_preds} |
|
|
|
def print_param_count(self): |
|
"""Print a summary table of parameter counts""" |
|
table = PrettyTable() |
|
table.field_names = ["Layer Name", "Type", "Parameters", "Trainable"] |
|
total_params = 0 |
|
trainable_params = 0 |
|
|
|
for name, module in self.named_modules(): |
|
if not list(module.children()): |
|
params = sum(p.numel() for p in module.parameters()) |
|
is_trainable = any(p.requires_grad for p in module.parameters()) |
|
|
|
if params > 0: |
|
total_params += params |
|
trainable_params += params if is_trainable else 0 |
|
|
|
table.add_row([ |
|
name, |
|
module.__class__.__name__, |
|
f"{params:,}", |
|
"✓" if is_trainable else "✗" |
|
]) |
|
|
|
table.add_row(["", "", "", ""], divider=True) |
|
table.add_row([ |
|
"TOTAL", |
|
"", |
|
f"{total_params:,}", |
|
f"Trainable: {trainable_params:,}" |
|
]) |
|
|
|
print("\nReCEP Model Parameter Summary:") |
|
print(table) |
|
print(f"Parameter Density: {trainable_params/total_params:.1%}\n") |
|
|
|
def save(self, path, threshold: float = 0.5): |
|
"""Save model with configuration""" |
|
path = Path(path) |
|
try: |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
save_path = path.with_suffix('.bin') |
|
|
|
config = self.get_config() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save({ |
|
'model_state': self.state_dict(), |
|
'config': config, |
|
'model_class': self.__class__.__name__, |
|
'version': '1.0', |
|
'threshold': threshold |
|
}, save_path) |
|
print(f"ReCEP model saved to {save_path}") |
|
except Exception as e: |
|
print(f"Save failed: {str(e)}") |
|
raise |
|
|
|
@classmethod |
|
def load(cls, path, device='cpu', strict=True, verbose=True): |
|
"""Load model with configuration""" |
|
path = Path(path) |
|
if not path.exists(): |
|
raise FileNotFoundError(f"Model file {path} not found") |
|
|
|
try: |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
elif isinstance(device, int): |
|
if device >= 0 and torch.cuda.is_available(): |
|
device = torch.device(f'cuda:{device}') |
|
else: |
|
device = torch.device('cpu') |
|
elif not isinstance(device, torch.device): |
|
raise ValueError(f"Unsupported device type: {type(device)}") |
|
|
|
checkpoint = torch.load( |
|
path, |
|
map_location=device, |
|
weights_only=False |
|
) |
|
except RuntimeError: |
|
print("Warning: Using unsafe load due to weights_only restriction") |
|
checkpoint = torch.load(path, map_location=device) |
|
|
|
|
|
if 'version' not in checkpoint: |
|
print("Warning: Loading legacy model without version info") |
|
|
|
|
|
config = checkpoint.get('config', {}) |
|
model = cls(**config) |
|
|
|
|
|
model_state = checkpoint['model_state'] |
|
current_state = model.state_dict() |
|
|
|
|
|
matched_state = {} |
|
for name, param in model_state.items(): |
|
if name in current_state: |
|
if param.shape == current_state[name].shape: |
|
matched_state[name] = param |
|
else: |
|
print(f"Size mismatch: {name} (load {param.shape} vs model {current_state[name].shape})") |
|
else: |
|
print(f"Parameter not found: {name}") |
|
|
|
current_state.update(matched_state) |
|
model.load_state_dict(current_state, strict=strict) |
|
|
|
if verbose: |
|
print(f"Successfully loaded {len(matched_state)}/{len(model_state)} parameters") |
|
|
|
return model.to(device), checkpoint.get('threshold', 0.5) |
|
|
|
def get_config(self): |
|
"""Get model configuration""" |
|
return { |
|
'in_dim': self.in_dim, |
|
'rsa': self.rsa, |
|
'dihedral': self.dihedral, |
|
'node_dims': self.original_node_dims, |
|
'edge_dim': self.edge_dim, |
|
'dropout': self.dropout, |
|
'activation': self.activation, |
|
'residual': self.residual, |
|
'attention': self.attention, |
|
'normalize': self.normalize, |
|
'coords_agg': self.coords_agg, |
|
'ffn': self.ffn, |
|
'batch_norm': self.batch_norm, |
|
'concat': self.concat, |
|
'addition': self.addition, |
|
'pooling': self.pooling, |
|
'fusion_type': self.fusion_type, |
|
'node_gate': self.node_gate, |
|
'node_norm': self.node_norm, |
|
'node_layers': self.node_layers, |
|
'out_dropout': self.out_dropout, |
|
'use_egnn': self.use_egnn, |
|
'encoder': self.encoder |
|
} |
|
|
|
|
|
model_registry = { |
|
"ReCEP": ReCEP, |
|
} |
|
|
|
def get_model(configs): |
|
""" |
|
Flexible model loader. Accepts either an argparse.Namespace or a dict. |
|
Returns an instance of the selected model. |
|
""" |
|
|
|
if hasattr(configs, '__dict__'): |
|
args = vars(configs) |
|
else: |
|
args = configs |
|
|
|
|
|
model_name = args.get('model', 'ReCEP') |
|
|
|
if model_name not in model_registry: |
|
valid_models = list(model_registry.keys()) |
|
raise ValueError(f"Invalid model type: {model_name}. Must be one of: {valid_models}") |
|
|
|
model_class = model_registry[model_name] |
|
|
|
|
|
import inspect |
|
init_signature = inspect.signature(model_class.__init__) |
|
parameters = init_signature.parameters |
|
|
|
|
|
model_config = {} |
|
for param_name, param in parameters.items(): |
|
if param_name == 'self': |
|
continue |
|
if param_name in args: |
|
model_config[param_name] = args[param_name] |
|
elif param.default is not param.empty: |
|
model_config[param_name] = param.default |
|
else: |
|
print(f"[WARNING] Required parameter '{param_name}' not found in args and has no default value") |
|
|
|
|
|
model = model_class(**model_config) |
|
return model |