from pathlib import Path from prettytable import PrettyTable import torch import torch.nn as nn from torch_scatter import scatter_softmax, scatter_sum from torch_geometric.data import Data, Batch from .dihedral import DihedralFeatures from .EGNN import EGNNLayer from .pooling import AttentionPooling, AddPooling from .activation import get_activation # from .baseline import EP class ReCEP(nn.Module): """ Refined Graph Epitope Predictor with optional EGNN layer skipping for ablation. """ def __init__( self, in_dim: int = 2560, rsa: bool = True, dihedral: bool = True, node_dims: list = [512, 256, 256], edge_dim: int = 32, dropout: float = 0.3, activation: str = "gelu", residual: bool = True, attention: bool = True, normalize: bool = True, coords_agg: str = 'mean', ffn: bool = True, batch_norm: bool = True, concat: bool = False, addition: bool = False, # Global predictor pooling: str = 'attention', # Node classifier fusion_type: str = 'concat', node_gate: bool = False, node_norm: bool = False, node_layers: int = 2, out_dropout: float = 0.2, use_egnn: bool = True, # NEW: toggle for EGNN layer usage encoder: str = 'esmc', ): super().__init__() self.use_egnn = use_egnn self.in_dim = in_dim self.rsa = rsa self.dihedral = dihedral self.original_node_dims = node_dims.copy() self.edge_dim = edge_dim self.dropout = dropout self.activation = activation self.residual = residual self.attention = attention self.normalize = normalize self.ffn = ffn self.batch_norm = batch_norm self.coords_agg = coords_agg self.concat = concat self.addition = addition self.fusion_type = fusion_type self.node_gate = node_gate self.node_norm = node_norm self.node_layers = node_layers self.out_dropout = out_dropout self.pooling = pooling self.base_node_dim = node_dims[0] self.node_dims = node_dims.copy() self.node_dims[0] += 1 if rsa else 0 self.node_dims[-1] = self.node_dims[0] if addition else self.node_dims[-1] # Modify input dimension based on encoder self.encoder = encoder if encoder == 'esmc': self.in_dim = 2560 elif encoder == 'esm2': self.in_dim = 1280 else: self.in_dim = in_dim # Calculate actual final node dimension based on whether EGNN is used if self.use_egnn: self.final_node_dim = self.node_dims[-1] else: self.final_node_dim = self.node_dims[0] self.concat = False self.addition = False self.proj_layer = nn.Sequential( nn.Linear(self.in_dim, self.base_node_dim), get_activation(activation), nn.Dropout(dropout), ) if dihedral: try: self.dihedral_features = DihedralFeatures(self.base_node_dim) except: print("Warning: DihedralFeatures not found, skipping dihedral features") self.dihedral = False self.egnn_layers = nn.ModuleList() if self.use_egnn: for i in range(len(self.node_dims) - 1): self.egnn_layers.append( EGNNLayer( input_nf=self.node_dims[i], output_nf=self.node_dims[i+1], hidden_nf=self.node_dims[i+1], edges_in_d=edge_dim, act_fn=get_activation(activation), residual=residual, attention=attention, normalize=normalize, coords_agg=coords_agg, dropout=dropout, ffn=ffn, batch_norm=batch_norm ) ) if concat and self.use_egnn: self.final_node_dim += self.node_dims[0] if addition and self.use_egnn: assert self.node_dims[0] == self.node_dims[-1], "Node dimension mismatch for addition" self.final_node_dim = self.node_dims[0] # Calculate node classifier input dimension based on fusion type if fusion_type == 'concat': self.node_classifier_input_dim = self.final_node_dim * 2 elif fusion_type == 'add': self.node_classifier_input_dim = self.final_node_dim else: raise ValueError(f"Unsupported fusion type: {fusion_type}") # Calculate node gate input dimension if node_gate: if fusion_type == 'concat': self.node_gate_input_dim = self.final_node_dim * 2 elif fusion_type == 'add': self.node_gate_input_dim = self.final_node_dim else: raise ValueError(f"Unsupported fusion type: {fusion_type}") if pooling == 'attention': self.graph_pool = AttentionPooling( input_dim=self.final_node_dim, dropout=dropout, activation=activation ) elif pooling == 'add': self.graph_pool = AddPooling( input_dim=self.final_node_dim, dropout=dropout ) else: raise ValueError(f"Unsupported pooling method: {pooling}") self.global_predictor = nn.Sequential( nn.Linear(self.final_node_dim, self.final_node_dim // 2), get_activation(activation), nn.Dropout(out_dropout), nn.Linear(self.final_node_dim // 2, 1) ) if node_gate: self.node_gate = nn.Sequential( nn.Linear(self.node_gate_input_dim, self.final_node_dim), get_activation(activation), nn.LayerNorm(self.final_node_dim), nn.Linear(self.final_node_dim, self.final_node_dim), nn.Sigmoid() ) self.node_classifier = self._build_node_classifier() self._param_printed = False self.apply(self._init_weights) def _build_node_classifier(self): layers = [] input_dim = self.node_classifier_input_dim current_dim = input_dim for i in range(self.node_layers): output_dim = 1 if i == self.node_layers - 1 else max(current_dim // 2, 32) layers.append(nn.Linear(current_dim, output_dim)) if self.node_norm and i < self.node_layers - 1: layers.append(nn.LayerNorm(output_dim)) if i < self.node_layers - 1: layers.append(get_activation(self.activation)) layers.append(nn.Dropout(self.out_dropout)) current_dim = output_dim return nn.Sequential(*layers) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.xavier_normal_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0.0 if module.out_features == 1 else 0.01) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward(self, data: Batch) -> dict: if self.training and not self._param_printed: print(f"ReCEP total params: {sum(p.numel() for p in self.parameters()):,}") self._param_printed = True x = data.x coords = data.pos batch = data.batch e_attr = data.edge_attr coords_C = coords[:, 1].clone() x = self.proj_layer(x) if self.dihedral and coords is not None: x = x + self.dihedral_features(coords) if self.rsa and data.rsa is not None: rsa = data.rsa.unsqueeze(-1) x = torch.cat([x, rsa], dim=-1) h = x assert h.shape[1] == self.node_dims[0], f"[ReCEP] Node feature dim mismatch: got {h.shape[1]}, expected {self.node_dims[0]}" if self.use_egnn: for layer in self.egnn_layers: h, coords_C, _ = layer(h, coords_C, data.edge_index, batch, edge_attr=e_attr) if self.concat and self.use_egnn: h = torch.cat([x, h], dim=-1) elif self.addition and self.use_egnn: h = h + x graph_feats = self.graph_pool(h, batch) global_pred = self.global_predictor(graph_feats).squeeze(-1) context = graph_feats[batch] if self.node_gate and hasattr(self, 'node_gate'): if self.fusion_type == 'concat': gate_input = torch.cat([h, context], dim=-1) elif self.fusion_type == 'add': gate_input = h + context else: raise ValueError(f"Unsupported fusion type: {self.fusion_type}") gate = self.node_gate(gate_input) gated_h = h + gate * h else: gated_h = h if self.fusion_type == 'concat': cat = torch.cat([gated_h, context], dim=-1) elif self.fusion_type == 'add': # Ensure dimensions match for addition assert gated_h.shape[-1] == context.shape[-1], f"[ReCEP] Dimension mismatch for add fusion: gated_h {gated_h.shape[-1]} vs context {context.shape[-1]}" cat = gated_h + context else: raise ValueError(f"Unsupported fusion type: {self.fusion_type}") # Verify input dimension matches node classifier expectation expected_dim = self.node_classifier_input_dim actual_dim = cat.shape[-1] assert actual_dim == expected_dim, f"[ReCEP] Node classifier input dim mismatch: got {actual_dim}, expected {expected_dim}" node_preds = self.node_classifier(cat).squeeze(-1) return {"global_pred": global_pred, "node_preds": node_preds} def print_param_count(self): """Print a summary table of parameter counts""" table = PrettyTable() table.field_names = ["Layer Name", "Type", "Parameters", "Trainable"] total_params = 0 trainable_params = 0 for name, module in self.named_modules(): if not list(module.children()): # Only leaf nodes params = sum(p.numel() for p in module.parameters()) is_trainable = any(p.requires_grad for p in module.parameters()) if params > 0: total_params += params trainable_params += params if is_trainable else 0 table.add_row([ name, module.__class__.__name__, f"{params:,}", "✓" if is_trainable else "✗" ]) table.add_row(["", "", "", ""], divider=True) table.add_row([ "TOTAL", "", f"{total_params:,}", f"Trainable: {trainable_params:,}" ]) print("\nReCEP Model Parameter Summary:") print(table) print(f"Parameter Density: {trainable_params/total_params:.1%}\n") def save(self, path, threshold: float = 0.5): """Save model with configuration""" path = Path(path) try: path.parent.mkdir(parents=True, exist_ok=True) save_path = path.with_suffix('.bin') config = self.get_config() # config = { # 'in_dim': self.in_dim, # 'rsa': self.rsa, # 'dihedral': self.dihedral, # 'node_dims': self.original_node_dims, # Use original node_dims # 'edge_dim': self.edge_dim, # 'dropout': self.dropout, # 'activation': self.activation, # 'residual': self.residual, # 'attention': self.attention, # 'normalize': self.normalize, # 'coords_agg': self.coords_agg, # 'ffn': self.ffn, # 'batch_norm': self.batch_norm, # 'concat': self.concat, # 'node_norm': self.node_norm, # 'node_layers': self.node_layers, # 'node_gate': self.node_gate, # 'out_dropout': self.out_dropout # } torch.save({ 'model_state': self.state_dict(), 'config': config, 'model_class': self.__class__.__name__, 'version': '1.0', 'threshold': threshold }, save_path) print(f"ReCEP model saved to {save_path}") except Exception as e: print(f"Save failed: {str(e)}") raise @classmethod def load(cls, path, device='cpu', strict=True, verbose=True): """Load model with configuration""" path = Path(path) if not path.exists(): raise FileNotFoundError(f"Model file {path} not found") try: if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): if device >= 0 and torch.cuda.is_available(): device = torch.device(f'cuda:{device}') else: device = torch.device('cpu') elif not isinstance(device, torch.device): raise ValueError(f"Unsupported device type: {type(device)}") checkpoint = torch.load( path, map_location=device, weights_only=False ) except RuntimeError: print("Warning: Using unsafe load due to weights_only restriction") checkpoint = torch.load(path, map_location=device) # Version compatibility check if 'version' not in checkpoint: print("Warning: Loading legacy model without version info") # Rebuild configuration config = checkpoint.get('config', {}) model = cls(**config) # Load state dict model_state = checkpoint['model_state'] current_state = model.state_dict() # Auto-match parameters matched_state = {} for name, param in model_state.items(): if name in current_state: if param.shape == current_state[name].shape: matched_state[name] = param else: print(f"Size mismatch: {name} (load {param.shape} vs model {current_state[name].shape})") else: print(f"Parameter not found: {name}") current_state.update(matched_state) model.load_state_dict(current_state, strict=strict) if verbose: print(f"Successfully loaded {len(matched_state)}/{len(model_state)} parameters") return model.to(device), checkpoint.get('threshold', 0.5) def get_config(self): """Get model configuration""" return { 'in_dim': self.in_dim, 'rsa': self.rsa, 'dihedral': self.dihedral, 'node_dims': self.original_node_dims, 'edge_dim': self.edge_dim, 'dropout': self.dropout, 'activation': self.activation, 'residual': self.residual, 'attention': self.attention, 'normalize': self.normalize, 'coords_agg': self.coords_agg, 'ffn': self.ffn, 'batch_norm': self.batch_norm, 'concat': self.concat, 'addition': self.addition, 'pooling': self.pooling, 'fusion_type': self.fusion_type, 'node_gate': self.node_gate, 'node_norm': self.node_norm, 'node_layers': self.node_layers, 'out_dropout': self.out_dropout, 'use_egnn': self.use_egnn, 'encoder': self.encoder } model_registry = { "ReCEP": ReCEP, } def get_model(configs): """ Flexible model loader. Accepts either an argparse.Namespace or a dict. Returns an instance of the selected model. """ # Support both argparse.Namespace and dict if hasattr(configs, '__dict__'): args = vars(configs) else: args = configs # Default to ReCEP if no model specified model_name = args.get('model', 'ReCEP') if model_name not in model_registry: valid_models = list(model_registry.keys()) raise ValueError(f"Invalid model type: {model_name}. Must be one of: {valid_models}") model_class = model_registry[model_name] # Use inspect to get the model's __init__ parameters import inspect init_signature = inspect.signature(model_class.__init__) parameters = init_signature.parameters # Build model configuration from args model_config = {} for param_name, param in parameters.items(): if param_name == 'self': continue if param_name in args: model_config[param_name] = args[param_name] elif param.default is not param.empty: model_config[param_name] = param.default else: print(f"[WARNING] Required parameter '{param_name}' not found in args and has no default value") # print(f"[INFO] Creating {model_name} model with config: {list(model_config.keys())}") model = model_class(**model_config) return model