|
from __future__ import annotations |
|
|
|
import json |
|
import h5py |
|
import traceback |
|
import numpy as np |
|
import traceback |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
from typing import Optional, Union, BinaryIO, TextIO |
|
from dataclasses import dataclass |
|
from scipy.spatial.distance import cdist |
|
|
|
import torch |
|
|
|
|
|
from esm.utils import residue_constants as RC |
|
from esm.utils.structure.protein_chain import ProteinChain |
|
|
|
|
|
import biotite.structure as bs |
|
from biotite.database import rcsb |
|
from biotite.structure.io.pdb import PDBFile |
|
from biotite.structure import annotate_sse |
|
|
|
from cloudpathlib import CloudPath |
|
from Bio.Data import PDBData |
|
|
|
import py3Dmol |
|
|
|
|
|
from ..utils.constants import BASE_DIR |
|
from ..utils.loading import load_epitopes_csv, load_epitopes_csv_single, load_species |
|
from .pc import AMINO_ACID_1TO3, AMINO_ACID_3TO1, MAX_ASA |
|
from ..model.ReCEP import ReCEP |
|
from ..data.utils import create_graph_data |
|
|
|
|
|
PathOrBuffer = Union[str, Path, BinaryIO, TextIO] |
|
|
|
@dataclass |
|
class AntigenChain(ProteinChain): |
|
""" |
|
Extended ProteinChain class that adds additional functionalities, |
|
such as computing surface residues based on SASA and maxASA constants. |
|
""" |
|
def __post_init__(self, token: Optional[str] = "1mzAo8l1uxaU8UfVcGgV7B"): |
|
super().__post_init__() |
|
|
|
|
|
self.resnum_to_index = {int(rnum): i for i, rnum in enumerate(self.residue_index)} |
|
|
|
|
|
self.epitopes = self.get_epitopes() |
|
|
|
|
|
self.token = token |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
@staticmethod |
|
def convert_letter_1to3(letter: str) -> str: |
|
""" |
|
Convert a one-letter amino acid code to its corresponding three-letter code. |
|
|
|
Args: |
|
letter (str): A single-character amino acid code (e.g., "A"). |
|
|
|
Returns: |
|
str: The corresponding three-letter code (e.g., "ALA"). |
|
Returns "UNK" if the code is not recognized. |
|
""" |
|
return AMINO_ACID_1TO3.get(letter.upper(), "UNK") |
|
|
|
@staticmethod |
|
def convert_letter_3to1(three_letter: str) -> str: |
|
""" |
|
Convert a three-letter amino acid code to its corresponding one-letter code. |
|
|
|
Args: |
|
three_letter (str): A three-letter amino acid code (e.g., "ALA"). |
|
|
|
Returns: |
|
str: The corresponding one-letter code (e.g., "A"). |
|
Returns "X" if the code is not recognized. |
|
""" |
|
return AMINO_ACID_3TO1.get(three_letter.upper(), "X") |
|
|
|
def get_species(self) -> str: |
|
""" |
|
Get the species of the antigen. |
|
""" |
|
from ..utils.tools import get_chain_organism |
|
|
|
species_dict = load_species() |
|
if self.id in species_dict: |
|
species = species_dict[self.id]['classification'] |
|
else: |
|
try: |
|
species = get_chain_organism(self.id, self.chain_id) |
|
species_dict[self.id] = {'classification': species} |
|
|
|
|
|
species_file_path = Path(f"{BASE_DIR}/data/species.json") |
|
species_file_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
with open(species_file_path, "w") as f: |
|
json.dump(species_dict, f, indent=2) |
|
except Exception as e: |
|
print(f"[ERROR] Failed to get species for {self.id}_{self.chain_id}: {str(e)}") |
|
species = "Unknown" |
|
return species |
|
|
|
def get_backbone_atoms(self) -> np.ndarray: |
|
""" |
|
Get backbone atom coordinates in the order: CA, C, N. |
|
|
|
Returns: |
|
np.ndarray: [L, 3, 3] array where [:, 0] is CA, [:, 1] is C, [:, 2] is N. |
|
""" |
|
file = Path(f"{BASE_DIR}/data/coords/{self.id}_{self.chain_id}.npy") |
|
|
|
if file.exists(): |
|
return np.load(file) |
|
else: |
|
idx_CA = RC.atom_order["CA"] |
|
idx_C = RC.atom_order["C"] |
|
idx_N = RC.atom_order["N"] |
|
|
|
backbone_atoms = self.atom37_positions[:, [idx_N, idx_CA, idx_C], :] |
|
|
|
|
|
file.parent.mkdir(parents=True, exist_ok=True) |
|
np.save(file, backbone_atoms) |
|
return backbone_atoms |
|
|
|
def get_secondary_structure(self) -> np.ndarray: |
|
""" |
|
Get secondary structure information using numpy operations. |
|
""" |
|
try: |
|
ss3_arr = annotate_sse(self.atom_array) |
|
biotite_ss3_str = "".join(ss3_arr) |
|
|
|
if len(biotite_ss3_str) != len(self.sequence): |
|
print(f"[WARNING] Secondary structure prediction length ({len(biotite_ss3_str)}) " |
|
f"doesn't match sequence length ({len(self.sequence)}) " |
|
f"for protein {self.id}_{self.chain_id}") |
|
return None |
|
|
|
translation_table = str.maketrans({ |
|
"a": "H", |
|
"b": "E", |
|
"c": "C", |
|
}) |
|
return biotite_ss3_str.translate(translation_table) |
|
|
|
except Exception as e: |
|
print(f"[ERROR] Failed to predict secondary structure for " |
|
f"{self.id}_{self.chain_id}: {str(e)}") |
|
return None |
|
|
|
def get_ss_onehot(self) -> np.ndarray: |
|
""" |
|
Get one-hot encoded secondary structure information using numpy operations. |
|
Only encode H (helix) and E (sheet), as C (coil) can be inferred. |
|
|
|
Returns: |
|
np.ndarray: One-hot encoded secondary structure array of shape (seq_len, 2) |
|
where 2 represents [H, E] (Helix, Sheet) |
|
""" |
|
self.secondary_structure = self.get_secondary_structure() |
|
seq_len = len(self.secondary_structure) |
|
ss_onehot = np.zeros((seq_len, 2), dtype=np.float32) |
|
|
|
|
|
ss_array = np.array(list(self.secondary_structure)) |
|
ss_onehot[:, 0] = (ss_array == 'H') |
|
ss_onehot[:, 1] = (ss_array == 'E') |
|
|
|
return ss_onehot |
|
|
|
def get_rsa(self) -> np.ndarray: |
|
""" |
|
Calculate relative solvent accessibility (RSA) for all residues. |
|
RSA is the ratio of SASA to maximum ASA for each residue. |
|
|
|
Returns: |
|
np.ndarray: An array of RSA values for each residue in the sequence. |
|
""" |
|
|
|
cache_file = Path(BASE_DIR) / "data" / "rsa" / f"{self.id}_{self.chain_id}.npy" |
|
if cache_file.exists(): |
|
return np.load(cache_file) |
|
|
|
sasa_values = self.sasa() |
|
rsa_values = np.zeros(len(self.sequence), dtype=np.float32) |
|
|
|
|
|
for i, (letter, sasa) in enumerate(zip(self.sequence, sasa_values)): |
|
three_letter = self.convert_letter_1to3(letter) |
|
max_asa = MAX_ASA.get(three_letter) |
|
if max_asa is not None and max_asa != 0: |
|
rsa_values[i] = sasa / max_asa |
|
|
|
|
|
cache_file.parent.mkdir(parents=True, exist_ok=True) |
|
np.save(cache_file, rsa_values) |
|
|
|
return rsa_values |
|
|
|
def get_surface_residues(self, threshold: float = 0.25) -> list: |
|
""" |
|
Identify surface-exposed residues using RSA values. |
|
|
|
A residue is considered surface-exposed if its RSA value |
|
is at least `threshold`. |
|
|
|
Args: |
|
threshold (float): The minimum RSA value required to consider |
|
the residue as surface-exposed. |
|
|
|
Returns: |
|
tuple: A tuple of two lists, where the first list contains residue numbers (from the PDB) that are surface-exposed, |
|
and the second list contains the indices of the surface residues in the sequence. |
|
""" |
|
rsa_values = self.get_rsa() |
|
surface_residue_numbers = [] |
|
surface_residue_indices = [] |
|
|
|
|
|
for idx, rsa in enumerate(rsa_values): |
|
if rsa >= threshold: |
|
surface_residue_numbers.append(int(self.residue_index[idx])) |
|
surface_residue_indices.append(idx) |
|
|
|
return surface_residue_numbers, surface_residue_indices |
|
|
|
def get_epitopes(self, threshold: float = 0.25) -> np.ndarray: |
|
""" |
|
Retrieve epitopes for this chain as a boolean array. |
|
|
|
Args: |
|
threshold (float): SASA threshold for determining surface residues. |
|
|
|
Returns: |
|
np.ndarray: A boolean array of length L (sequence length) where True indicates |
|
epitope positions and False indicates non-epitope positions. |
|
Only surface-exposed residues can be True. |
|
""" |
|
_, _, epitopes = load_epitopes_csv() |
|
|
|
if f'{self.id}_{self.chain_id}' in epitopes: |
|
binary_labels = epitopes.get(f'{self.id}_{self.chain_id}', [0] * len(self.sequence)) |
|
else: |
|
print(f"[WARNING] Epitopes not found for {self.id}_{self.chain_id}. Use single epitopes.") |
|
binary_labels = self.get_epitopes_single() |
|
|
|
|
|
epitope_array = np.zeros(len(self.sequence), dtype=bool) |
|
|
|
|
|
if binary_labels is not None and len(binary_labels) > 0: |
|
|
|
if len(binary_labels) == len(self.sequence): |
|
epitope_array = np.array(binary_labels, dtype=bool) |
|
else: |
|
print(f"[WARNING] Binary labels length ({len(binary_labels)}) doesn't match " |
|
f"sequence length ({len(self.sequence)}) for {self.id}_{self.chain_id}") |
|
return epitope_array |
|
|
|
if threshold == 0.0: |
|
return epitope_array |
|
|
|
|
|
_, surface_indices = self.get_surface_residues(threshold=threshold) |
|
|
|
|
|
surface_mask = np.zeros(len(self.sequence), dtype=bool) |
|
for res_idx in surface_indices: |
|
if 0 <= res_idx < len(self.sequence): |
|
surface_mask[res_idx] = True |
|
|
|
|
|
epitope_array = epitope_array & surface_mask |
|
|
|
return epitope_array |
|
|
|
def get_epitopes_single(self) -> np.ndarray: |
|
""" |
|
Retrieve epitopes for this chain as a boolean array. |
|
""" |
|
_, _, epitopes = load_epitopes_csv_single() |
|
|
|
|
|
possible_keys = [ |
|
f'{self.id.upper()}_{self.chain_id}', |
|
f'{self.id}_{self.chain_id}', |
|
f'{self.id.lower()}_{self.chain_id}' |
|
] |
|
|
|
epitopes_resnums = None |
|
for key in possible_keys: |
|
if key in epitopes: |
|
epitopes_resnums = epitopes.get(key) |
|
break |
|
|
|
if epitopes_resnums is not None: |
|
epitope_array = np.zeros(len(self.sequence), dtype=int) |
|
for resnum in epitopes_resnums: |
|
if resnum in self.resnum_to_index: |
|
epitope_array[self.resnum_to_index[resnum]] = 1 |
|
return epitope_array |
|
else: |
|
print(f"[WARNING] Single Epitopes not found for {self.id}_{self.chain_id}. Use no epitopes.") |
|
epitope_array = np.zeros(len(self.sequence), dtype=int) |
|
|
|
return epitope_array |
|
|
|
def get_epitope_residue_numbers(self) -> list: |
|
""" |
|
Get epitope residue numbers from the boolean epitope array. |
|
|
|
Returns: |
|
list: List of residue numbers that are epitopes. |
|
""" |
|
epitope_indices = np.where(self.epitopes)[0] |
|
epitope_residue_numbers = [int(self.residue_index[idx]) for idx in epitope_indices] |
|
return epitope_residue_numbers |
|
|
|
def get_embeddings(self, override: bool = False, encoder: str = "esmc") -> np.ndarray: |
|
""" |
|
Retrieve or compute per-residue (full) ESM-C embeddings. |
|
|
|
Returns: |
|
np.ndarray: Array of shape (seq_len, embed_dim), dtype float32. |
|
""" |
|
full_file = Path(BASE_DIR) / "data" / "embeddings" / f"{encoder}" / f"{self.id}_{self.chain_id}.h5" |
|
|
|
if full_file.exists() and not override: |
|
with h5py.File(full_file, "r") as h5f: |
|
full_embedding = h5f["embedding"][:] |
|
else: |
|
if encoder == "esmc": |
|
if self.token is None: |
|
raise ValueError("ESM token is not set. Please go to https://forge.evolutionaryscale.ai/ to get a token.") |
|
|
|
else: |
|
print(f"[INFO] Generating with ESM-C...") |
|
|
|
from esm.sdk.api import ESMProtein, LogitsConfig |
|
from esm.sdk.forge import ESM3ForgeInferenceClient |
|
|
|
token = self.token |
|
model = ESM3ForgeInferenceClient( |
|
model="esmc-6b-2024-12", |
|
url="https://forge.evolutionaryscale.ai", |
|
token=token |
|
) |
|
config = LogitsConfig(sequence=True, return_embeddings=True) |
|
|
|
sequence = self.sequence[:2046] |
|
protein = ESMProtein(sequence) |
|
protein_tensor = model.encode(protein) |
|
output = model.logits(protein_tensor, config) |
|
full_embedding = output.embeddings.squeeze(0)[1:-1, :].to(torch.float32).cpu().numpy() |
|
|
|
full_file.parent.mkdir(parents=True, exist_ok=True) |
|
with h5py.File(full_file, "w") as h5f: |
|
h5f.create_dataset("embedding", data=full_embedding, compression="gzip") |
|
|
|
elif encoder == "esm2": |
|
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D") |
|
batch_converter = alphabet.get_batch_converter() |
|
model.eval() |
|
data = [ |
|
("antigen", self.sequence[:2046]) |
|
] |
|
batch_labels, batch_strs, batch_tokens = batch_converter(data) |
|
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) |
|
model.to(self.device) |
|
batch_tokens = batch_tokens.to(self.device) |
|
with torch.no_grad(): |
|
results = model(batch_tokens, repr_layers=[33], return_contacts=True) |
|
token_representations = results["representations"][33] |
|
full_embedding = token_representations.squeeze(0)[1:-1, :].to(torch.float32).cpu().numpy() |
|
|
|
full_file.parent.mkdir(parents=True, exist_ok=True) |
|
with h5py.File(full_file, "w") as h5f: |
|
h5f.create_dataset("embedding", data=full_embedding, compression="gzip") |
|
|
|
return full_embedding |
|
|
|
def _scan_surface_residues(self, radius: float, threshold: float = 0.25) -> tuple: |
|
""" |
|
Helper function to compute the surface coverage for each surface residue. |
|
For each surface residue, using its C_alpha coordinate as the center of a sphere with |
|
radius `radius`, determine which surface residues are covered. |
|
|
|
Args: |
|
radius (float): The radius of the sphere (in Ångstroms) |
|
threshold (float): Fraction of maximum ASA to define a residue as surface-exposed |
|
|
|
Returns: |
|
tuple: |
|
- coverage (dict): Mapping from center residue index to: |
|
(list[int]): List of covered residue indices |
|
(list[int]): List of covered epitope residue indices |
|
(float): Precision |
|
(float): Recall |
|
- max_recall_res (int): Center residue index with highest recall |
|
- max_precision_res (int): Center residue index with highest precision |
|
""" |
|
|
|
if radius <= 0: |
|
raise ValueError("Radius must be positive") |
|
if threshold < 0 or threshold > 1: |
|
raise ValueError("Threshold must be between 0 and 1") |
|
|
|
|
|
surface_res_nums, surface_indices = self.get_surface_residues(threshold=threshold) |
|
|
|
|
|
valid_surface_indices = [ |
|
idx for idx in surface_indices |
|
if 0 <= idx < len(self.sequence) |
|
] |
|
valid_surface_res_nums = [ |
|
surface_res_nums[surface_indices.index(idx)] |
|
for idx in valid_surface_indices |
|
] |
|
|
|
if not valid_surface_indices: |
|
return {}, None, None |
|
|
|
|
|
all_atoms = [] |
|
all_res_indices = [] |
|
for idx in valid_surface_indices: |
|
mask = self.atom37_mask[idx] |
|
coords = self.atom37_positions[idx][mask] |
|
if len(coords) > 0: |
|
all_atoms.append(coords) |
|
all_res_indices.extend([idx] * len(coords)) |
|
|
|
if not all_atoms: |
|
return {idx: ([], [], 0.0, 0.0) for idx in valid_surface_indices}, None, None |
|
|
|
all_atoms = np.vstack(all_atoms).astype(np.float32) |
|
all_res_indices = np.array(all_res_indices) |
|
|
|
|
|
surface_ca = [] |
|
valid_center_indices = [] |
|
ca_idx = RC.atom_order["CA"] |
|
|
|
for idx in valid_surface_indices: |
|
|
|
ca_coord = self.atom37_positions[idx, ca_idx, :] |
|
if not np.any(np.isnan(ca_coord)) and self.atom37_mask[idx, ca_idx]: |
|
surface_ca.append(ca_coord) |
|
valid_center_indices.append(idx) |
|
|
|
if not surface_ca: |
|
return {}, None, None |
|
|
|
surface_ca = np.array(surface_ca, dtype=np.float32) |
|
surface_ca = surface_ca.reshape(-1, 3) |
|
|
|
|
|
try: |
|
dist_matrix = cdist(surface_ca, all_atoms) |
|
except ValueError as e: |
|
print(f"Error in distance calculation: {e}") |
|
print(f"surface_ca shape: {surface_ca.shape}") |
|
print(f"all_atoms shape: {all_atoms.shape}") |
|
return {}, None, None |
|
|
|
max_recall = -1 |
|
max_recall_res = None |
|
max_precision = -1 |
|
max_precision_res = None |
|
|
|
coverage = {} |
|
epitope_indices = np.where(self.epitopes)[0] |
|
if len(epitope_indices) == 0: |
|
print(f"No epitopes records for protein {self.id}_{self.chain_id}") |
|
|
|
for i, center_idx in enumerate(valid_center_indices): |
|
within_radius = dist_matrix[i] < radius |
|
covered_indices = np.unique(all_res_indices[within_radius]) |
|
covered_indices_list = covered_indices.tolist() |
|
|
|
|
|
covered_epitope_indices = list(set(covered_indices_list).intersection(set(epitope_indices))) |
|
|
|
|
|
precision = len(covered_epitope_indices) / len(covered_indices_list) if covered_indices_list else 0.0 |
|
recall = len(covered_epitope_indices) / len(epitope_indices) if len(epitope_indices) > 0 else 0.0 |
|
|
|
if recall > max_recall: |
|
max_recall = recall |
|
max_recall_res = center_idx |
|
if precision > max_precision: |
|
max_precision = precision |
|
max_precision_res = center_idx |
|
|
|
|
|
coverage[int(center_idx)] = ( |
|
[int(idx) for idx in covered_indices_list], |
|
[int(idx) for idx in covered_epitope_indices], |
|
float(precision), |
|
float(recall) |
|
) |
|
|
|
return coverage, max_recall_res, max_precision_res |
|
|
|
def get_surface_coverage(self, radius: float = 18, |
|
threshold: float = 0.25, |
|
index: bool = True, |
|
override: bool = False) -> tuple: |
|
""" |
|
Retrieve (or compute and cache) the coverage mapping for surface residues. |
|
For each surface residue, using its C_alpha as the sphere center (with radius `radius`), |
|
determine which surface residues are covered (i.e. if any atom falls within that sphere). |
|
The result is cached to an HDF5 file for faster subsequent retrieval. |
|
|
|
The cache file is saved in BASE_DIR / "data/antigen_sphere", with the file name |
|
"{self.id}_{self.chain_id}.h5", and radius as the first-level key. |
|
|
|
Args: |
|
radius (float): The radius of the sphere (in Ångstroms). |
|
threshold (float): Fraction of maximum ASA to define a residue as surface-exposed. |
|
index (bool): If True, return indices instead of residue numbers for easier embeddings/coords access. |
|
override (bool): If True, recompute even if cache exists. |
|
|
|
Returns: |
|
tuple: |
|
- coverage (dict): A dictionary mapping each surface residue to a tuple of: |
|
If index=True: center_index -> (list[int]): List of covered residue indices |
|
(list[int]): List of covered epitope residue indices |
|
(float): Precision |
|
(float): Recall |
|
If index=False: center_residue_num -> (list[int]): List of covered residue numbers |
|
(list[int]): List of covered epitope residue numbers |
|
(float): Precision |
|
(float): Recall |
|
- max_recall_res (int): The surface residue number with the highest recall. |
|
- max_precision_res (int): The surface residue number with the highest precision. |
|
""" |
|
|
|
cache_dir = BASE_DIR / "data" / "antigen_sphere" |
|
cache_dir.mkdir(parents=True, exist_ok=True) |
|
cache_filename = f"{self.id}_{self.chain_id}.h5" |
|
cache_path = cache_dir / cache_filename |
|
radius_key = f"r{radius}" |
|
|
|
|
|
if cache_path.exists() and not override: |
|
try: |
|
with h5py.File(cache_path, "r") as h5f: |
|
if radius_key in h5f: |
|
|
|
radius_group = h5f[radius_key] |
|
|
|
if index: |
|
|
|
coverage = {} |
|
for center_idx_str in radius_group.keys(): |
|
center_idx = int(center_idx_str) |
|
center_group = radius_group[center_idx_str] |
|
covered_indices = center_group['covered_indices'][:].tolist() |
|
covered_epitope_indices = center_group['covered_epitope_indices'][:].tolist() |
|
precision = float(center_group.attrs['precision']) |
|
recall = float(center_group.attrs['recall']) |
|
coverage[center_idx] = (covered_indices, covered_epitope_indices, precision, recall) |
|
return coverage, None, None |
|
else: |
|
|
|
coverage = {} |
|
max_recall = -1 |
|
max_recall_res = None |
|
max_precision = -1 |
|
max_precision_res = None |
|
|
|
for center_idx_str in radius_group.keys(): |
|
center_idx = int(center_idx_str) |
|
center_res_num = int(self.residue_index[center_idx]) |
|
center_group = radius_group[center_idx_str] |
|
|
|
covered_indices = center_group['covered_indices'][:].tolist() |
|
covered_epitope_indices = center_group['covered_epitope_indices'][:].tolist() |
|
precision = float(center_group.attrs['precision']) |
|
recall = float(center_group.attrs['recall']) |
|
|
|
|
|
covered_res_nums = [int(self.residue_index[idx]) for idx in covered_indices if 0 <= idx < len(self.residue_index)] |
|
covered_epitope_res_nums = [int(self.residue_index[idx]) for idx in covered_epitope_indices if 0 <= idx < len(self.residue_index)] |
|
|
|
coverage[center_res_num] = (covered_res_nums, covered_epitope_res_nums, precision, recall) |
|
|
|
if recall > max_recall: |
|
max_recall = recall |
|
max_recall_res = center_res_num |
|
if precision > max_precision: |
|
max_precision = precision |
|
max_precision_res = center_res_num |
|
|
|
return coverage, max_recall_res, max_precision_res |
|
except (OSError, KeyError, ValueError) as e: |
|
print(f"[WARNING] Error reading cache file {cache_path}: {e}") |
|
print(f"[INFO] Recomputing surface coverage...") |
|
|
|
|
|
coverage, max_recall_res, max_precision_res = self._scan_surface_residues(radius, threshold) |
|
|
|
|
|
|
|
with h5py.File(cache_path, "a") as h5f: |
|
|
|
if radius_key in h5f: |
|
del h5f[radius_key] |
|
|
|
radius_group = h5f.create_group(radius_key) |
|
|
|
|
|
for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in coverage.items(): |
|
center_group = radius_group.create_group(str(center_idx)) |
|
center_group.create_dataset('covered_indices', data=np.array(covered_indices, dtype=np.int32), compression='gzip') |
|
center_group.create_dataset('covered_epitope_indices', data=np.array(covered_epitope_indices, dtype=np.int32), compression='gzip') |
|
center_group.attrs['precision'] = precision |
|
center_group.attrs['recall'] = recall |
|
|
|
|
|
if not index: |
|
coverage_resnums = {} |
|
max_recall_res_num = None |
|
max_precision_res_num = None |
|
|
|
if max_recall_res is not None: |
|
max_recall_res_num = int(self.residue_index[max_recall_res]) |
|
if max_precision_res is not None: |
|
max_precision_res_num = int(self.residue_index[max_precision_res]) |
|
|
|
for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in coverage.items(): |
|
center_res_num = int(self.residue_index[center_idx]) |
|
|
|
covered_res_nums = [int(self.residue_index[idx]) for idx in covered_indices if 0 <= idx < len(self.residue_index)] |
|
covered_epitope_res_nums = [int(self.residue_index[idx]) for idx in covered_epitope_indices if 0 <= idx < len(self.residue_index)] |
|
coverage_resnums[center_res_num] = (covered_res_nums, covered_epitope_res_nums, precision, recall) |
|
|
|
return coverage_resnums, max_recall_res_num, max_precision_res_num |
|
|
|
return coverage, max_recall_res, max_precision_res |
|
|
|
def data_preparation(self, radius: float = None, encoder: str = "esmc", override: bool = False): |
|
""" |
|
Retrieve or compute region embeddings for surface residues using spherical regions. |
|
|
|
Args: |
|
radius (float): Radius to define the neighborhood of each center residue. |
|
threshold (float): Threshold to determine surface residues. |
|
cover (bool): Whether to recompute and overwrite cached data. |
|
verbose (bool): Whether to print progress information. |
|
|
|
Returns: |
|
tuple: |
|
- embeddings (np.ndarray): Array of embeddings mean of the region. (num_regions, embedding_dim) |
|
- center_residues (np.ndarray): Array of center residue numbers. (num_regions,) |
|
- precisions (np.ndarray): Array of precision values for each center residue. (num_regions,) |
|
- recalls (np.ndarray): Array of recall values for each center residue. (num_regions,) |
|
""" |
|
embeddings = self.get_embeddings(encoder=encoder) |
|
backbone_atoms = self.get_backbone_atoms() |
|
rsa = self.get_rsa() |
|
if radius is None: |
|
|
|
for i in range(16,21,2): |
|
_, _, _ = self.get_surface_coverage(radius=i, override=override) |
|
return embeddings, backbone_atoms, rsa, None |
|
else: |
|
coverage_dict, _, _ = self.get_surface_coverage(radius=radius, override=override) |
|
return embeddings, backbone_atoms, rsa, coverage_dict |
|
|
|
def evaluate(self, model_path: str = None, device_id: int = 1, radius: float = 19.0, k: int = 7, |
|
threshold: float = None, verbose: bool = True, encoder: str = "esmc", use_gpu: bool = True): |
|
""" |
|
Evaluate epitopes using ReCEP model with spherical regions. |
|
|
|
Args: |
|
model_path (str): Path to the trained ReCEP model |
|
device_id (int): GPU device ID to use |
|
radius (float): Radius for spherical regions |
|
k (int): Number of top regions to select |
|
threshold (float): Threshold for node-level epitope prediction |
|
verbose (bool): Whether to print progress information |
|
|
|
Returns: |
|
dict: Dictionary containing: |
|
- 'predicted_epitopes': List of predicted epitope residue numbers |
|
- 'true_epitopes': Set of true epitope residue numbers |
|
- 'precision': Final prediction precision |
|
- 'recall': Final prediction recall |
|
- 'top_k_regions': Information about selected regions |
|
""" |
|
|
|
if use_gpu and torch.cuda.is_available() and device_id >= 0: |
|
device = torch.device(f"cuda:{device_id}") |
|
else: |
|
device = torch.device("cpu") |
|
if verbose: |
|
print(f"[INFO] Using device: {device}") |
|
|
|
|
|
try: |
|
if model_path is None: |
|
model_path = f"{BASE_DIR}/models/ReCEP/20250626_110438/best_mcc_model.bin" |
|
|
|
if threshold is None: |
|
model, threshold = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
|
else: |
|
model, _ = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
|
|
|
model.eval() |
|
if verbose: |
|
print(f"[INFO] Loaded ReCEP model from {model_path}") |
|
except Exception as e: |
|
if verbose: |
|
print(f"[ERROR] Failed to load model: {str(e)}") |
|
return {} |
|
|
|
|
|
try: |
|
embeddings, backbone_atoms, rsa, coverage_dict = self.data_preparation(radius=radius, encoder=encoder) |
|
if verbose: |
|
print(f"[INFO] Retrieved protein data for {len(coverage_dict)} surface regions") |
|
except Exception as e: |
|
if verbose: |
|
print(f"[ERROR] Failed to prepare data: {str(e)}") |
|
traceback.print_exc() |
|
return {} |
|
|
|
if not coverage_dict: |
|
if verbose: |
|
print("[WARNING] No surface regions found") |
|
return {} |
|
|
|
|
|
epitope_indices = np.where(self.epitopes)[0].tolist() |
|
|
|
|
|
region_predictions = [] |
|
|
|
with torch.no_grad(): |
|
for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in tqdm( |
|
coverage_dict.items(), desc="Predicting region values", disable=not verbose): |
|
|
|
if len(covered_indices) < 2: |
|
continue |
|
|
|
try: |
|
|
|
graph_data = create_graph_data( |
|
center_idx=center_idx, |
|
covered_indices=covered_indices, |
|
covered_epitope_indices=covered_epitope_indices, |
|
embeddings=embeddings, |
|
backbone_atoms=backbone_atoms, |
|
rsa_values=rsa, |
|
epitope_indices=epitope_indices, |
|
recall=recall, |
|
precision=precision, |
|
pdb_id=self.id, |
|
chain_id=self.chain_id, |
|
verbose=True |
|
) |
|
|
|
if graph_data is None: |
|
if verbose: |
|
print(f"[WARNING] Failed to create graph data for region {center_idx}") |
|
continue |
|
|
|
|
|
graph_data = graph_data.to(device) |
|
|
|
|
|
graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
|
|
|
|
|
outputs = model(graph_data) |
|
|
|
|
|
if 'global_pred' in outputs: |
|
graph_pred = torch.sigmoid(outputs['global_pred']).cpu().item() |
|
else: |
|
|
|
node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
|
graph_pred = float(np.mean(node_preds)) |
|
|
|
region_predictions.append({ |
|
'center_idx': center_idx, |
|
'covered_indices': covered_indices, |
|
'covered_epitope_indices': covered_epitope_indices, |
|
'graph_pred': graph_pred, |
|
'true_recall': recall, |
|
'graph_data': graph_data |
|
}) |
|
|
|
except Exception as e: |
|
if verbose: |
|
print(f"[WARNING] Error processing region {center_idx}: {str(e)}") |
|
traceback.print_exc() |
|
continue |
|
|
|
if not region_predictions: |
|
if verbose: |
|
print("[WARNING] No valid region predictions") |
|
return {} |
|
|
|
|
|
region_predictions.sort(key=lambda x: x['graph_pred'], reverse=True) |
|
top_k_regions = region_predictions[:k] |
|
|
|
if verbose: |
|
print(f"[INFO] Selected top {len(top_k_regions)} regions:") |
|
for i, region in enumerate(top_k_regions): |
|
print(f" Region {i+1}: center={region['center_idx']}, " |
|
f"predicted_value={region['graph_pred']:.3f}, " |
|
f"true_recall={region['true_recall']:.3f}") |
|
|
|
|
|
residue_votes = {} |
|
residue_probs = {} |
|
|
|
with torch.no_grad(): |
|
for region in tqdm(top_k_regions, desc="Predicting node values", disable=not verbose): |
|
try: |
|
graph_data = region['graph_data'] |
|
|
|
|
|
if not hasattr(graph_data, 'batch') or graph_data.batch is None: |
|
graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
|
|
|
|
|
outputs = model(graph_data) |
|
|
|
|
|
node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
|
|
|
|
|
for local_idx, residue_idx in enumerate(region['covered_indices']): |
|
if residue_idx not in residue_votes: |
|
residue_votes[residue_idx] = [] |
|
residue_probs[residue_idx] = [] |
|
|
|
|
|
prob = float(node_preds[local_idx]) |
|
residue_probs[residue_idx].append(prob) |
|
|
|
|
|
vote = 1 if prob >= threshold else 0 |
|
residue_votes[residue_idx].append(vote) |
|
|
|
except Exception as e: |
|
if verbose: |
|
print(f"[WARNING] Error in node prediction for region {region['center_idx']}: {str(e)}") |
|
traceback.print_exc() |
|
continue |
|
|
|
|
|
all_residue_predictions = {} |
|
for idx in range(len(self.residue_index)): |
|
residue_num = int(self.residue_index[idx]) |
|
if idx in residue_probs: |
|
|
|
all_residue_predictions[residue_num] = float(np.mean(residue_probs[idx])) |
|
else: |
|
|
|
all_residue_predictions[residue_num] = 1e-2 |
|
|
|
|
|
voted_epitope_indices = [] |
|
for residue_idx, votes in residue_votes.items(): |
|
|
|
if sum(votes) >= len(votes) / 2: |
|
voted_epitope_indices.append(residue_idx) |
|
|
|
|
|
voted_epitope_resnums = [int(self.residue_index[idx]) for idx in voted_epitope_indices |
|
if 0 <= idx < len(self.residue_index)] |
|
|
|
|
|
predicted_epitope_resnums = [] |
|
for residue_num, prob in all_residue_predictions.items(): |
|
if prob >= threshold: |
|
predicted_epitope_resnums.append(residue_num) |
|
|
|
|
|
true_epitope_resnums = set(self.get_epitope_residue_numbers()) |
|
|
|
|
|
|
|
voted_tp = len(set(voted_epitope_resnums) & true_epitope_resnums) |
|
voted_precision = voted_tp / len(voted_epitope_resnums) if voted_epitope_resnums else 0 |
|
voted_recall = voted_tp / len(true_epitope_resnums) if true_epitope_resnums else 0 |
|
|
|
|
|
predicted_tp = len(set(predicted_epitope_resnums) & true_epitope_resnums) |
|
predicted_precision = predicted_tp / len(predicted_epitope_resnums) if predicted_epitope_resnums else 0 |
|
predicted_recall = predicted_tp / len(true_epitope_resnums) if true_epitope_resnums else 0 |
|
|
|
if verbose: |
|
print(f"\n[INFO] Final Results:") |
|
print(f" True epitopes: {len(true_epitope_resnums)}") |
|
print(f" Residues in top-k regions: {len(residue_probs)}/{len(self.residue_index)}") |
|
print(f"\n Voting-based prediction:") |
|
print(f" Voted epitopes: {len(voted_epitope_resnums)}") |
|
print(f" Voted precision: {voted_precision:.3f}") |
|
print(f" Voted recall: {voted_recall:.3f}") |
|
print(f"\n Probability-based prediction (threshold={threshold}):") |
|
print(f" Predicted epitopes: {len(predicted_epitope_resnums)}") |
|
print(f" Predicted precision: {predicted_precision:.3f}") |
|
print(f" Predicted recall: {predicted_recall:.3f}") |
|
|
|
return { |
|
'predicted_epitopes': predicted_epitope_resnums, |
|
'voted_epitopes': voted_epitope_resnums, |
|
'true_epitopes': true_epitope_resnums, |
|
'predicted_precision': predicted_precision, |
|
'predicted_recall': predicted_recall, |
|
'voted_precision': voted_precision, |
|
'voted_recall': voted_recall, |
|
'predictions': all_residue_predictions, |
|
'top_k_regions': [ |
|
{ |
|
'center_residue': int(self.residue_index[region['center_idx']]), |
|
'center_idx': region['center_idx'], |
|
'predicted_value': region['graph_pred'], |
|
'true_recall': region['true_recall'], |
|
'covered_residues': [int(self.residue_index[idx]) for idx in region['covered_indices']] |
|
} |
|
for region in top_k_regions |
|
], |
|
'residue_votes': { |
|
int(self.residue_index[idx]): votes |
|
for idx, votes in residue_votes.items() |
|
if 0 <= idx < len(self.residue_index) |
|
} |
|
} |
|
|
|
def predict(self, model_path: str = None, device_id: int = 1, radius: float = 19.0, k: int = 7, |
|
threshold: float = None, verbose: bool = True, encoder: str = "esmc", use_gpu: bool = True, |
|
auto_cleanup: bool = False): |
|
""" |
|
Predict epitopes using ReCEP model with spherical regions (for unknown true epitopes). |
|
|
|
Args: |
|
model_path (str): Path to the trained ReCEP model |
|
device_id (int): GPU device ID to use |
|
radius (float): Radius for spherical regions |
|
k (int): Number of top regions to select |
|
threshold (float): Threshold for node-level epitope prediction |
|
verbose (bool): Whether to print progress information |
|
encoder (str): Encoder type for embeddings |
|
use_gpu (bool): Whether to use GPU for computation |
|
auto_cleanup (bool): Whether to automatically delete generated data files after prediction |
|
|
|
Returns: |
|
dict: Dictionary containing: |
|
- 'predicted_epitopes': List of predicted epitope residue numbers |
|
- 'predictions': Dictionary of all residue probabilities {resnum: probability} |
|
- 'top_k_centers': List of top-k center residue numbers |
|
- 'top_k_region_residues': List of all residues covered by top-k regions (union) |
|
- 'top_k_regions': Detailed information about selected regions |
|
""" |
|
|
|
if use_gpu and torch.cuda.is_available() and device_id >= 0: |
|
device = torch.device(f"cuda:{device_id}") |
|
else: |
|
device = torch.device("cpu") |
|
if verbose: |
|
print(f"[INFO] Using device: {device}") |
|
|
|
|
|
try: |
|
if model_path is None: |
|
model_path = f"{BASE_DIR}/models/ReCEP/20250626_110438/best_mcc_model.bin" |
|
|
|
if threshold is None: |
|
model, threshold = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
|
else: |
|
model, _ = ReCEP.load(model_path, device=device, strict=False, verbose=False) |
|
|
|
model.eval() |
|
if verbose: |
|
print(f"[INFO] Loaded ReCEP model from {model_path}") |
|
except Exception as e: |
|
if verbose: |
|
print(f"[ERROR] Failed to load model: {str(e)}") |
|
return {} |
|
|
|
|
|
try: |
|
embeddings, backbone_atoms, rsa, coverage_dict = self.data_preparation(radius=radius, encoder=encoder) |
|
if verbose: |
|
print(f"[INFO] Retrieved protein data for {len(coverage_dict)} surface regions") |
|
except Exception as e: |
|
if verbose: |
|
print(f"[ERROR] Failed to prepare data: {str(e)}") |
|
traceback.print_exc() |
|
return {} |
|
|
|
if not coverage_dict: |
|
if verbose: |
|
print("[WARNING] No surface regions found") |
|
return {} |
|
|
|
|
|
region_predictions = [] |
|
|
|
with torch.no_grad(): |
|
for center_idx, (covered_indices, covered_epitope_indices, precision, recall) in tqdm( |
|
coverage_dict.items(), desc="Predicting region values", disable=not verbose): |
|
|
|
if len(covered_indices) < 2: |
|
continue |
|
|
|
try: |
|
|
|
graph_data = create_graph_data( |
|
center_idx=center_idx, |
|
covered_indices=covered_indices, |
|
covered_epitope_indices=[], |
|
embeddings=embeddings, |
|
backbone_atoms=backbone_atoms, |
|
rsa_values=rsa, |
|
epitope_indices=[], |
|
recall=0.0, |
|
precision=0.0, |
|
pdb_id=self.id, |
|
chain_id=self.chain_id, |
|
verbose=False |
|
) |
|
|
|
if graph_data is None: |
|
if verbose: |
|
print(f"[WARNING] Failed to create graph data for region {center_idx}") |
|
continue |
|
|
|
|
|
graph_data = graph_data.to(device) |
|
|
|
|
|
graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
|
|
|
|
|
outputs = model(graph_data) |
|
|
|
|
|
if 'global_pred' in outputs: |
|
graph_pred = torch.sigmoid(outputs['global_pred']).cpu().item() |
|
else: |
|
|
|
node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
|
graph_pred = float(np.mean(node_preds)) |
|
|
|
region_predictions.append({ |
|
'center_idx': center_idx, |
|
'covered_indices': covered_indices, |
|
'graph_pred': graph_pred, |
|
'graph_data': graph_data |
|
}) |
|
|
|
except Exception as e: |
|
if verbose: |
|
print(f"[WARNING] Error processing region {center_idx}: {str(e)}") |
|
traceback.print_exc() |
|
continue |
|
|
|
if not region_predictions: |
|
if verbose: |
|
print("[WARNING] No valid region predictions") |
|
return {} |
|
|
|
|
|
region_predictions.sort(key=lambda x: x['graph_pred'], reverse=True) |
|
top_k_regions = region_predictions[:k] |
|
|
|
if verbose: |
|
print(f"[INFO] Selected top {len(top_k_regions)} regions:") |
|
for i, region in enumerate(top_k_regions): |
|
print(f" Region {i+1}: center={region['center_idx']}, " |
|
f"predicted_value={region['graph_pred']:.3f}") |
|
|
|
|
|
residue_probs = {} |
|
|
|
with torch.no_grad(): |
|
for region in tqdm(top_k_regions, desc="Predicting node values", disable=not verbose): |
|
try: |
|
graph_data = region['graph_data'] |
|
|
|
|
|
if not hasattr(graph_data, 'batch') or graph_data.batch is None: |
|
graph_data.batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=device) |
|
|
|
|
|
outputs = model(graph_data) |
|
|
|
|
|
node_preds = torch.sigmoid(outputs['node_preds']).cpu().numpy() |
|
|
|
|
|
for local_idx, residue_idx in enumerate(region['covered_indices']): |
|
if residue_idx not in residue_probs: |
|
residue_probs[residue_idx] = [] |
|
|
|
|
|
prob = float(node_preds[local_idx]) |
|
residue_probs[residue_idx].append(prob) |
|
|
|
except Exception as e: |
|
if verbose: |
|
print(f"[WARNING] Error in node prediction for region {region['center_idx']}: {str(e)}") |
|
traceback.print_exc() |
|
continue |
|
|
|
|
|
all_residue_predictions = {} |
|
for idx in range(len(self.residue_index)): |
|
residue_num = int(self.residue_index[idx]) |
|
if idx in residue_probs: |
|
|
|
all_residue_predictions[residue_num] = float(np.mean(residue_probs[idx])) |
|
else: |
|
|
|
all_residue_predictions[residue_num] = 0.0 |
|
|
|
|
|
predicted_epitope_resnums = [] |
|
node_mean = 0.0 |
|
for residue_num, prob in all_residue_predictions.items(): |
|
node_mean += prob |
|
if prob >= threshold: |
|
predicted_epitope_resnums.append(residue_num) |
|
node_mean /= len(all_residue_predictions) if all_residue_predictions else 1 |
|
|
|
|
|
top_k_centers = [int(self.residue_index[region['center_idx']]) for region in top_k_regions] |
|
|
|
|
|
graph_mean = 0.0 |
|
all_covered_indices = set() |
|
for region in top_k_regions: |
|
all_covered_indices.update(region['covered_indices']) |
|
graph_mean += region['graph_pred'] |
|
graph_mean /= len(top_k_regions) |
|
|
|
top_k_region_residues = [int(self.residue_index[idx]) for idx in all_covered_indices |
|
if 0 <= idx < len(self.residue_index)] |
|
|
|
if verbose: |
|
print(f"\n[INFO] Prediction Results:") |
|
print(f" Predicted epitopes: {len(predicted_epitope_resnums)}") |
|
print(f" Top-k centers: {top_k_centers}") |
|
print(f" Total residues in top-k regions: {len(top_k_region_residues)}") |
|
|
|
|
|
results = { |
|
'predicted_epitopes': predicted_epitope_resnums, |
|
'predictions': all_residue_predictions, |
|
'top_k_centers': top_k_centers, |
|
'top_k_region_residues': top_k_region_residues, |
|
'top_k_regions': [ |
|
{ |
|
'center_residue': int(self.residue_index[region['center_idx']]), |
|
'center_idx': region['center_idx'], |
|
'predicted_value': region['graph_pred'], |
|
'covered_residues': [int(self.residue_index[idx]) for idx in region['covered_indices']] |
|
} |
|
for region in top_k_regions |
|
], |
|
'antigen_rate': graph_mean, |
|
'epitope_rate': node_mean |
|
} |
|
|
|
|
|
if auto_cleanup: |
|
self._cleanup_generated_data(encoder=encoder, verbose=verbose) |
|
|
|
return results |
|
|
|
def _cleanup_generated_data(self, encoder: str = "esmc", verbose: bool = True): |
|
""" |
|
Clean up generated data files for this antigen chain. |
|
|
|
Args: |
|
encoder (str): Encoder type used for embeddings |
|
verbose (bool): Whether to print cleanup information |
|
""" |
|
import os |
|
|
|
|
|
files_to_delete = [ |
|
|
|
Path(BASE_DIR) / "data" / "embeddings" / encoder / f"{self.id}_{self.chain_id}.h5", |
|
|
|
Path(BASE_DIR) / "data" / "coords" / f"{self.id}_{self.chain_id}.npy", |
|
|
|
Path(BASE_DIR) / "data" / "rsa" / f"{self.id}_{self.chain_id}.npy", |
|
|
|
Path(BASE_DIR) / "data" / "antigen_sphere" / f"{self.id}_{self.chain_id}.h5" |
|
] |
|
|
|
deleted_files = [] |
|
failed_deletions = [] |
|
total_size = 0 |
|
|
|
for file_path in files_to_delete: |
|
if file_path.exists(): |
|
try: |
|
|
|
file_size = file_path.stat().st_size |
|
os.remove(file_path) |
|
deleted_files.append(file_path) |
|
total_size += file_size |
|
if verbose: |
|
print(f"[INFO] Deleted: {file_path}") |
|
except Exception as e: |
|
failed_deletions.append((file_path, str(e))) |
|
if verbose: |
|
print(f"[WARNING] Failed to delete {file_path}: {str(e)}") |
|
else: |
|
if verbose: |
|
print(f"[INFO] File not found (already deleted or not generated): {file_path}") |
|
|
|
if verbose: |
|
print(f"[INFO] Cleanup completed for {self.id}_{self.chain_id}") |
|
print(f" - Files deleted: {len(deleted_files)}") |
|
print(f" - Failed deletions: {len(failed_deletions)}") |
|
if total_size > 0: |
|
print(f" - Total space freed: {total_size / (1024**2):.2f} MB") |
|
|
|
def visualize(self, |
|
mode: str = 'normal', |
|
style: str = 'cartoon', |
|
predicted_epitopes: list = None, |
|
predict_results: dict = None, |
|
prediction_mode: str = 'residue', |
|
center_res: int = None, |
|
radius: float = None, |
|
region_index: int = None, |
|
width: int = 800, |
|
height: int = 600, |
|
base_color: str = '#e6e6f7', |
|
true_epitope_color: str = '#f1b54c', |
|
false_positive_color: str = '#ef5331', |
|
true_positive_color: str = '#a0d293', |
|
coverage_color: str = '#9C6ADE', |
|
prediction_color: str = '#9C6ADE', |
|
center_color: str = '#2C3E50', |
|
probability_colormap: str = 'RdYlBu_r', |
|
show_surface: bool = True, |
|
show_shape: bool = True, |
|
show_center: bool = True, |
|
center_radius: float = 0.7, |
|
n_points: int = 50, |
|
shape_opacity: float = 0.3, |
|
surface_opacity: float = 1.0, |
|
wireframe: bool = True, |
|
show_epitope: bool = True, |
|
show_coverage: bool = True, |
|
show_top_regions: bool = True, |
|
max_spheres: int = None, |
|
prob_threshold: float = 0.5): |
|
""" |
|
Visualize the protein chain with various modes and integration with predict results. |
|
|
|
Args: |
|
mode (str): Visualization mode. Options: |
|
- 'normal': Basic protein structure |
|
- 'epitope': Show predicted epitopes vs true epitopes |
|
- 'coverage': Show spherical coverage region |
|
- 'evaluation': Show evaluation results from evaluate() function |
|
- 'prediction': Show prediction results from predict() function |
|
- 'probability': Show residue probabilities as color gradient |
|
- 'top_regions': Show top-k regions from prediction |
|
- 'comparison': Compare voted vs predicted epitopes |
|
prediction_mode (str): Sub-mode for prediction visualization ('residue' or 'region') |
|
- 'residue': Color predicted epitopes by probability (gradient purple) |
|
- 'region': Color all residues in top-k regions uniformly |
|
style (str): Protein representation style ('cartoon', 'stick', 'sphere', 'surface') |
|
predicted_epitopes (list): List of predicted epitope residue numbers |
|
predict_results (dict): Results dictionary from predict() function |
|
center_res (int): Center residue number for coverage visualization |
|
radius (float): Radius for spherical coverage |
|
region_index (int): Index of specific region to show in probability mode (0-based) |
|
If None, shows all regions |
|
Each region uses a distinct color for shape visualization |
|
probability_colormap (str): Colormap name for probability visualization |
|
prob_threshold (float): Threshold for probability-based coloring |
|
... (other parameters as before) |
|
|
|
Returns: |
|
py3Dmol.view: The molecular visualization view object |
|
""" |
|
|
|
view = self._create_base_view(width, height) |
|
|
|
|
|
style_dict = { |
|
'cartoon': {'cartoon': {}}, |
|
'stick': {'stick': {}}, |
|
'sphere': {'sphere': {}}, |
|
'surface': {'surface': {}} |
|
} |
|
base_style = style_dict.get(style, {'cartoon': {}}) |
|
|
|
|
|
if mode == 'epitope' and predicted_epitopes is not None: |
|
self._add_epitope_visualization( |
|
view, style, predicted_epitopes, |
|
base_color, true_epitope_color, false_positive_color, |
|
true_positive_color, coverage_color, |
|
show_surface, surface_opacity, show_coverage, |
|
center_res, radius |
|
) |
|
|
|
|
|
if show_shape and center_res is not None and radius is not None: |
|
self._add_shape_visualization( |
|
view, center_res, radius, |
|
coverage_color, center_color, |
|
show_center, center_radius, |
|
shape_opacity, wireframe |
|
) |
|
|
|
elif mode == 'coverage' and center_res is not None and radius is not None: |
|
self._add_coverage_visualization( |
|
view, style, center_res, radius, |
|
base_color, coverage_color, true_positive_color, true_epitope_color, |
|
show_surface, show_shape, show_center, |
|
surface_opacity, shape_opacity, center_radius, |
|
n_points, center_color, wireframe, show_epitope |
|
) |
|
|
|
elif mode == 'evaluation' and predict_results is not None: |
|
self._add_evaluation_visualization( |
|
view, style, predict_results, |
|
base_color, true_epitope_color, false_positive_color, |
|
true_positive_color, coverage_color, |
|
show_surface, surface_opacity, show_shape, radius, max_spheres |
|
) |
|
|
|
elif mode == 'prediction' and predict_results is not None: |
|
self._add_prediction_visualization( |
|
view, style, predict_results, prediction_mode, |
|
base_color, prediction_color, show_surface, surface_opacity, |
|
show_shape, shape_opacity, show_center, center_radius, |
|
wireframe, radius, max_spheres |
|
) |
|
|
|
elif mode == 'probability' and predict_results is not None: |
|
self._add_probability_visualization( |
|
view, style, predict_results, |
|
base_color, probability_colormap, show_surface, surface_opacity, |
|
prob_threshold, region_index, radius, show_shape, shape_opacity, |
|
show_center, center_radius, wireframe, coverage_color, center_color |
|
) |
|
|
|
elif mode == 'top_regions' and predict_results is not None: |
|
self._add_top_regions_visualization( |
|
view, style, predict_results, |
|
base_color, coverage_color, center_color, |
|
show_surface, show_shape, show_center, |
|
surface_opacity, shape_opacity, center_radius, |
|
wireframe, radius, max_spheres |
|
) |
|
|
|
elif mode == 'comparison' and predict_results is not None: |
|
self._add_comparison_visualization( |
|
view, style, predict_results, |
|
base_color, true_epitope_color, false_positive_color, |
|
true_positive_color, coverage_color, show_surface, surface_opacity |
|
) |
|
|
|
else: |
|
|
|
view.setStyle({'chain': self.chain_id}, base_style) |
|
|
|
|
|
view.zoomTo() |
|
return view |
|
|
|
def _add_prediction_visualization(self, view, style, predict_results, prediction_mode, |
|
base_color, prediction_color, show_surface, surface_opacity, |
|
show_shape, shape_opacity, show_center, center_radius, |
|
wireframe, radius, max_spheres): |
|
"""Add visualization for prediction results""" |
|
if prediction_mode == 'residue': |
|
self._add_prediction_residue_mode( |
|
view, style, predict_results, base_color, prediction_color, |
|
show_surface, surface_opacity |
|
) |
|
elif prediction_mode == 'region': |
|
self._add_prediction_region_mode( |
|
view, style, predict_results, base_color, prediction_color, |
|
show_surface, surface_opacity, show_shape, shape_opacity, |
|
show_center, center_radius, wireframe, radius, max_spheres |
|
) |
|
|
|
def _add_prediction_residue_mode(self, view, style, predict_results, base_color, prediction_color, |
|
show_surface, surface_opacity): |
|
"""Add visualization for prediction results in residue mode""" |
|
import matplotlib.pyplot as plt |
|
import matplotlib.colors as mcolors |
|
|
|
|
|
predictions = predict_results.get('predictions', {}) |
|
predicted_epitopes = predict_results.get('predicted_epitopes', []) |
|
|
|
|
|
style_dict = { |
|
'cartoon': {'cartoon': {}}, |
|
'stick': {'stick': {}}, |
|
'sphere': {'sphere': {}}, |
|
'surface': {'surface': {}} |
|
} |
|
base_style = style_dict.get(style, {'cartoon': {}}) |
|
|
|
if not predictions: |
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, |
|
list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
if show_surface: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id}) |
|
return |
|
|
|
|
|
epitope_predictions = {res: prob for res, prob in predictions.items() |
|
if res in predicted_epitopes} |
|
|
|
if not epitope_predictions: |
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
if show_surface: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id}) |
|
return |
|
|
|
|
|
probs = list(epitope_predictions.values()) |
|
min_prob, max_prob = min(probs), max(probs) |
|
|
|
|
|
|
|
epitope_colors = [ |
|
'#FFE4B5', |
|
'#FFD700', |
|
'#FFA500', |
|
'#FF8C00', |
|
'#FF6347', |
|
'#FF4500', |
|
'#DC143C' |
|
] |
|
n_colors = len(epitope_colors) |
|
|
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
|
|
|
|
for residue_num, prob in epitope_predictions.items(): |
|
|
|
if max_prob > min_prob: |
|
norm_prob = (prob - min_prob) / (max_prob - min_prob) |
|
else: |
|
norm_prob = 0.5 |
|
|
|
|
|
color_idx = int(norm_prob * (n_colors - 1)) |
|
color_idx = max(0, min(color_idx, n_colors - 1)) |
|
color = epitope_colors[color_idx] |
|
|
|
|
|
style_name = list(base_style.keys())[0] |
|
colored_style = {style_name: {'color': color}} |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': residue_num}, |
|
colored_style |
|
) |
|
|
|
|
|
if show_surface: |
|
|
|
all_residues = set(int(res) for res in self.residue_index) |
|
non_epitope_residues = all_residues - set(predicted_epitopes) |
|
|
|
if non_epitope_residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id, 'resi': list(non_epitope_residues)}) |
|
|
|
|
|
for residue_num, prob in epitope_predictions.items(): |
|
|
|
if max_prob > min_prob: |
|
norm_prob = (prob - min_prob) / (max_prob - min_prob) |
|
else: |
|
norm_prob = 0.5 |
|
|
|
|
|
color_idx = int(norm_prob * (n_colors - 1)) |
|
color_idx = max(0, min(color_idx, n_colors - 1)) |
|
color = epitope_colors[color_idx] |
|
|
|
|
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': color |
|
}, {'chain': self.chain_id, 'resi': residue_num}) |
|
|
|
def _add_prediction_region_mode(self, view, style, predict_results, base_color, prediction_color, |
|
show_surface, surface_opacity, show_shape, shape_opacity, |
|
show_center, center_radius, wireframe, radius, max_spheres): |
|
"""Add visualization for prediction results in region mode""" |
|
|
|
top_k_regions = predict_results.get('top_k_regions', []) |
|
top_k_region_residues = predict_results.get('top_k_region_residues', []) |
|
|
|
|
|
style_dict = { |
|
'cartoon': {'cartoon': {}}, |
|
'stick': {'stick': {}}, |
|
'sphere': {'sphere': {}}, |
|
'surface': {'surface': {}} |
|
} |
|
base_style = style_dict.get(style, {'cartoon': {}}) |
|
|
|
if not top_k_region_residues: |
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
if show_surface: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id}) |
|
return |
|
|
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
|
|
|
|
if top_k_region_residues: |
|
style_name = list(base_style.keys())[0] |
|
colored_style = {style_name: {'color': prediction_color}} |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': top_k_region_residues}, |
|
colored_style |
|
) |
|
|
|
|
|
if show_surface: |
|
|
|
all_residues = set(int(res) for res in self.residue_index) |
|
non_region_residues = all_residues - set(top_k_region_residues) |
|
|
|
if non_region_residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id, 'resi': list(non_region_residues)}) |
|
|
|
|
|
if top_k_region_residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': prediction_color |
|
}, {'chain': self.chain_id, 'resi': top_k_region_residues}) |
|
|
|
|
|
if show_shape and top_k_regions: |
|
self._add_multi_shape_visualization( |
|
view, top_k_regions, radius, max_spheres, |
|
show_center, center_radius, shape_opacity, wireframe |
|
) |
|
|
|
def _add_evaluation_visualization(self, view, style, predict_results, |
|
base_color, true_epitope_color, false_positive_color, |
|
true_positive_color, coverage_color, |
|
show_surface, surface_opacity, show_shape, radius, max_spheres): |
|
"""Add visualization for evaluation results""" |
|
|
|
predicted_epitopes = set(predict_results.get('predicted_epitopes', [])) |
|
true_epitopes = set(predict_results.get('true_epitopes', [])) |
|
|
|
|
|
true_positives = predicted_epitopes & true_epitopes |
|
false_positives = predicted_epitopes - true_epitopes |
|
false_negatives = true_epitopes - predicted_epitopes |
|
|
|
|
|
style_dict = { |
|
'cartoon': {'cartoon': {}}, |
|
'stick': {'stick': {}}, |
|
'sphere': {'sphere': {}}, |
|
'surface': {'surface': {}} |
|
} |
|
base_style = style_dict.get(style, {'cartoon': {}}) |
|
|
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
|
|
|
|
for residues, color in [ |
|
(true_positives, true_positive_color), |
|
(false_positives, false_positive_color), |
|
(false_negatives, true_epitope_color) |
|
]: |
|
if residues: |
|
|
|
style_name = list(base_style.keys())[0] |
|
colored_style = {style_name: {'color': color}} |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': list(residues)}, |
|
colored_style |
|
) |
|
|
|
|
|
if show_surface: |
|
|
|
all_colored_residues = true_positives | false_positives | false_negatives |
|
|
|
|
|
if all_colored_residues: |
|
all_residues = set(int(res) for res in self.residue_index) |
|
non_colored_residues = all_residues - all_colored_residues |
|
|
|
if non_colored_residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id, 'resi': list(non_colored_residues)}) |
|
else: |
|
|
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id}) |
|
|
|
|
|
for residues, color in [ |
|
(true_positives, true_positive_color), |
|
(false_positives, false_positive_color), |
|
(false_negatives, true_epitope_color) |
|
]: |
|
if residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': color |
|
}, {'chain': self.chain_id, 'resi': list(residues)}) |
|
|
|
|
|
if show_shape and 'top_k_regions' in predict_results: |
|
top_regions = predict_results['top_k_regions'] |
|
self._add_multi_shape_visualization( |
|
view, top_regions, radius, max_spheres, |
|
True, 0.5, 0.2, True |
|
) |
|
|
|
def _add_probability_visualization(self, view, style, predict_results, |
|
base_color, colormap, show_surface, surface_opacity, threshold, |
|
region_index, radius, show_shape, shape_opacity, |
|
show_center, center_radius, wireframe, coverage_color, center_color): |
|
""" |
|
Add visualization based on prediction probabilities with enhanced support for |
|
specific region selection and surface rendering. |
|
|
|
Args: |
|
view: py3Dmol view object |
|
style (str): Protein representation style |
|
predict_results (dict): Results from predict() function |
|
base_color (str): Base color for non-highlighted residues |
|
colormap (str): Colormap name for probability visualization |
|
show_surface (bool): Whether to show surface |
|
surface_opacity (float): Surface opacity |
|
threshold (float): Probability threshold for coloring |
|
region_index (int): Index of specific region to show (0-based), None for all |
|
Each region_index uses a distinct color for shape visualization |
|
radius (float): Radius for spherical regions |
|
show_shape (bool): Whether to show spherical shapes |
|
shape_opacity (float): Shape opacity |
|
show_center (bool): Whether to show center points |
|
center_radius (float): Center point radius |
|
wireframe (bool): Whether to show wireframe spheres |
|
coverage_color (str): Color for coverage regions (not used when region_index is specified) |
|
center_color (str): Color for center points |
|
""" |
|
import matplotlib.pyplot as plt |
|
import matplotlib.colors as mcolors |
|
|
|
|
|
predictions = predict_results.get('predictions', {}) |
|
top_k_regions = predict_results.get('top_k_regions', []) |
|
|
|
|
|
style_dict = { |
|
'cartoon': {'cartoon': {}}, |
|
'stick': {'stick': {}}, |
|
'sphere': {'sphere': {}}, |
|
'surface': {'surface': {}} |
|
} |
|
base_style = style_dict.get(style, {'cartoon': {}}) |
|
|
|
if not predictions: |
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, |
|
list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
if show_surface: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id}) |
|
return |
|
|
|
|
|
view.setStyle({'chain': self.chain_id}, {**base_style, |
|
list(base_style.keys())[0]: {**list(base_style.values())[0], 'color': base_color}}) |
|
|
|
|
|
target_residues = {} |
|
selected_region = None |
|
|
|
if region_index is not None and 0 <= region_index < len(top_k_regions): |
|
|
|
selected_region = top_k_regions[region_index] |
|
covered_residues = selected_region.get('covered_residues', []) |
|
|
|
|
|
for res_num in covered_residues: |
|
if res_num in predictions: |
|
target_residues[res_num] = predictions[res_num] |
|
else: |
|
|
|
target_residues = {res: prob for res, prob in predictions.items() |
|
if prob >= threshold} |
|
|
|
if not target_residues: |
|
|
|
if show_surface: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id}) |
|
return |
|
|
|
|
|
probs = list(target_residues.values()) |
|
min_prob, max_prob = min(probs), max(probs) |
|
|
|
|
|
if colormap in ['RdYlBu_r', 'coolwarm', 'RdBu_r']: |
|
|
|
probability_colors = [ |
|
'#c6dbef', |
|
'#9ecae1', |
|
'#6baed6', |
|
'#4292c6', |
|
'#2171b5', |
|
'#fcbba1', |
|
'#fc9272', |
|
'#fb6a4a', |
|
'#ef3b2c', |
|
'#cb181d' |
|
] |
|
n_colors = len(probability_colors) |
|
else: |
|
|
|
cmap = plt.cm.get_cmap(colormap) |
|
probability_colors = [] |
|
n_colors = 10 |
|
for i in range(n_colors): |
|
color_rgba = cmap(i / (n_colors - 1)) |
|
|
|
softened_rgba = [ |
|
color_rgba[0] * 0.7 + 0.3, |
|
color_rgba[1] * 0.7 + 0.3, |
|
color_rgba[2] * 0.7 + 0.3, |
|
] |
|
|
|
softened_rgba = [min(1.0, val) for val in softened_rgba] |
|
probability_colors.append(mcolors.rgb2hex(softened_rgba)) |
|
|
|
|
|
colored_residues = [] |
|
for residue_num, prob in target_residues.items(): |
|
|
|
if max_prob > min_prob: |
|
norm_prob = (prob - min_prob) / (max_prob - min_prob) |
|
else: |
|
norm_prob = 0.5 |
|
|
|
|
|
color_idx = int(norm_prob * (n_colors - 1)) |
|
color_idx = max(0, min(color_idx, n_colors - 1)) |
|
color = probability_colors[color_idx] |
|
|
|
|
|
style_name = list(base_style.keys())[0] |
|
colored_style = {style_name: {'color': color}} |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': residue_num}, |
|
colored_style |
|
) |
|
colored_residues.append(residue_num) |
|
|
|
|
|
if show_surface: |
|
|
|
all_residues = set(int(res) for res in self.residue_index) |
|
non_colored_residues = all_residues - set(colored_residues) |
|
|
|
if non_colored_residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}, {'chain': self.chain_id, 'resi': list(non_colored_residues)}) |
|
|
|
|
|
for residue_num, prob in target_residues.items(): |
|
|
|
if max_prob > min_prob: |
|
norm_prob = (prob - min_prob) / (max_prob - min_prob) |
|
else: |
|
norm_prob = 0.5 |
|
|
|
|
|
color_idx = int(norm_prob * (n_colors - 1)) |
|
color_idx = max(0, min(color_idx, n_colors - 1)) |
|
color = probability_colors[color_idx] |
|
|
|
|
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': color |
|
}, {'chain': self.chain_id, 'resi': residue_num}) |
|
|
|
|
|
if selected_region is not None and show_shape: |
|
center_res = selected_region['center_residue'] |
|
|
|
|
|
sphere_radius = radius or 19.0 |
|
|
|
|
|
region_colors = [ |
|
'#FF6B6B', |
|
'#4ECDC4', |
|
'#45B7D1', |
|
'#96CEB4', |
|
'#FFEAA7', |
|
'#DDA0DD', |
|
'#87CEEB', |
|
'#F0E68C', |
|
'#FFB6C1', |
|
'#98FB98', |
|
'#9C6ADE', |
|
'#FF9A8B' |
|
] |
|
|
|
|
|
shape_color = region_colors[region_index % len(region_colors)] |
|
|
|
|
|
self._add_shape_visualization( |
|
view, center_res, sphere_radius, |
|
shape_color, center_color, |
|
show_center, center_radius, |
|
shape_opacity * 0.6, |
|
wireframe |
|
) |
|
|
|
|
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': center_res}, |
|
{list(base_style.keys())[0]: {'color': shape_color}} |
|
) |
|
|
|
def _add_top_regions_visualization(self, view, style, predict_results, |
|
base_color, coverage_color, center_color, |
|
show_surface, show_shape, show_center, |
|
surface_opacity, shape_opacity, center_radius, |
|
wireframe, radius, max_spheres): |
|
"""Add visualization for top-k regions""" |
|
|
|
view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
|
|
|
|
|
top_regions = predict_results.get('top_k_regions', []) |
|
|
|
|
|
if max_spheres is not None: |
|
top_regions = top_regions[:max_spheres] |
|
|
|
|
|
region_colors = [ |
|
'#FF6B6B', |
|
'#96CEB4', |
|
'#4ECDC4', |
|
'#45B7D1', |
|
'#FFEAA7', |
|
'#DDA0DD', |
|
'#87CEEB', |
|
'#F0E68C', |
|
'#FFB6C1', |
|
'#98FB98' |
|
] |
|
|
|
for i, region in enumerate(top_regions): |
|
center_res = region['center_residue'] |
|
covered_residues = region.get('covered_residues', []) |
|
region_color = region_colors[i % len(region_colors)] |
|
|
|
|
|
if covered_residues: |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': covered_residues}, |
|
{style: {'color': region_color}} |
|
) |
|
|
|
|
|
if show_shape: |
|
self._add_shape_visualization( |
|
view, center_res, radius or 18.0, |
|
region_color, center_color, |
|
show_center, center_radius * 0.8, |
|
shape_opacity, wireframe |
|
) |
|
|
|
|
|
if show_surface: |
|
|
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}) |
|
|
|
|
|
for i, region in enumerate(top_regions): |
|
covered_residues = region.get('covered_residues', []) |
|
region_color = region_colors[i % len(region_colors)] |
|
|
|
if covered_residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': region_color |
|
}, {'resi': covered_residues}) |
|
|
|
def _add_comparison_visualization(self, view, style, predict_results, |
|
base_color, true_epitope_color, false_positive_color, |
|
true_positive_color, coverage_color, show_surface, surface_opacity): |
|
"""Add visualization comparing voted vs predicted epitopes""" |
|
|
|
view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
|
|
|
|
|
predicted_epitopes = set(predict_results.get('predicted_epitopes', [])) |
|
voted_epitopes = set(predict_results.get('voted_epitopes', [])) |
|
true_epitopes = set(predict_results.get('true_epitopes', [])) |
|
|
|
|
|
both_methods = predicted_epitopes & voted_epitopes |
|
only_predicted = predicted_epitopes - voted_epitopes |
|
only_voted = voted_epitopes - predicted_epitopes |
|
|
|
|
|
both_correct = both_methods & true_epitopes |
|
both_incorrect = both_methods - true_epitopes |
|
only_pred_correct = only_predicted & true_epitopes |
|
only_pred_incorrect = only_predicted - true_epitopes |
|
only_vote_correct = only_voted & true_epitopes |
|
only_vote_incorrect = only_voted - true_epitopes |
|
|
|
|
|
color_assignments = [ |
|
(both_correct, '#00FF00'), |
|
(both_incorrect, '#FF0000'), |
|
(only_pred_correct, '#90EE90'), |
|
(only_pred_incorrect, '#FFB6C1'), |
|
(only_vote_correct, '#87CEEB'), |
|
(only_vote_incorrect, '#DDA0DD') |
|
] |
|
|
|
for residues, color in color_assignments: |
|
if residues: |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': list(residues)}, |
|
{style: {'color': color}} |
|
) |
|
|
|
|
|
if show_surface: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': base_color |
|
}) |
|
|
|
for residues, color in color_assignments: |
|
if residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': color |
|
}, {'resi': list(residues)}) |
|
|
|
def _create_base_view(self, width: int, height: int) -> py3Dmol.view: |
|
"""创建基本的py3Dmol视图并添加蛋白质结构""" |
|
view = py3Dmol.view(width=width, height=height) |
|
|
|
|
|
pdb_str = "MODEL 1\n" |
|
atom_num = 1 |
|
for res_idx in range(len(self.sequence)): |
|
one_letter = self.sequence[res_idx] |
|
resname = self.convert_letter_1to3(one_letter) |
|
resnum = self.residue_index[res_idx] |
|
|
|
mask = self.atom37_mask[res_idx] |
|
coords = self.atom37_positions[res_idx][mask] |
|
atoms = [name for name, exists in zip(RC.atom_types, mask) if exists] |
|
|
|
for atom_name, coord in zip(atoms, coords): |
|
x, y, z = coord |
|
pdb_str += (f"ATOM {atom_num:5d} {atom_name:<3s} {resname:>3s} {self.chain_id:1s}{resnum:4d}" |
|
f" {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n") |
|
atom_num += 1 |
|
|
|
pdb_str += "ENDMDL\n" |
|
view.addModel(pdb_str, "pdb") |
|
return view |
|
|
|
def _add_epitope_visualization(self, view, style, predicted_epitopes, |
|
base_color, true_epitope_color, false_positive_color, true_positive_color, coverage_color, |
|
show_surface, surface_opacity, show_coverage, |
|
center_res=None, radius=None): |
|
"""添加表位可视化""" |
|
|
|
view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
|
|
|
true_epitopes = set(self.get_epitope_residue_numbers()) |
|
true_positives = set(predicted_epitopes) & true_epitopes |
|
false_positives = set(predicted_epitopes) - true_epitopes |
|
false_negatives = true_epitopes - set(predicted_epitopes) |
|
|
|
|
|
covered_residues = [] |
|
if center_res is not None and radius is not None: |
|
coverage_dict, _, _ = self.get_surface_coverage( |
|
radius=radius, threshold=0.25, index=False |
|
) |
|
covered_res_list = coverage_dict.get(center_res, [[], [], 0, 0])[0] |
|
covered_residues = covered_res_list |
|
|
|
|
|
if covered_residues: |
|
true_negatives = [res for res in covered_residues |
|
if res not in true_epitopes and res not in predicted_epitopes] |
|
|
|
|
|
true_negative_color = '#888888' |
|
|
|
if true_negatives: |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': true_negatives}, |
|
{style: {'color': true_negative_color}} |
|
) |
|
|
|
|
|
for residues, color in [ |
|
(true_positives, true_positive_color), |
|
(false_positives, false_positive_color), |
|
(false_negatives, true_epitope_color) |
|
]: |
|
if residues: |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': list(residues)}, |
|
{style: {'color': color}} |
|
) |
|
|
|
|
|
if show_surface: |
|
|
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': base_color |
|
}) |
|
|
|
|
|
for residues, color in [ |
|
(true_positives, true_positive_color), |
|
(false_positives, false_positive_color), |
|
(false_negatives, true_epitope_color) |
|
]: |
|
if residues: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': color |
|
}, {'resi': list(residues)}) |
|
|
|
|
|
if center_res is not None and radius is not None and covered_residues and show_coverage: |
|
true_negatives = [res for res in covered_residues |
|
if res not in true_epitopes and res not in predicted_epitopes] |
|
|
|
if true_negatives: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': coverage_color |
|
}, {'resi': true_negatives}) |
|
|
|
def _add_shape_visualization(self, view, center_res, radius, |
|
coverage_color, center_color, |
|
show_center, center_radius, |
|
shape_opacity, wireframe): |
|
"""添加球形可视化""" |
|
center_idx = self.resnum_to_index.get(center_res) |
|
if center_idx is None: |
|
return |
|
|
|
ca_idx = RC.atom_order["CA"] |
|
center_coord = self.atom37_positions[center_idx, ca_idx, :] |
|
|
|
|
|
sphere_params = { |
|
'center': {'x': float(center_coord[0]), |
|
'y': float(center_coord[1]), |
|
'z': float(center_coord[2])}, |
|
'radius': float(radius), |
|
'color': coverage_color |
|
} |
|
if wireframe: |
|
sphere_params.update({'wireframe': True, 'linewidth': 1.5}) |
|
else: |
|
sphere_params.update({'opacity': shape_opacity}) |
|
view.addSphere(sphere_params) |
|
|
|
|
|
if show_center: |
|
view.addSphere({ |
|
'center': {'x': float(center_coord[0]), |
|
'y': float(center_coord[1]), |
|
'z': float(center_coord[2])}, |
|
'radius': float(center_radius), |
|
'color': center_color, |
|
'opacity': 1.0 |
|
}) |
|
|
|
def _add_coverage_visualization(self, view, style, center_res, radius, |
|
base_color, coverage_color, true_positive_color, true_epitope_color, |
|
show_surface, show_shape, show_center, |
|
surface_opacity, shape_opacity, center_radius, |
|
n_points, center_color, wireframe, show_epitope): |
|
"""添加覆盖区域可视化""" |
|
|
|
view.setStyle({'chain': self.chain_id}, {style: {'color': base_color}}) |
|
|
|
|
|
coverage_dict, _, _ = self.get_surface_coverage( |
|
radius=radius, threshold=0.25, index=False |
|
) |
|
|
|
covered_res_list = coverage_dict.get(center_res, [[], [], 0, 0])[0] |
|
covered_residues = covered_res_list |
|
|
|
|
|
if show_epitope: |
|
true_epitopes = set(self.get_epitope_residue_numbers()) |
|
else: |
|
true_epitopes = set() |
|
|
|
|
|
true_positives = set(covered_residues) & true_epitopes |
|
false_negatives = true_epitopes - set(covered_residues) |
|
covered_non_epitopes = set(covered_residues) - true_epitopes |
|
|
|
|
|
if show_surface: |
|
|
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 1.0, |
|
'color': base_color |
|
}) |
|
|
|
|
|
if false_negatives: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': true_epitope_color |
|
}, {'resi': list(false_negatives)}) |
|
|
|
|
|
if true_positives: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity, |
|
'color': true_positive_color |
|
}, {'resi': list(true_positives)}) |
|
|
|
|
|
if covered_non_epitopes: |
|
view.addSurface(py3Dmol.VDW, { |
|
'opacity': surface_opacity * 0.9, |
|
'color': coverage_color |
|
}, {'resi': list(covered_non_epitopes)}) |
|
|
|
|
|
if false_negatives: |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': list(false_negatives)}, |
|
{style: {'color': true_epitope_color}} |
|
) |
|
|
|
if true_positives: |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': list(true_positives)}, |
|
{style: {'color': true_positive_color}} |
|
) |
|
|
|
if covered_non_epitopes: |
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': list(covered_non_epitopes)}, |
|
{style: {'color': coverage_color}} |
|
) |
|
|
|
|
|
view.addStyle( |
|
{'chain': self.chain_id, 'resi': center_res}, |
|
{style: {'color': '#FFD700'}} |
|
) |
|
|
|
|
|
if show_shape: |
|
self._add_shape_visualization( |
|
view, center_res, radius, |
|
coverage_color, |
|
center_color, |
|
show_center, center_radius, |
|
shape_opacity, wireframe |
|
) |
|
|
|
def _add_multi_shape_visualization(self, view, regions_data, radius, max_spheres, |
|
show_center, center_radius, shape_opacity, wireframe): |
|
"""Add multiple spherical regions with different colors""" |
|
if not regions_data: |
|
return |
|
|
|
|
|
regions_to_show = regions_data[:max_spheres] if max_spheres else regions_data |
|
|
|
|
|
sphere_colors = [ |
|
'#d671f1', |
|
'#7190f1', |
|
'#FF6B6B', |
|
'#96CEB4', |
|
'#FFEAA7', |
|
'#FFB6C1', |
|
'#4ECDC4', |
|
'#87CEEB', |
|
'#F0E68C', |
|
'#98FB98', |
|
'#45B7D1' |
|
] |
|
|
|
for i, region_data in enumerate(regions_to_show): |
|
if isinstance(region_data, dict): |
|
|
|
center_res = region_data['center_residue'] |
|
else: |
|
|
|
center_res = region_data |
|
|
|
sphere_color = sphere_colors[i % len(sphere_colors)] |
|
self._add_shape_visualization( |
|
view, center_res, radius or 18.0, |
|
sphere_color, '#FFD700', |
|
show_center, center_radius, shape_opacity, wireframe |
|
) |
|
|
|
@classmethod |
|
def from_pdb( |
|
cls, |
|
path: Optional[PathOrBuffer] = None, |
|
chain_id: str = "detect", |
|
id: Optional[str] = None, |
|
is_predicted: bool = False, |
|
) -> "AntigenChain": |
|
""" |
|
Return a AntigenChain object from a pdb file. |
|
|
|
If `path` is not provided, the function will try multiple possible paths: |
|
1. {id}_{chain_id}.pdb |
|
2. {id}.pdb |
|
3. {id.lower()}_{chain_id}.pdb |
|
4. {id.upper()}_{chain_id}.pdb |
|
If none of these paths exist, it will download the structure from RCSB PDB |
|
and save it to the antigen_structs directory. |
|
|
|
Args: |
|
path (Optional[PathOrBuffer]): Path or buffer to read pdb file from. If None, |
|
the default path is constructed from DATA_DIR. |
|
chain_id (str, optional): Select a chain corresponding to (author) chain id. |
|
"detect" uses the first detected chain. |
|
id (Optional[str], optional): Protein identifier (pdb_id). If not provided and `path` |
|
is given, the id will be inferred from the file name. |
|
is_predicted (bool, optional): If True, reads b factor as the confidence readout. |
|
|
|
Returns: |
|
AntigenChain: The constructed antigen chain. |
|
""" |
|
|
|
id = id.lower() |
|
|
|
if path is None: |
|
if id is None: |
|
raise ValueError("Either 'path' or 'id' must be provided to locate the pdb file.") |
|
|
|
|
|
possible_paths = [ |
|
Path(BASE_DIR) / "data" / "antigen_structs" / f"{id}_{chain_id}.pdb", |
|
Path(BASE_DIR) / "data" / "antigen_structs" / f"{id}.pdb", |
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
path = None |
|
for p in possible_paths: |
|
if p.exists(): |
|
path = p |
|
print(f"Found pdb file at {path}") |
|
break |
|
|
|
|
|
if path is None: |
|
try: |
|
|
|
save_dir = Path(BASE_DIR) / "data" / "pdb" |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
rcsb.fetch(id, "pdb", target_path=save_dir) |
|
|
|
path = save_dir / f"{id}.pdb" |
|
print(f"No existing pdb file for {id}_{chain_id}, downloaded {id} complex pdb file to {path}") |
|
|
|
except Exception as e: |
|
print(f"[ERROR] Failed to download pdb file for {id}: {str(e)}") |
|
return None |
|
else: |
|
path = Path(path) |
|
|
|
|
|
if id is not None: |
|
file_id = id |
|
else: |
|
|
|
file_id = path.with_suffix("").name |
|
|
|
|
|
try: |
|
atom_array = PDBFile.read(path).get_structure(model=1, extra_fields=["b_factor"]) |
|
except Exception as e: |
|
print(f"[ERROR] Failed to read pdb file {path}: {str(e)}") |
|
return None |
|
|
|
|
|
if chain_id == "detect": |
|
chain_id = atom_array.chain_id[0] |
|
print(f"[WARNING] No chain_id provided, using the first detected chain: {chain_id}") |
|
|
|
|
|
atom_array = atom_array[ |
|
bs.filter_amino_acids(atom_array) |
|
& ~atom_array.hetero |
|
& (atom_array.chain_id == chain_id) |
|
] |
|
|
|
|
|
entity_id = 1 |
|
|
|
|
|
sequence = "".join( |
|
( |
|
r if len((r := PDBData.protein_letters_3to1.get(monomer[0].res_name, "X"))) == 1 else "X" |
|
) |
|
for monomer in bs.residue_iter(atom_array) |
|
) |
|
num_res = len(sequence) |
|
|
|
|
|
atom_positions = np.full([num_res, RC.atom_type_num, 3], np.nan, dtype=np.float32) |
|
atom_mask = np.full([num_res, RC.atom_type_num], False, dtype=bool) |
|
residue_index = np.full([num_res], -1, dtype=np.int64) |
|
insertion_code = np.full([num_res], "", dtype="<U4") |
|
confidence = np.ones([num_res], dtype=np.float32) |
|
|
|
|
|
for i, res in enumerate(bs.residue_iter(atom_array)): |
|
for atom in res: |
|
atom_name = atom.atom_name |
|
if atom_name == "SE" and atom.res_name == "MSE": |
|
atom_name = "SD" |
|
if atom_name in RC.atom_order: |
|
atom_positions[i, RC.atom_order[atom_name]] = atom.coord |
|
atom_mask[i, RC.atom_order[atom_name]] = True |
|
if is_predicted and atom_name == "CA": |
|
confidence[i] = atom.b_factor |
|
residue_index[i] = res[0].res_id |
|
insertion_code[i] = res[0].ins_code |
|
|
|
|
|
assert all(sequence), "Some residue name was not specified correctly" |
|
|
|
return cls( |
|
id=file_id, |
|
sequence=sequence, |
|
chain_id=chain_id, |
|
entity_id=entity_id, |
|
atom37_positions=atom_positions, |
|
atom37_mask=atom_mask, |
|
residue_index=residue_index, |
|
insertion_code=insertion_code, |
|
confidence=confidence, |
|
) |