import gradio as gr import os import json import tempfile import traceback import numpy as np import pandas as pd from pathlib import Path from typing import Optional, Tuple, Dict, Any import torch import time import io import base64 import zipfile from datetime import datetime # Dynamic installation of PyTorch Geometric dependencies def install_torch_geometric_deps(): """Install PyTorch Geometric dependencies at runtime to avoid compilation issues during Hugging Face Spaces build""" import subprocess import sys # Check if torch-scatter is already installed try: import torch_scatter print("โ torch-scatter already installed") return True except ImportError: print("๐ Installing torch-scatter and related packages...") # Get PyTorch version and CUDA info torch_version = torch.__version__ torch_version_str = '+'.join(torch_version.split('+')[:1]) # Remove CUDA info # Use PyTorch Geometric official recommended installation method try: # For CPU version, use official CPU wheel pip_cmd = [ sys.executable, "-m", "pip", "install", "torch-scatter", "torch-sparse", "torch-cluster", "torch-spline-conv", "-f", f"https://data.pyg.org/whl/torch-{torch_version_str}+cpu.html", "--no-cache-dir" ] print(f"Running: {' '.join(pip_cmd)}") result = subprocess.run(pip_cmd, capture_output=True, text=True, timeout=300) if result.returncode == 0: print("โ Successfully installed torch-scatter and related packages") return True else: print(f"โ Failed to install packages: {result.stderr}") # Try simplified installation method try: simple_cmd = [sys.executable, "-m", "pip", "install", "torch-scatter", "--no-cache-dir"] result = subprocess.run(simple_cmd, capture_output=True, text=True, timeout=300) if result.returncode == 0: print("โ Successfully installed torch-scatter with simple method") return True else: print(f"โ Simple install also failed: {result.stderr}") return False except Exception as e: print(f"โ Exception during simple install: {e}") return False except subprocess.TimeoutExpired: print("โ Installation timeout - packages may not be available") return False except Exception as e: print(f"โ Exception during installation: {e}") return False # Try to install PyTorch Geometric dependencies deps_installed = install_torch_geometric_deps() if not deps_installed: print("โ ๏ธ Warning: PyTorch Geometric dependencies not installed. Some features may not work.") print("The application will try to continue with limited functionality.") # Set up paths and imports for different deployment environments import sys BASE_DIR = Path(__file__).parent # Smart import handling for different environments def setup_imports(): """Smart import setup for different deployment environments""" global AntigenChain, PROJECT_BASE_DIR # Method 1: Try importing from src directory (local development) if (BASE_DIR / "src").exists(): sys.path.insert(0, str(BASE_DIR)) try: from src.bce.antigen.antigen import AntigenChain from src.bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR print("โ Successfully imported from src/ directory") return True except ImportError as e: print(f"โ Failed to import from src/: {e}") # Method 2: Try adding src to path and direct import (Hugging Face Spaces) src_path = BASE_DIR / "src" if src_path.exists(): sys.path.insert(0, str(src_path)) try: from bce.antigen.antigen import AntigenChain from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR print("โ Successfully imported from src/ added to path") return True except ImportError as e: print(f"โ Failed to import with src/ in path: {e}") # Method 3: Try direct import (if package is installed) try: from bce.antigen.antigen import AntigenChain from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR print("โ Successfully imported from installed package") return True except ImportError as e: print(f"โ Failed to import from installed package: {e}") # If all methods fail, use default settings print("โ ๏ธ All import methods failed, using fallback settings") PROJECT_BASE_DIR = BASE_DIR return False # Execute import setup import_success = setup_imports() if not import_success: print("โ Critical: Could not import BCE modules. Please check the file structure.") print("Expected structure:") print("- src/bce/antigen/antigen.py") print("- src/bce/utils/constants.py") print("- src/bce/model/ReCEP.py") print("- src/bce/data/utils.py") sys.exit(1) # Configuration DEFAULT_MODEL_PATH = os.getenv("BCE_MODEL_PATH", str(PROJECT_BASE_DIR / "models" / "ReCEP" / "20250626_110438" / "best_mcc_model.bin")) ESM_TOKEN = os.getenv("ESM_TOKEN", "1mzAo8l1uxaU8UfVcGgV7B") # PDB data directory PDB_DATA_DIR = PROJECT_BASE_DIR / "data" / "pdb" PDB_DATA_DIR.mkdir(parents=True, exist_ok=True) def validate_pdb_id(pdb_id: str) -> bool: """Validate PDB ID format""" if not pdb_id or len(pdb_id) != 4: return False return pdb_id.isalnum() def validate_chain_id(chain_id: str) -> bool: """Validate chain ID format""" if not chain_id or len(chain_id) != 1: return False return chain_id.isalnum() def create_pdb_visualization_html(pdb_data: str, predicted_epitopes: list, predictions: dict, protein_id: str, top_k_regions: list = None) -> str: """Create HTML with 3Dmol.js visualization compatible with Gradio - enhanced version with more features""" # Prepare data for JavaScript epitope_residues = predicted_epitopes # Process top_k_regions for visualization processed_regions = [] if top_k_regions: for i, region in enumerate(top_k_regions): if isinstance(region, dict): processed_regions.append({ 'center_idx': region.get('center_idx', 0), 'center_residue': region.get('center_residue', region.get('center_idx', 0)), 'covered_residues': region.get('covered_residues', region.get('covered_indices', [])), 'radius': 19.0, # Default radius 'predicted_value': region.get('graph_pred', 0.0) }) # Create a unique ID for this visualization to avoid conflicts import uuid viewer_id = f"viewer_{uuid.uuid4().hex[:8]}" html_content = f"""
""" return html_content def predict_epitopes(pdb_id: str, pdb_file, chain_id: str, radius: float, k: int, encoder: str, device_config: str, use_threshold: bool, threshold: float, auto_cleanup: bool, progress: gr.Progress = None) -> Tuple[str, str, str, str, str, str]: """ Main prediction function that handles the epitope prediction workflow """ try: # Input validation if not pdb_file and not pdb_id: return "Error: Please provide either a PDB ID or upload a PDB file", "", "", "", "", "" if pdb_id and not validate_pdb_id(pdb_id): return "Error: PDB ID must be exactly 4 characters (letters and numbers)", "", "", "", "", "" if not validate_chain_id(chain_id): return "Error: Chain ID must be exactly 1 character", "", "", "", "", "" # Update progress if progress: progress(0.1, desc="Initializing prediction...") # Process device configuration device_id = -1 if device_config == "CPU Only" else int(device_config.split(" ")[1]) use_gpu = device_id >= 0 # Load protein structure if progress: progress(0.2, desc="Loading protein structure...") antigen_chain = None temp_file_path = None try: if pdb_file: # Handle uploaded file if progress: progress(0.25, desc="Processing uploaded PDB file...") # Debug: print type and attributes of pdb_file print(f"๐ Debug: pdb_file type = {type(pdb_file)}") print(f"๐ Debug: pdb_file attributes = {dir(pdb_file)}") # Extract PDB ID from filename if not provided if not pdb_id: if hasattr(pdb_file, 'name'): pdb_id = Path(pdb_file.name).stem.split('_')[0][:4] else: pdb_id = "UNKN" # Default fallback # Save uploaded file to data/pdb/ directory with proper naming timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{pdb_id}_{chain_id}_{timestamp}.pdb" temp_file_path = PDB_DATA_DIR / filename # Properly read and write the uploaded file try: if hasattr(pdb_file, 'name') and os.path.isfile(pdb_file.name): # pdb_file is a file object with .name attribute print(f"๐ Processing file object: {pdb_file.name}") with open(pdb_file.name, "rb") as src: with open(temp_file_path, "wb") as dst: dst.write(src.read()) elif hasattr(pdb_file, 'read'): # pdb_file is a file-like object print(f"๐ Processing file-like object") with open(temp_file_path, "wb") as f: f.write(pdb_file.read()) else: # pdb_file is a string (file path) print(f"๐ Processing file path: {pdb_file}") with open(str(pdb_file), "rb") as src: with open(temp_file_path, "wb") as dst: dst.write(src.read()) print(f"โ PDB file saved to: {temp_file_path}") except Exception as file_error: print(f"โ Error processing uploaded file: {file_error}") return f"Error processing uploaded file: {str(file_error)}", "", "", "", "", "" antigen_chain = AntigenChain.from_pdb( path=str(temp_file_path), chain_id=chain_id, id=pdb_id ) else: # Load from PDB ID if progress: progress(0.25, desc=f"Downloading PDB structure {pdb_id}...") antigen_chain = AntigenChain.from_pdb( chain_id=chain_id, id=pdb_id ) except Exception as e: return f"Error loading protein structure: {str(e)}", "", "", "", "", "" if antigen_chain is None: return "Error: Failed to load protein structure", "", "", "", "", "" # Run prediction if progress: progress(0.4, desc="Running epitope prediction...") try: # Use threshold only if checkbox is checked final_threshold = threshold if use_threshold else None predict_results = antigen_chain.predict( model_path=DEFAULT_MODEL_PATH, device_id=device_id, radius=radius, k=k, threshold=final_threshold, verbose=True, encoder=encoder, use_gpu=use_gpu, auto_cleanup=auto_cleanup ) except Exception as e: error_msg = f"Error during prediction: {str(e)}" print(f"Prediction error: {error_msg}") import traceback traceback.print_exc() return error_msg, "", "", "", "", "" if progress: progress(0.8, desc="Processing results...") # Process results if not predict_results: return "Error: No prediction results generated", "", "", "", "", "" # Extract prediction data predicted_epitopes = predict_results.get("predicted_epitopes", []) predictions = predict_results.get("predictions", {}) top_k_centers = predict_results.get("top_k_centers", []) top_k_region_residues = predict_results.get("top_k_region_residues", []) top_k_regions = predict_results.get("top_k_regions", []) # Calculate summary statistics protein_length = len(antigen_chain.sequence) epitope_count = len(predicted_epitopes) region_count = len(top_k_regions) coverage_rate = (len(top_k_region_residues) / protein_length) * 100 if protein_length > 0 else 0 # Create summary text summary_text = f""" ## Prediction Results for {pdb_id}_{chain_id} ### Protein Information - **PDB ID**: {pdb_id} - **Chain**: {chain_id} - **Length**: {protein_length} residues - **Sequence**:Predict epitopes using the ReCEP model
You can download the HTML to visualize the prediction results and the spheres used.
ยฉ 2024 B-cell Epitope Prediction Server | Powered by ReCEP model
Features: PDB ID/File support โข 3D visualization โข Multiple export formats