|
import torch |
|
|
|
def set_seed(seed=42): |
|
"""Set random seed for reproducibility.""" |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
def get_chain_organism(pdb_id, chain_id): |
|
""" |
|
Use RCSB PDB API to get the organism of the chain. |
|
""" |
|
import requests |
|
entry_url = f"https://data.rcsb.org/rest/v1/core/entry/{pdb_id.lower()}" |
|
res = requests.get(entry_url) |
|
if res.status_code != 200: |
|
return "Unknown" |
|
entry_data = res.json() |
|
|
|
|
|
chain_to_entity = {} |
|
for entity_id in entry_data.get("rcsb_entry_container_identifiers", {}).get("polymer_entity_ids", []): |
|
entity_url = f"https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id.lower()}/{entity_id}" |
|
entity_res = requests.get(entity_url) |
|
if entity_res.status_code != 200: |
|
continue |
|
entity_data = entity_res.json() |
|
chains = entity_data.get("rcsb_polymer_entity_container_identifiers", {}).get("auth_asym_ids", []) |
|
for c in chains: |
|
chain_to_entity[c] = entity_id |
|
if chain_id in chains: |
|
organism = entity_data.get("rcsb_entity_source_organism", [{}])[0].get("scientific_name", "Unknown") |
|
return organism |
|
return "Unknown" |
|
|
|
def classify_antigen(organisms): |
|
for org in organisms: |
|
if "virus" in org.lower() or "coronavirus" in org.lower(): |
|
return "viral" |
|
elif "homo sapiens" in org.lower(): |
|
return "human" |
|
elif "bacteria" in org.lower() or "bacillus" in org.lower(): |
|
return "bacterial" |
|
elif "tumor" in org.lower() or "cancer" in org.lower(): |
|
return "tumor" |
|
return "other" |