|
import os |
|
import h5py |
|
import torch |
|
import numpy as np |
|
import json |
|
import random |
|
from pathlib import Path |
|
from multiprocessing import Pool |
|
|
|
from typing import List, Tuple, Dict, Optional, Union |
|
from torch_geometric.data import Data, Dataset, Batch |
|
from torch_geometric.loader import DataLoader |
|
from tqdm import tqdm |
|
import pickle |
|
|
|
from ..utils.loading import load_epitopes_csv, load_data_split |
|
from ..utils.constants import BASE_DIR |
|
from .utils import create_graph_data, create_graph_data_full |
|
|
|
|
|
def apply_undersample(data_list: List, undersample_param: Union[int, float], seed: int = 42, verbose: bool = True): |
|
""" |
|
Apply undersampling to a data list. |
|
|
|
Args: |
|
data_list: List of data samples |
|
undersample_param: If int, sample that many samples; if float (0-1), sample that fraction of data |
|
seed: Random seed for reproducibility |
|
verbose: Whether to print sampling information |
|
|
|
Returns: |
|
Undersampled data list |
|
""" |
|
if undersample_param is None: |
|
return data_list |
|
|
|
original_size = len(data_list) |
|
|
|
if isinstance(undersample_param, float): |
|
|
|
if not (0 < undersample_param <= 1.0): |
|
raise ValueError(f"Float undersample must be between 0 and 1, got {undersample_param}") |
|
target_size = int(len(data_list) * undersample_param) |
|
elif isinstance(undersample_param, int): |
|
|
|
if undersample_param <= 0: |
|
raise ValueError(f"Int undersample must be positive, got {undersample_param}") |
|
target_size = min(undersample_param, len(data_list)) |
|
else: |
|
raise ValueError(f"Undersample must be int, float, or None, got {type(undersample_param)}") |
|
|
|
if target_size < len(data_list): |
|
|
|
random.seed(seed) |
|
sampled_data = random.sample(data_list, target_size) |
|
|
|
if verbose: |
|
print(f"Applied undersampling: {original_size} -> {target_size} samples") |
|
|
|
return sampled_data |
|
elif verbose: |
|
print(f"No undersampling applied: requested {target_size}, available {original_size}") |
|
|
|
return data_list |
|
|
|
class AntigenDataset(Dataset): |
|
""" |
|
Dataset for antigen chains. |
|
Each data point represents a complete protein as a graph, with nodes being residues |
|
and edges based on spatial distance (< 18 Å). |
|
""" |
|
def __init__( |
|
self, |
|
data_split: str = "train", |
|
radius: float = 18, |
|
threshold: float = 0.25, |
|
num_posenc: int = 16, |
|
num_rbf: int = 16, |
|
undersample: Union[int, float, None] = None, |
|
cache_dir: Optional[str] = None, |
|
force_rebuild: bool = False, |
|
verbose: bool = True, |
|
seed: int = 42, |
|
encoder: str = "esmc" |
|
): |
|
""" |
|
Initialize the antigen dataset. |
|
|
|
Args: |
|
data_split: Data split name ('train', 'val', 'test') |
|
radius: Distance threshold for edge creation (Å) |
|
threshold: SASA threshold for surface residues (not used in full protein) |
|
num_posenc: Number of positional encoding features |
|
num_rbf: Number of RBF features |
|
undersample: Undersample parameter (int for count, float for ratio) |
|
cache_dir: Directory to cache processed data |
|
force_rebuild: Whether to force rebuild the dataset |
|
verbose: Whether to print progress information |
|
seed: Random seed for reproducibility |
|
encoder: Encoder type ('esmc' or 'esm2') |
|
""" |
|
self.data_split = data_split |
|
self.radius = radius |
|
self.threshold = threshold |
|
self.num_posenc = num_posenc |
|
self.num_rbf = num_rbf |
|
self.undersample = undersample |
|
self.verbose = verbose |
|
self.seed = seed |
|
self.encoder = encoder |
|
|
|
|
|
if cache_dir is None: |
|
cache_dir = Path(f"{BASE_DIR}/data/full_region_cache/antigen_r{radius}") |
|
self.cache_dir = Path(cache_dir) |
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.cache_file = self.cache_dir / f"{data_split}_antigen_dataset.h5" |
|
|
|
|
|
self.antigens = load_data_split(data_split, verbose=verbose) |
|
_, _, self.epitope_dict = load_epitopes_csv() |
|
|
|
|
|
self.data_list = [] |
|
|
|
|
|
if self.cache_file.exists() and not force_rebuild: |
|
if verbose: |
|
print(f"Loading cached antigen dataset from {self.cache_file}") |
|
self._load_cache() |
|
else: |
|
if verbose: |
|
print(f"Building antigen dataset for {data_split} split...") |
|
self._build_dataset() |
|
self._save_cache() |
|
|
|
super().__init__() |
|
|
|
def _load_protein_data(self, pdb_id: str, chain_id: str) -> Optional[Dict]: |
|
""" |
|
Load precomputed protein data from files. |
|
|
|
Args: |
|
pdb_id: PDB ID |
|
chain_id: Chain ID |
|
|
|
Returns: |
|
Dictionary containing all protein data or None if loading fails |
|
""" |
|
try: |
|
protein_key = f"{pdb_id}_{chain_id}" |
|
|
|
|
|
embedding_file = Path(BASE_DIR) / "data" / "embeddings" / self.encoder / f"{protein_key}.h5" |
|
if not embedding_file.exists(): |
|
if self.verbose: |
|
print(f"Embedding file not found: {embedding_file}") |
|
return None |
|
|
|
with h5py.File(embedding_file, "r") as h5f: |
|
embeddings = h5f["embedding"][:] |
|
|
|
|
|
coords_file = Path(BASE_DIR) / "data" / "coords" / f"{protein_key}.npy" |
|
if not coords_file.exists(): |
|
if self.verbose: |
|
print(f"Coords file not found: {coords_file}") |
|
return None |
|
backbone_atoms = np.load(coords_file) |
|
|
|
|
|
rsa_file = Path(BASE_DIR) / "data" / "rsa" / f"{protein_key}.npy" |
|
if not rsa_file.exists(): |
|
if self.verbose: |
|
print(f"RSA file not found: {rsa_file}") |
|
return None |
|
rsa_values = np.load(rsa_file) |
|
|
|
|
|
binary_labels = self.epitope_dict.get(protein_key, []) |
|
|
|
|
|
epitope_indices = [] |
|
for idx, is_epitope in enumerate(binary_labels): |
|
if is_epitope == 1: |
|
epitope_indices.append(idx) |
|
|
|
return { |
|
'embeddings': embeddings, |
|
'backbone_atoms': backbone_atoms, |
|
'rsa_values': rsa_values, |
|
'epitope_indices': epitope_indices, |
|
} |
|
|
|
except Exception as e: |
|
if self.verbose: |
|
print(f"Error loading protein data for {pdb_id}_{chain_id}: {str(e)}") |
|
return None |
|
|
|
def _build_dataset(self): |
|
"""Build the dataset from precomputed data files.""" |
|
failed_proteins = [] |
|
|
|
for pdb_id, chain_id in tqdm(self.antigens, desc=f"Processing {self.data_split} antigens", |
|
disable=not self.verbose): |
|
try: |
|
|
|
protein_data = self._load_protein_data(pdb_id, chain_id) |
|
if protein_data is None: |
|
failed_proteins.append(f"{pdb_id}_{chain_id}") |
|
continue |
|
|
|
embeddings = protein_data['embeddings'] |
|
backbone_atoms = protein_data['backbone_atoms'] |
|
rsa_values = protein_data['rsa_values'] |
|
epitope_indices = protein_data['epitope_indices'] |
|
|
|
|
|
graph_data = create_graph_data_full( |
|
embeddings=embeddings, |
|
backbone_atoms=backbone_atoms, |
|
rsa_values=rsa_values, |
|
epitope_indices=epitope_indices, |
|
pdb_id=pdb_id, |
|
chain_id=chain_id, |
|
num_rbf=self.num_rbf, |
|
num_posenc=self.num_posenc, |
|
radius=self.radius, |
|
verbose=self.verbose |
|
) |
|
|
|
if graph_data is not None: |
|
self.data_list.append(graph_data) |
|
else: |
|
failed_proteins.append(f"{pdb_id}_{chain_id}") |
|
|
|
except Exception as e: |
|
failed_proteins.append(f"{pdb_id}_{chain_id}") |
|
if self.verbose: |
|
print(f"Error processing {pdb_id}_{chain_id}: {str(e)}") |
|
|
|
if failed_proteins and self.verbose: |
|
print(f"Failed to process {len(failed_proteins)} proteins: {failed_proteins[:5]}...") |
|
|
|
|
|
if self.undersample is not None: |
|
self.data_list = apply_undersample( |
|
self.data_list, |
|
self.undersample, |
|
seed=self.seed, |
|
verbose=self.verbose |
|
) |
|
|
|
if self.verbose: |
|
print(f"Successfully created {len(self.data_list)} protein graphs") |
|
|
|
def _save_cache(self): |
|
"""Save processed dataset to cache.""" |
|
try: |
|
self._save_cache_hdf5() |
|
if self.verbose: |
|
print(f"Dataset cached to {self.cache_file}") |
|
except Exception as e: |
|
if self.verbose: |
|
print(f"Failed to save cache: {str(e)}") |
|
|
|
def _load_cache(self): |
|
"""Load processed dataset from cache.""" |
|
try: |
|
self._load_cache_hdf5() |
|
if self.verbose: |
|
print(f"Loaded {len(self.data_list)} samples from cache") |
|
except Exception as e: |
|
if self.verbose: |
|
print(f"Failed to load cache: {str(e)}") |
|
self.data_list = [] |
|
|
|
def _save_cache_hdf5(self): |
|
"""Save dataset using HDF5 format.""" |
|
with h5py.File(self.cache_file, 'w') as f: |
|
|
|
f.attrs['num_samples'] = len(self.data_list) |
|
f.attrs['radius'] = self.radius |
|
f.attrs['threshold'] = self.threshold |
|
f.attrs['data_split'] = self.data_split |
|
f.attrs['encoder'] = self.encoder |
|
f.attrs['dataset_type'] = 'antigen_full' |
|
|
|
|
|
for i, data in enumerate(tqdm(self.data_list, desc="Saving dataset...", disable=not self.verbose)): |
|
group = f.create_group(f'protein_{i}') |
|
|
|
|
|
group.create_dataset('x', data=data.x.numpy(), compression='gzip', compression_opts=6) |
|
group.create_dataset('pos', data=data.pos.numpy(), compression='gzip', compression_opts=6) |
|
group.create_dataset('rsa', data=data.rsa.numpy(), compression='gzip', compression_opts=6) |
|
group.create_dataset('edge_index', data=data.edge_index.numpy(), compression='gzip', compression_opts=6) |
|
group.create_dataset('edge_attr', data=data.edge_attr.numpy(), compression='gzip', compression_opts=6) |
|
group.create_dataset('y_node', data=data.y_node.numpy(), compression='gzip', compression_opts=6) |
|
|
|
|
|
group.attrs['pdb_id'] = data.pdb_id.encode('utf-8') |
|
group.attrs['chain_id'] = data.chain_id.encode('utf-8') |
|
group.attrs['num_nodes'] = data.num_nodes |
|
group.attrs['num_epitopes'] = data.num_epitopes |
|
group.attrs['epitope_ratio'] = data.epitope_ratio |
|
group.attrs['radius'] = data.radius |
|
|
|
|
|
group.create_dataset('epitope_indices', data=np.array(data.epitope_indices), compression='gzip', compression_opts=6) |
|
|
|
def _load_cache_hdf5(self): |
|
"""Load dataset from HDF5 cache.""" |
|
self.data_list = [] |
|
|
|
with h5py.File(self.cache_file, 'r') as f: |
|
total_samples = f.attrs['num_samples'] |
|
|
|
for i in tqdm(range(total_samples), desc="Loading dataset...", disable=not self.verbose): |
|
group = f[f'protein_{i}'] |
|
attrs = dict(group.attrs) |
|
|
|
|
|
def safe_decode(attr): |
|
val = attrs[attr] |
|
return val.decode('utf-8') if isinstance(val, bytes) else str(val) |
|
|
|
data = Data( |
|
x=torch.tensor(group['x'][:]), |
|
pos=torch.tensor(group['pos'][:]), |
|
rsa=torch.tensor(group['rsa'][:]), |
|
edge_index=torch.tensor(group['edge_index'][:]), |
|
edge_attr=torch.tensor(group['edge_attr'][:]), |
|
y_node=torch.tensor(group['y_node'][:]), |
|
epitope_indices=group['epitope_indices'][:].tolist(), |
|
pdb_id=safe_decode('pdb_id'), |
|
chain_id=safe_decode('chain_id'), |
|
num_nodes=int(attrs['num_nodes']), |
|
num_epitopes=int(attrs['num_epitopes']), |
|
epitope_ratio=float(attrs['epitope_ratio']), |
|
radius=float(attrs['radius']) |
|
) |
|
self.data_list.append(data) |
|
|
|
|
|
if self.undersample is not None: |
|
self.data_list = apply_undersample( |
|
self.data_list, |
|
self.undersample, |
|
seed=self.seed, |
|
verbose=self.verbose |
|
) |
|
|
|
def len(self) -> int: |
|
"""Return the number of samples in the dataset.""" |
|
return len(self.data_list) |
|
|
|
def get(self, idx: int) -> Data: |
|
"""Get a sample by index.""" |
|
return self.data_list[idx] |
|
|
|
def get_stats(self) -> Dict: |
|
"""Get dataset statistics.""" |
|
if not self.data_list: |
|
return {} |
|
|
|
|
|
num_nodes_list = [data.num_nodes for data in self.data_list] |
|
num_edges_list = [data.edge_index.shape[1] for data in self.data_list] |
|
num_epitopes_list = [data.num_epitopes for data in self.data_list] |
|
epitope_ratio_list = [data.epitope_ratio for data in self.data_list] |
|
|
|
|
|
total_nodes = sum(num_nodes_list) |
|
total_edges = sum(num_edges_list) |
|
total_epitopes = sum(num_epitopes_list) |
|
|
|
stats = { |
|
'num_proteins': len(self.data_list), |
|
'avg_nodes_per_protein': np.mean(num_nodes_list), |
|
'std_nodes_per_protein': np.std(num_nodes_list), |
|
'min_nodes_per_protein': np.min(num_nodes_list), |
|
'max_nodes_per_protein': np.max(num_nodes_list), |
|
'avg_edges_per_protein': np.mean(num_edges_list), |
|
'std_edges_per_protein': np.std(num_edges_list), |
|
'total_nodes': total_nodes, |
|
'total_edges': total_edges, |
|
'total_epitopes': total_epitopes, |
|
'avg_epitopes_per_protein': np.mean(num_epitopes_list), |
|
'avg_epitope_ratio': np.mean(epitope_ratio_list), |
|
'overall_epitope_ratio': total_epitopes / total_nodes if total_nodes > 0 else 0, |
|
} |
|
|
|
return stats |
|
|
|
def print_stats(self): |
|
"""Print dataset statistics.""" |
|
stats = self.get_stats() |
|
if not stats: |
|
print("No statistics available (empty dataset)") |
|
return |
|
|
|
print(f"\n=== {self.data_split.upper()} Antigen Dataset Statistics ===") |
|
print(f"Number of proteins: {stats['num_proteins']:,}") |
|
print(f"Average nodes per protein: {stats['avg_nodes_per_protein']:.1f} ± {stats['std_nodes_per_protein']:.1f}") |
|
print(f"Nodes per protein range: [{stats['min_nodes_per_protein']}, {stats['max_nodes_per_protein']}]") |
|
print(f"Average edges per protein: {stats['avg_edges_per_protein']:.1f} ± {stats['std_edges_per_protein']:.1f}") |
|
print(f"Total nodes: {stats['total_nodes']:,}") |
|
print(f"Total edges: {stats['total_edges']:,}") |
|
print(f"Total epitope nodes: {stats['total_epitopes']:,}") |
|
print(f"Average epitopes per protein: {stats['avg_epitopes_per_protein']:.1f}") |
|
print(f"Average epitope ratio per protein: {stats['avg_epitope_ratio']:.3f}") |
|
print(f"Overall epitope ratio: {stats['overall_epitope_ratio']:.3f}") |
|
print("=" * 50) |
|
|
|
|
|
class SphereGraphDataset(Dataset): |
|
""" |
|
Optimized graph dataset for training ReGEP model using spherical regions from antigen chains. |
|
Each graph represents a spherical region centered on a surface residue. |
|
|
|
Optimizations: |
|
- Only uses HDF5 format for caching |
|
- Builds complete dataset without zero_ratio filtering |
|
- Applies zero_ratio and undersample during loading |
|
- Faster caching with optimized HDF5 structure |
|
""" |
|
|
|
def __init__( |
|
self, |
|
data_split: str = "train", |
|
radius: int = 18, |
|
threshold: float = 0.25, |
|
num_posenc: int = 16, |
|
num_rbf: int = 16, |
|
zero_ratio: float = 0.1, |
|
undersample: Union[int, float, None] = None, |
|
cache_dir: Optional[str] = None, |
|
force_rebuild: bool = False, |
|
verbose: bool = True, |
|
seed: int = 42, |
|
use_embeddings2: bool = False |
|
): |
|
""" |
|
Initialize the spherical graph dataset. |
|
|
|
Args: |
|
data_split: Data split name ('train', 'val', 'test') |
|
radius: Radius for spherical regions |
|
threshold: SASA threshold for surface residues |
|
num_posenc: Number of positional encoding features |
|
num_rbf: Number of RBF features |
|
zero_ratio: Ratio to downsample graphs with recall=0 (0.3 means keep 30%) |
|
undersample: Undersample parameter (int for count, float for ratio) |
|
cache_dir: Directory to cache processed data |
|
force_rebuild: Whether to force rebuild the dataset |
|
verbose: Whether to print progress information |
|
seed: Random seed for reproducibility |
|
""" |
|
self.data_split = data_split |
|
self.radius = radius |
|
self.threshold = threshold |
|
self.num_posenc = num_posenc |
|
self.num_rbf = num_rbf |
|
self.zero_ratio = zero_ratio |
|
self.undersample = undersample |
|
self.verbose = verbose |
|
self.seed = seed |
|
self.use_embeddings2 = use_embeddings2 |
|
|
|
|
|
if cache_dir is None: |
|
cache_dir = Path(f"{BASE_DIR}/data/region_cache/sphere_r{radius}") |
|
self.cache_dir = Path(cache_dir) |
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.cache_file = self.cache_dir / f"{data_split}_dataset_complete.h5" |
|
|
|
|
|
self.antigens = load_data_split(data_split, verbose=verbose) |
|
|
|
|
|
self.data_list = [] |
|
|
|
|
|
if self.cache_file.exists() and not force_rebuild: |
|
if verbose: |
|
print(f"Loading cached dataset with radius {self.radius} from {self.cache_file}") |
|
self._load_cache() |
|
else: |
|
if verbose: |
|
print(f"Building complete dataset with radius {self.radius} for {data_split} split...") |
|
self._build_dataset() |
|
self._save_cache() |
|
|
|
super().__init__() |
|
|
|
def _load_protein_data(self, pdb_id: str, chain_id: str) -> Optional[Dict]: |
|
""" |
|
Load precomputed protein data from files. |
|
|
|
Args: |
|
pdb_id: PDB ID |
|
chain_id: Chain ID |
|
|
|
Returns: |
|
Dictionary containing all protein data or None if loading fails |
|
""" |
|
try: |
|
protein_key = f"{pdb_id}_{chain_id}" |
|
|
|
|
|
embedding_file = Path(BASE_DIR) / "data" / "embeddings" / 'esmc' / f"{protein_key}.h5" |
|
if not embedding_file.exists(): |
|
if self.verbose: |
|
print(f"Embedding file not found: {embedding_file}") |
|
return None |
|
|
|
with h5py.File(embedding_file, "r") as h5f: |
|
embeddings = h5f["embedding"][:] |
|
|
|
|
|
esm2_file = Path(BASE_DIR) / "data" / "embeddings" / "esm2" / f"{protein_key}.h5" |
|
if not esm2_file.exists(): |
|
if self.verbose: |
|
print(f"ESM2 file not found: {esm2_file}") |
|
embeddings2 = None |
|
else: |
|
with h5py.File(esm2_file, "r") as h5f: |
|
embeddings2 = h5f["embedding"][:] |
|
|
|
|
|
coords_file = Path(BASE_DIR) / "data" / "coords" / f"{protein_key}.npy" |
|
if not coords_file.exists(): |
|
if self.verbose: |
|
print(f"Coords file not found: {coords_file}") |
|
return None |
|
backbone_atoms = np.load(coords_file) |
|
|
|
|
|
rsa_file = Path(BASE_DIR) / "data" / "rsa" / f"{protein_key}.npy" |
|
if not rsa_file.exists(): |
|
if self.verbose: |
|
print(f"RSA file not found: {rsa_file}") |
|
return None |
|
rsa_values = np.load(rsa_file) |
|
|
|
|
|
sphere_file = Path(BASE_DIR) / "data" / "antigen_sphere" / f"{protein_key}.h5" |
|
radius_key = f"r{self.radius}" |
|
|
|
if not sphere_file.exists(): |
|
if self.verbose: |
|
print(f"Sphere file not found: {sphere_file}") |
|
return None |
|
|
|
coverage_dict = {} |
|
with h5py.File(sphere_file, "r") as h5f: |
|
if radius_key not in h5f: |
|
if self.verbose: |
|
print(f"Radius {self.radius} not found in {sphere_file}") |
|
return None |
|
|
|
radius_group = h5f[radius_key] |
|
for center_idx_str in radius_group.keys(): |
|
center_idx = int(center_idx_str) |
|
center_group = radius_group[center_idx_str] |
|
covered_indices = center_group['covered_indices'][:].tolist() |
|
covered_epitope_indices = center_group['covered_epitope_indices'][:].tolist() |
|
precision = float(center_group.attrs['precision']) |
|
recall = float(center_group.attrs['recall']) |
|
coverage_dict[center_idx] = (covered_indices, covered_epitope_indices, precision, recall) |
|
|
|
|
|
_, _, epitopes = load_epitopes_csv() |
|
binary_labels = epitopes.get(protein_key, []) |
|
|
|
|
|
epitope_indices = [] |
|
for idx, is_epitope in enumerate(binary_labels): |
|
if is_epitope == 1: |
|
epitope_indices.append(idx) |
|
|
|
return { |
|
'embeddings': embeddings, |
|
'backbone_atoms': backbone_atoms, |
|
'rsa_values': rsa_values, |
|
'coverage_dict': coverage_dict, |
|
'epitope_indices': epitope_indices, |
|
'embeddings2': embeddings2 |
|
} |
|
|
|
except Exception as e: |
|
if self.verbose: |
|
print(f"Error loading protein data for {pdb_id}_{chain_id}: {str(e)}") |
|
return None |
|
|
|
def _build_dataset(self): |
|
"""Build the complete dataset from precomputed data files (no zero_ratio filtering).""" |
|
failed_proteins = [] |
|
|
|
for pdb_id, chain_id in tqdm(self.antigens, desc=f"Processing {self.data_split} antigens", |
|
disable=not self.verbose): |
|
try: |
|
|
|
protein_data = self._load_protein_data(pdb_id, chain_id) |
|
if protein_data is None: |
|
if self.verbose: |
|
print(f"Failed to load data for {pdb_id}_{chain_id}") |
|
continue |
|
|
|
embeddings = protein_data['embeddings'] |
|
embeddings2 = protein_data['embeddings2'] |
|
backbone_atoms = protein_data['backbone_atoms'] |
|
rsa_values = protein_data['rsa_values'] |
|
coverage_dict = protein_data['coverage_dict'] |
|
epitope_indices = protein_data['epitope_indices'] |
|
|
|
if not coverage_dict: |
|
if self.verbose: |
|
print(f"No surface regions found for {pdb_id}_{chain_id}") |
|
continue |
|
|
|
|
|
for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in coverage_dict.items(): |
|
if len(covered_indices) < 2: |
|
continue |
|
|
|
|
|
graph_data = create_graph_data( |
|
center_idx=center_idx, |
|
covered_indices=covered_indices, |
|
covered_epitope_indices=covered_epitope_indices, |
|
embeddings=embeddings, |
|
embeddings2=embeddings2, |
|
backbone_atoms=backbone_atoms, |
|
rsa_values=rsa_values, |
|
epitope_indices=epitope_indices, |
|
recall=recall, |
|
precision=precision, |
|
pdb_id=pdb_id, |
|
chain_id=chain_id, |
|
num_rbf=self.num_rbf, |
|
num_posenc=self.num_posenc, |
|
verbose=self.verbose |
|
) |
|
|
|
if graph_data is not None: |
|
self.data_list.append(graph_data) |
|
|
|
except Exception as e: |
|
failed_proteins.append(f"{pdb_id}_{chain_id}") |
|
if self.verbose: |
|
print(f"Error processing {pdb_id}_{chain_id}: {str(e)}") |
|
|
|
if failed_proteins and self.verbose: |
|
print(f"Failed to process {len(failed_proteins)} proteins: {failed_proteins[:5]}...") |
|
|
|
if self.verbose: |
|
print(f"Successfully created {len(self.data_list)} graph samples (complete dataset)") |
|
|
|
def _save_cache(self): |
|
"""Save processed dataset to cache.""" |
|
try: |
|
self._save_cache_hdf5() |
|
if self.verbose: |
|
print(f"Dataset cached to {self.cache_file}") |
|
except Exception as e: |
|
if self.verbose: |
|
print(f"Failed to save cache: {str(e)}") |
|
|
|
def _load_cache(self): |
|
"""Load processed dataset from cache.""" |
|
try: |
|
self._load_cache_hdf5() |
|
if self.verbose: |
|
print(f"Loaded {len(self.data_list)} samples from cache") |
|
except Exception as e: |
|
if self.verbose: |
|
print(f"Failed to load cache: {str(e)}") |
|
self.data_list = [] |
|
|
|
def _save_cache_hdf5(self): |
|
"""Save dataset using optimized HDF5 format for faster loading.""" |
|
with h5py.File(self.cache_file, 'w') as f: |
|
|
|
f.attrs['num_samples'] = len(self.data_list) |
|
f.attrs['radius'] = self.radius |
|
f.attrs['threshold'] = self.threshold |
|
f.attrs['data_split'] = self.data_split |
|
f.attrs['complete_dataset'] = True |
|
|
|
|
|
num_samples = len(self.data_list) |
|
if num_samples == 0: |
|
return |
|
|
|
|
|
all_x = [] |
|
all_pos = [] |
|
all_rsa = [] |
|
all_edge_index = [] |
|
all_edge_attr = [] |
|
all_y = [] |
|
all_y_node = [] |
|
all_center_idx = [] |
|
all_precision = [] |
|
all_pdb_ids = [] |
|
all_chain_ids = [] |
|
all_num_nodes = [] |
|
all_covered_indices = [] |
|
all_embeddings2 = [] |
|
|
|
max_nodes = 0 |
|
max_edges = 0 |
|
|
|
for data in self.data_list: |
|
all_x.append(data.x.numpy()) |
|
all_pos.append(data.pos.numpy()) |
|
all_rsa.append(data.rsa.numpy()) |
|
all_edge_index.append(data.edge_index.numpy()) |
|
all_edge_attr.append(data.edge_attr.numpy()) |
|
all_y.append(data.y.numpy()) |
|
all_y_node.append(data.y_node.numpy()) |
|
all_center_idx.append(data.center_idx) |
|
all_precision.append(data.precision) |
|
all_pdb_ids.append(data.pdb_id.encode('utf-8')) |
|
all_chain_ids.append(data.chain_id.encode('utf-8')) |
|
all_num_nodes.append(data.num_nodes) |
|
all_covered_indices.append(data.covered_indices) |
|
|
|
|
|
if hasattr(data, 'embeddings2') and data.embeddings2 is not None: |
|
if isinstance(data.embeddings2, np.ndarray): |
|
all_embeddings2.append(data.embeddings2) |
|
else: |
|
|
|
all_embeddings2.append(data.embeddings2.numpy()) |
|
else: |
|
|
|
all_embeddings2.append(np.zeros((data.num_nodes, 1280), dtype=np.float32)) |
|
|
|
max_nodes = max(max_nodes, data.num_nodes) |
|
max_edges = max(max_edges, data.edge_index.shape[1]) |
|
|
|
|
|
progress_bar = tqdm(enumerate(self.data_list), total=num_samples, desc="Saving dataset...", disable=not self.verbose) |
|
|
|
for i, data in progress_bar: |
|
group = f.create_group(f'graph_{i}') |
|
|
|
|
|
group.create_dataset('x', data=all_x[i], compression='gzip', compression_opts=6) |
|
group.create_dataset('pos', data=all_pos[i], compression='gzip', compression_opts=6) |
|
group.create_dataset('rsa', data=all_rsa[i], compression='gzip', compression_opts=6) |
|
group.create_dataset('edge_index', data=all_edge_index[i], compression='gzip', compression_opts=6) |
|
group.create_dataset('edge_attr', data=all_edge_attr[i], compression='gzip', compression_opts=6) |
|
group.create_dataset('y', data=all_y[i], compression='gzip', compression_opts=6) |
|
group.create_dataset('y_node', data=all_y_node[i], compression='gzip', compression_opts=6) |
|
group.create_dataset('embeddings2', data=all_embeddings2[i], compression='gzip', compression_opts=6) |
|
|
|
|
|
group.attrs['center_idx'] = all_center_idx[i] |
|
group.attrs['precision'] = all_precision[i] |
|
group.attrs['pdb_id'] = all_pdb_ids[i] |
|
group.attrs['chain_id'] = all_chain_ids[i] |
|
group.attrs['num_nodes'] = all_num_nodes[i] |
|
|
|
|
|
group.create_dataset('covered_indices', data=np.array(all_covered_indices[i]), compression='gzip', compression_opts=6) |
|
|
|
def _load_cache_hdf5(self): |
|
"""Optimized cache loader with robust string handling.""" |
|
self.data_list = [] |
|
|
|
with h5py.File(self.cache_file, 'r') as f: |
|
|
|
zero_recall_indices = [] |
|
non_zero_recall_indices = [] |
|
total_samples = f.attrs['num_samples'] |
|
|
|
if self.verbose: |
|
print(f"Scanning {total_samples} samples for recall values...") |
|
|
|
for i in range(total_samples): |
|
recall = f[f'graph_{i}/y'][0].item() |
|
if recall == 0.0: |
|
zero_recall_indices.append(i) |
|
else: |
|
non_zero_recall_indices.append(i) |
|
|
|
|
|
selected_indices = non_zero_recall_indices.copy() |
|
|
|
if isinstance(self.zero_ratio, (int, float)) and 0 <= self.zero_ratio <= 1: |
|
if self.zero_ratio < 1.0 and zero_recall_indices: |
|
random.seed(self.seed) |
|
target_count = int(len(zero_recall_indices) * self.zero_ratio) |
|
selected_zero_indices = random.sample(zero_recall_indices, target_count) |
|
selected_indices.extend(selected_zero_indices) |
|
|
|
if self.verbose: |
|
kept = len(selected_zero_indices) |
|
total = len(zero_recall_indices) |
|
print(f"Zero-recall filtering: kept {kept}/{total} samples (ratio={self.zero_ratio})") |
|
else: |
|
selected_indices.extend(zero_recall_indices) |
|
|
|
|
|
if self.verbose: |
|
print(f"Loading {len(selected_indices)} selected samples...") |
|
|
|
for idx in tqdm(selected_indices, disable=not self.verbose): |
|
group = f[f'graph_{idx}'] |
|
attrs = dict(group.attrs) |
|
|
|
|
|
def safe_decode(attr): |
|
val = attrs[attr] |
|
return val.decode('utf-8') if isinstance(val, bytes) else str(val) |
|
|
|
|
|
if 'embeddings2' in group and self.use_embeddings2: |
|
if group['embeddings2'] is not None: |
|
emb = torch.tensor(group['embeddings2'][:]) |
|
else: |
|
emb = torch.tensor(group['x'][:]) |
|
else: |
|
emb = torch.tensor(group['x'][:]) |
|
|
|
data = Data( |
|
x=emb, |
|
pos=torch.tensor(group['pos'][:]), |
|
rsa=torch.tensor(group['rsa'][:]), |
|
edge_index=torch.tensor(group['edge_index'][:]), |
|
edge_attr=torch.tensor(group['edge_attr'][:]), |
|
y=torch.tensor(group['y'][:]), |
|
y_node=torch.tensor(group['y_node'][:]), |
|
center_idx=int(attrs['center_idx']), |
|
covered_indices=group['covered_indices'][:].tolist(), |
|
precision=float(attrs['precision']), |
|
pdb_id=safe_decode('pdb_id'), |
|
chain_id=safe_decode('chain_id'), |
|
num_nodes=int(attrs['num_nodes']) |
|
) |
|
|
|
self.data_list.append(data) |
|
|
|
|
|
if self.undersample is not None: |
|
self.data_list = apply_undersample( |
|
self.data_list, |
|
self.undersample, |
|
seed=self.seed, |
|
verbose=self.verbose |
|
) |
|
|
|
if self.verbose: |
|
print(f"Loaded {len(self.data_list)} samples (optimized loader)") |
|
|
|
def len(self) -> int: |
|
"""Return the number of samples in the dataset.""" |
|
return len(self.data_list) |
|
|
|
def get(self, idx: int) -> Data: |
|
"""Get a sample by index.""" |
|
return self.data_list[idx] |
|
|
|
def apply_filters(self, zero_ratio: Optional[float] = None, undersample: Union[int, float, None] = None, seed: int = None): |
|
""" |
|
Apply filtering to the already loaded dataset (for compatibility). |
|
Note: It's more efficient to set these parameters during initialization. |
|
|
|
Args: |
|
zero_ratio: Ratio to downsample graphs with recall=0 |
|
undersample: Undersample parameter |
|
seed: Random seed for reproducibility |
|
""" |
|
if seed is None: |
|
seed = self.seed |
|
|
|
|
|
if zero_ratio is not None: |
|
self.zero_ratio = zero_ratio |
|
if undersample is not None: |
|
self.undersample = undersample |
|
if seed is not None: |
|
self.seed = seed |
|
|
|
|
|
if self.cache_file.exists(): |
|
if self.verbose: |
|
print("Re-applying filters to cached dataset...") |
|
self._load_cache_hdf5() |
|
else: |
|
if self.verbose: |
|
print("Warning: No cache file found, filters cannot be applied") |
|
|
|
def get_stats(self) -> Dict: |
|
"""Get dataset statistics.""" |
|
if not self.data_list: |
|
return {} |
|
|
|
|
|
num_nodes_list = [data.num_nodes for data in self.data_list] |
|
recall_list = [data.y.item() for data in self.data_list] |
|
precision_list = [data.precision for data in self.data_list] |
|
|
|
|
|
total_nodes = sum(num_nodes_list) |
|
total_epitopes = sum([data.y_node.sum().item() for data in self.data_list]) |
|
num_zero_recall = sum([1 for data in self.data_list if data.y.item() == 0]) |
|
|
|
stats = { |
|
'num_graphs': len(self.data_list), |
|
'avg_nodes_per_graph': np.mean(num_nodes_list), |
|
'std_nodes_per_graph': np.std(num_nodes_list), |
|
'min_nodes_per_graph': np.min(num_nodes_list), |
|
'max_nodes_per_graph': np.max(num_nodes_list), |
|
'total_nodes': total_nodes, |
|
'total_epitopes': total_epitopes, |
|
'epitope_ratio': total_epitopes / total_nodes if total_nodes > 0 else 0, |
|
'avg_recall': np.mean(recall_list), |
|
'std_recall': np.std(recall_list), |
|
'avg_precision': np.mean(precision_list), |
|
'std_precision': np.std(precision_list), |
|
'num_zero_recall': num_zero_recall, |
|
} |
|
|
|
return stats |
|
|
|
def print_stats(self): |
|
"""Print dataset statistics.""" |
|
stats = self.get_stats() |
|
if not stats: |
|
print("No statistics available (empty dataset)") |
|
return |
|
|
|
print(f"\n=== {self.data_split.upper()} Dataset Statistics ===") |
|
print(f"Number of graphs: {stats['num_graphs']:,}") |
|
print(f"Average nodes per graph: {stats['avg_nodes_per_graph']:.1f} ± {stats['std_nodes_per_graph']:.1f}") |
|
print(f"Nodes per graph range: [{stats['min_nodes_per_graph']}, {stats['max_nodes_per_graph']}]") |
|
print(f"Total nodes: {stats['total_nodes']:,}") |
|
print(f"Total epitope nodes: {stats['total_epitopes']:,}") |
|
print(f"Epitope ratio: {stats['epitope_ratio']:.3f}") |
|
print(f"Average recall: {stats['avg_recall']:.3f} ± {stats['std_recall']:.3f}") |
|
print(f"Average precision: {stats['avg_precision']:.3f} ± {stats['std_precision']:.3f}") |
|
print(f"Number of graphs with zero recall: {stats['num_zero_recall']:,}") |
|
print("=" * 40) |
|
|
|
|
|
class MultiRadiusGraphDataset(Dataset): |
|
""" |
|
Dataset that combines multiple radius datasets for multi-scale training. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
data_split: str = "train", |
|
radii: List[int] = [16, 18, 20], |
|
threshold: float = 0.25, |
|
num_posenc: int = 16, |
|
num_rbf: int = 16, |
|
zero_ratio: float = 0.1, |
|
undersample: Union[int, float, None] = None, |
|
cache_dir: Optional[str] = None, |
|
force_rebuild: bool = False, |
|
verbose: bool = True, |
|
use_embeddings2: bool = False |
|
): |
|
""" |
|
Initialize multi-radius dataset. |
|
|
|
Args: |
|
data_split: Data split name |
|
radii: List of radii to use |
|
threshold: SASA threshold for surface residues |
|
num_posenc: Number of positional encoding features |
|
num_rbf: Number of RBF features |
|
zero_ratio: Ratio to downsample graphs with recall=0 |
|
undersample: Undersample parameter (int for count, float for ratio) |
|
cache_dir: Directory to cache processed data |
|
force_rebuild: Whether to force rebuild the dataset |
|
verbose: Whether to print progress information |
|
""" |
|
self.data_split = data_split |
|
self.radii = radii |
|
self.verbose = verbose |
|
|
|
|
|
self.datasets = [] |
|
for radius in radii: |
|
dataset = SphereGraphDataset( |
|
data_split=data_split, |
|
radius=radius, |
|
threshold=threshold, |
|
num_posenc=num_posenc, |
|
num_rbf=num_rbf, |
|
zero_ratio=zero_ratio, |
|
undersample=undersample, |
|
cache_dir=cache_dir, |
|
force_rebuild=force_rebuild, |
|
verbose=verbose, |
|
use_embeddings2=use_embeddings2 |
|
) |
|
self.datasets.append(dataset) |
|
|
|
|
|
self.data_list = [] |
|
for dataset in self.datasets: |
|
self.data_list.extend(dataset.data_list) |
|
|
|
if verbose: |
|
print(f"Combined {len(self.datasets)} datasets with {len(self.data_list)} total samples") |
|
|
|
super().__init__() |
|
|
|
def len(self) -> int: |
|
return len(self.data_list) |
|
|
|
def get(self, idx: int) -> Data: |
|
return self.data_list[idx] |
|
|
|
def apply_filters(self, undersample: Union[int, float, None] = None, seed: int = 42): |
|
""" |
|
Apply filtering to the loaded multi-radius dataset. |
|
|
|
Args: |
|
undersample: Undersample parameter (int for count, float for ratio) |
|
seed: Random seed for reproducibility |
|
""" |
|
if undersample is not None: |
|
original_size = len(self.data_list) |
|
self.data_list = apply_undersample(self.data_list, undersample, seed=seed, verbose=True) |
|
|
|
def get_stats(self) -> Dict: |
|
"""Get combined dataset statistics.""" |
|
if not self.data_list: |
|
return {} |
|
|
|
|
|
num_nodes_list = [data.num_nodes for data in self.data_list] |
|
recall_list = [data.y.item() for data in self.data_list] |
|
|
|
|
|
total_nodes = sum(num_nodes_list) |
|
total_epitopes = sum([data.y_node.sum().item() for data in self.data_list]) |
|
|
|
stats = { |
|
'num_graphs': len(self.data_list), |
|
'num_radii': len(self.radii), |
|
'radii': self.radii, |
|
'avg_nodes_per_graph': np.mean(num_nodes_list), |
|
'std_nodes_per_graph': np.std(num_nodes_list), |
|
'min_nodes_per_graph': np.min(num_nodes_list), |
|
'max_nodes_per_graph': np.max(num_nodes_list), |
|
'total_nodes': total_nodes, |
|
'total_epitopes': total_epitopes, |
|
'epitope_ratio': total_epitopes / total_nodes if total_nodes > 0 else 0, |
|
'avg_recall': np.mean(recall_list), |
|
} |
|
|
|
return stats |
|
|
|
def print_stats(self): |
|
"""Print dataset statistics.""" |
|
stats = self.get_stats() |
|
if not stats: |
|
print("No statistics available (empty dataset)") |
|
return |
|
|
|
print(f"\n=== {self.data_split.upper()} Dataset Statistics ===") |
|
print(f"Number of graphs: {stats['num_graphs']:,}") |
|
print(f"Average nodes per graph: {stats['avg_nodes_per_graph']:.1f} ± {stats['std_nodes_per_graph']:.1f}") |
|
print(f"Nodes per graph range: [{stats['min_nodes_per_graph']}, {stats['max_nodes_per_graph']}]") |
|
print(f"Total nodes: {stats['total_nodes']:,}") |
|
print(f"Total epitope nodes: {stats['total_epitopes']:,}") |
|
print(f"Epitope ratio: {stats['epitope_ratio']:.3f}") |
|
print(f"Average recall: {stats['avg_recall']:.3f} ± {stats['std_recall']:.3f}") |
|
print(f"Average precision: {stats['avg_precision']:.3f} ± {stats['std_precision']:.3f}") |
|
print("=" * 40) |
|
|
|
|
|
|
|
|
|
def create_datasets( |
|
radii: List[int] = [16, 18, 20], |
|
splits: List[str] = ["train", "test"], |
|
threshold: float = 0.25, |
|
zero_ratio: float = None, |
|
undersample: Union[int, float, None] = None, |
|
cache_dir: Optional[str] = None, |
|
force_rebuild: bool = False, |
|
verbose: bool = False, |
|
seed: int = 42, |
|
use_embeddings2: bool = False, |
|
) -> Dict[str, SphereGraphDataset]: |
|
""" |
|
Create optimized datasets for all splits and radii. |
|
|
|
Args: |
|
radii: List of radii to use |
|
splits: List of data splits to create |
|
threshold: SASA threshold for surface residues |
|
zero_ratio: Ratio to downsample graphs with recall=0 |
|
undersample: Undersample parameter (int for count, float for ratio) |
|
cache_dir: Directory to cache processed data |
|
force_rebuild: Whether to force rebuild datasets |
|
verbose: Whether to print progress information |
|
seed: Random seed for reproducibility |
|
|
|
Returns: |
|
Dictionary mapping split names to datasets |
|
""" |
|
datasets = {} |
|
|
|
for split in splits: |
|
if len(radii) == 1: |
|
|
|
dataset = SphereGraphDataset( |
|
data_split=split, |
|
radius=radii[0], |
|
threshold=threshold, |
|
zero_ratio=zero_ratio, |
|
undersample=undersample, |
|
cache_dir=cache_dir, |
|
force_rebuild=force_rebuild, |
|
verbose=verbose, |
|
seed=seed, |
|
use_embeddings2=use_embeddings2 |
|
) |
|
if verbose: |
|
dataset.print_stats() |
|
else: |
|
|
|
dataset = MultiRadiusGraphDataset( |
|
data_split=split, |
|
radii=radii, |
|
threshold=threshold, |
|
zero_ratio=zero_ratio, |
|
undersample=undersample, |
|
cache_dir=cache_dir, |
|
force_rebuild=force_rebuild, |
|
verbose=verbose, |
|
use_embeddings2=use_embeddings2 |
|
) |
|
|
|
datasets[split] = dataset |
|
|
|
return datasets |
|
|
|
|
|
def custom_collate_fn(batch): |
|
""" |
|
Custom collate function for ReGEP model. |
|
Converts PyG Data objects to the format expected by ReGEP. |
|
""" |
|
|
|
batched_data = Batch.from_data_list(batch) |
|
|
|
|
|
|
|
|
|
|
|
return batched_data |
|
|
|
|
|
class ReGEPDataLoader(DataLoader): |
|
""" |
|
Custom DataLoader for ReGEP model that handles the specific input format. |
|
Supports undersampling at the DataLoader level. |
|
""" |
|
|
|
def __init__(self, dataset, batch_size=32, shuffle=True, **kwargs): |
|
""" |
|
Initialize ReGEP DataLoader with optional undersampling. |
|
|
|
Args: |
|
dataset: The dataset to load from |
|
batch_size: Batch size |
|
shuffle: Whether to shuffle the data |
|
**kwargs: Additional arguments for DataLoader |
|
""" |
|
|
|
if 'collate_fn' not in kwargs: |
|
kwargs['collate_fn'] = custom_collate_fn |
|
|
|
super().__init__( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
shuffle=shuffle, |
|
**kwargs |
|
) |
|
|
|
def create_data_loader( |
|
radii=[16, 18, 20], |
|
batch_size=32, |
|
zero_ratio=0.1, |
|
undersample=0.5, |
|
seed=42, |
|
verbose=False, |
|
use_embeddings2=False, |
|
**kwargs |
|
): |
|
""" |
|
Create train and test data loaders. |
|
|
|
Args: |
|
radii (list): List of radii for data processing |
|
batch_size (int): Batch size for training |
|
zero_ratio (float): Ratio of zero samples for training |
|
undersample (float): Undersampling ratio for training |
|
seed (int): Random seed |
|
verbose (bool): Whether to print verbose information |
|
**kwargs: Additional arguments for data loader |
|
|
|
Returns: |
|
tuple: (train_loader, test_loader) |
|
""" |
|
train_dataset = create_datasets( |
|
radii=radii, |
|
splits=["train"], |
|
threshold=0.25, |
|
undersample=undersample, |
|
zero_ratio=zero_ratio, |
|
cache_dir=None, |
|
seed=seed, |
|
verbose=verbose, |
|
use_embeddings2=use_embeddings2 |
|
)["train"] |
|
|
|
test_dataset = create_datasets( |
|
radii=radii, |
|
splits=["test"], |
|
threshold=0.25, |
|
undersample=None, |
|
zero_ratio=None, |
|
cache_dir=None, |
|
verbose=verbose, |
|
use_embeddings2=use_embeddings2 |
|
)["test"] |
|
|
|
train_loader = ReGEPDataLoader( |
|
train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
collate_fn=custom_collate_fn, |
|
**kwargs |
|
) |
|
|
|
test_loader = ReGEPDataLoader( |
|
test_dataset, |
|
batch_size=batch_size*4, |
|
shuffle=False, |
|
**kwargs |
|
) |
|
|
|
return train_loader, test_loader |
|
|