import gradio as gr import os import json import tempfile import traceback import numpy as np import pandas as pd from pathlib import Path from typing import Optional, Tuple, Dict, Any import torch import time import io import base64 import zipfile from datetime import datetime # Dynamic installation of PyTorch Geometric dependencies def install_torch_geometric_deps(): """Install PyTorch Geometric dependencies at runtime to avoid compilation issues during Hugging Face Spaces build""" import subprocess import sys # Check if torch-scatter is already installed try: import torch_scatter print("โœ… torch-scatter already installed") return True except ImportError: print("๐Ÿ”„ Installing torch-scatter and related packages...") # Get PyTorch version and CUDA info torch_version = torch.__version__ torch_version_str = '+'.join(torch_version.split('+')[:1]) # Remove CUDA info # Use PyTorch Geometric official recommended installation method try: # For CPU version, use official CPU wheel pip_cmd = [ sys.executable, "-m", "pip", "install", "torch-scatter", "torch-sparse", "torch-cluster", "torch-spline-conv", "-f", f"https://data.pyg.org/whl/torch-{torch_version_str}+cpu.html", "--no-cache-dir" ] print(f"Running: {' '.join(pip_cmd)}") result = subprocess.run(pip_cmd, capture_output=True, text=True, timeout=300) if result.returncode == 0: print("โœ… Successfully installed torch-scatter and related packages") return True else: print(f"โŒ Failed to install packages: {result.stderr}") # Try simplified installation method try: simple_cmd = [sys.executable, "-m", "pip", "install", "torch-scatter", "--no-cache-dir"] result = subprocess.run(simple_cmd, capture_output=True, text=True, timeout=300) if result.returncode == 0: print("โœ… Successfully installed torch-scatter with simple method") return True else: print(f"โŒ Simple install also failed: {result.stderr}") return False except Exception as e: print(f"โŒ Exception during simple install: {e}") return False except subprocess.TimeoutExpired: print("โŒ Installation timeout - packages may not be available") return False except Exception as e: print(f"โŒ Exception during installation: {e}") return False # Try to install PyTorch Geometric dependencies deps_installed = install_torch_geometric_deps() if not deps_installed: print("โš ๏ธ Warning: PyTorch Geometric dependencies not installed. Some features may not work.") print("The application will try to continue with limited functionality.") # Set up paths and imports for different deployment environments import sys BASE_DIR = Path(__file__).parent # Smart import handling for different environments def setup_imports(): """Smart import setup for different deployment environments""" global AntigenChain, PROJECT_BASE_DIR # Method 1: Try importing from src directory (local development) if (BASE_DIR / "src").exists(): sys.path.insert(0, str(BASE_DIR)) try: from src.bce.antigen.antigen import AntigenChain from src.bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR print("โœ… Successfully imported from src/ directory") return True except ImportError as e: print(f"โŒ Failed to import from src/: {e}") # Method 2: Try adding src to path and direct import (Hugging Face Spaces) src_path = BASE_DIR / "src" if src_path.exists(): sys.path.insert(0, str(src_path)) try: from bce.antigen.antigen import AntigenChain from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR print("โœ… Successfully imported from src/ added to path") return True except ImportError as e: print(f"โŒ Failed to import with src/ in path: {e}") # Method 3: Try direct import (if package is installed) try: from bce.antigen.antigen import AntigenChain from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR print("โœ… Successfully imported from installed package") return True except ImportError as e: print(f"โŒ Failed to import from installed package: {e}") # If all methods fail, use default settings print("โš ๏ธ All import methods failed, using fallback settings") PROJECT_BASE_DIR = BASE_DIR return False # Execute import setup import_success = setup_imports() if not import_success: print("โŒ Critical: Could not import BCE modules. Please check the file structure.") print("Expected structure:") print("- src/bce/antigen/antigen.py") print("- src/bce/utils/constants.py") print("- src/bce/model/ReCEP.py") print("- src/bce/data/utils.py") sys.exit(1) # Configuration DEFAULT_MODEL_PATH = os.getenv("BCE_MODEL_PATH", str(PROJECT_BASE_DIR / "models" / "ReCEP" / "20250626_110438" / "best_mcc_model.bin")) ESM_TOKEN = os.getenv("ESM_TOKEN", "1mzAo8l1uxaU8UfVcGgV7B") # PDB data directory PDB_DATA_DIR = PROJECT_BASE_DIR / "data" / "pdb" PDB_DATA_DIR.mkdir(parents=True, exist_ok=True) def validate_pdb_id(pdb_id: str) -> bool: """Validate PDB ID format""" if not pdb_id or len(pdb_id) != 4: return False return pdb_id.isalnum() def validate_chain_id(chain_id: str) -> bool: """Validate chain ID format""" if not chain_id or len(chain_id) != 1: return False return chain_id.isalnum() def create_pdb_visualization_html(pdb_data: str, predicted_epitopes: list, predictions: dict, protein_id: str, top_k_regions: list = None) -> str: """Create HTML with 3Dmol.js visualization compatible with Gradio - enhanced version with more features""" # Prepare data for JavaScript epitope_residues = predicted_epitopes # Process top_k_regions for visualization processed_regions = [] if top_k_regions: for i, region in enumerate(top_k_regions): if isinstance(region, dict): processed_regions.append({ 'center_idx': region.get('center_idx', 0), 'center_residue': region.get('center_residue', region.get('center_idx', 0)), 'covered_residues': region.get('covered_residues', region.get('covered_indices', [])), 'radius': 19.0, # Default radius 'predicted_value': region.get('graph_pred', 0.0) }) # Create a unique ID for this visualization to avoid conflicts import uuid viewer_id = f"viewer_{uuid.uuid4().hex[:8]}" html_content = f"""

3D Structure Visualization - {protein_id}

Loading 3Dmol.js...

""" return html_content def predict_epitopes(pdb_id: str, pdb_file, chain_id: str, radius: float, k: int, encoder: str, device_config: str, use_threshold: bool, threshold: float, auto_cleanup: bool, progress: gr.Progress = None) -> Tuple[str, str, str, str, str, str]: """ Main prediction function that handles the epitope prediction workflow """ try: # Input validation if not pdb_file and not pdb_id: return "Error: Please provide either a PDB ID or upload a PDB file", "", "", "", "", "" if pdb_id and not validate_pdb_id(pdb_id): return "Error: PDB ID must be exactly 4 characters (letters and numbers)", "", "", "", "", "" if not validate_chain_id(chain_id): return "Error: Chain ID must be exactly 1 character", "", "", "", "", "" # Update progress if progress: progress(0.1, desc="Initializing prediction...") # Process device configuration device_id = -1 if device_config == "CPU Only" else int(device_config.split(" ")[1]) use_gpu = device_id >= 0 # Load protein structure if progress: progress(0.2, desc="Loading protein structure...") antigen_chain = None temp_file_path = None try: if pdb_file: # Handle uploaded file if progress: progress(0.25, desc="Processing uploaded PDB file...") # Debug: print type and attributes of pdb_file print(f"๐Ÿ” Debug: pdb_file type = {type(pdb_file)}") print(f"๐Ÿ” Debug: pdb_file attributes = {dir(pdb_file)}") # Extract PDB ID from filename if not provided if not pdb_id: if hasattr(pdb_file, 'name'): pdb_id = Path(pdb_file.name).stem.split('_')[0][:4] else: pdb_id = "UNKN" # Default fallback # Save uploaded file to data/pdb/ directory with proper naming timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{pdb_id}_{chain_id}_{timestamp}.pdb" temp_file_path = PDB_DATA_DIR / filename # Properly read and write the uploaded file try: if hasattr(pdb_file, 'name') and os.path.isfile(pdb_file.name): # pdb_file is a file object with .name attribute print(f"๐Ÿ“ Processing file object: {pdb_file.name}") with open(pdb_file.name, "rb") as src: with open(temp_file_path, "wb") as dst: dst.write(src.read()) elif hasattr(pdb_file, 'read'): # pdb_file is a file-like object print(f"๐Ÿ“„ Processing file-like object") with open(temp_file_path, "wb") as f: f.write(pdb_file.read()) else: # pdb_file is a string (file path) print(f"๐Ÿ“ Processing file path: {pdb_file}") with open(str(pdb_file), "rb") as src: with open(temp_file_path, "wb") as dst: dst.write(src.read()) print(f"โœ… PDB file saved to: {temp_file_path}") except Exception as file_error: print(f"โŒ Error processing uploaded file: {file_error}") return f"Error processing uploaded file: {str(file_error)}", "", "", "", "", "" antigen_chain = AntigenChain.from_pdb( path=str(temp_file_path), chain_id=chain_id, id=pdb_id ) else: # Load from PDB ID if progress: progress(0.25, desc=f"Downloading PDB structure {pdb_id}...") antigen_chain = AntigenChain.from_pdb( chain_id=chain_id, id=pdb_id ) except Exception as e: return f"Error loading protein structure: {str(e)}", "", "", "", "", "" if antigen_chain is None: return "Error: Failed to load protein structure", "", "", "", "", "" # Run prediction if progress: progress(0.4, desc="Running epitope prediction...") try: # Use threshold only if checkbox is checked final_threshold = threshold if use_threshold else None predict_results = antigen_chain.predict( model_path=DEFAULT_MODEL_PATH, device_id=device_id, radius=radius, k=k, threshold=final_threshold, verbose=True, encoder=encoder, use_gpu=use_gpu, auto_cleanup=auto_cleanup ) except Exception as e: error_msg = f"Error during prediction: {str(e)}" print(f"Prediction error: {error_msg}") import traceback traceback.print_exc() return error_msg, "", "", "", "", "" if progress: progress(0.8, desc="Processing results...") # Process results if not predict_results: return "Error: No prediction results generated", "", "", "", "", "" # Extract prediction data predicted_epitopes = predict_results.get("predicted_epitopes", []) predictions = predict_results.get("predictions", {}) top_k_centers = predict_results.get("top_k_centers", []) top_k_region_residues = predict_results.get("top_k_region_residues", []) top_k_regions = predict_results.get("top_k_regions", []) # Calculate summary statistics protein_length = len(antigen_chain.sequence) epitope_count = len(predicted_epitopes) region_count = len(top_k_regions) coverage_rate = (len(top_k_region_residues) / protein_length) * 100 if protein_length > 0 else 0 # Create summary text summary_text = f""" ## Prediction Results for {pdb_id}_{chain_id} ### Protein Information - **PDB ID**: {pdb_id} - **Chain**: {chain_id} - **Length**: {protein_length} residues - **Sequence**:
{antigen_chain.sequence}
### Prediction Summary - **Predicted Epitopes**: {epitope_count} - **Top-k Regions**: {region_count} - **Coverage Rate**: {coverage_rate:.1f}% ### Top-k Region Centers {', '.join(map(str, top_k_centers))} ### Predicted Epitope Residues {', '.join(map(str, predicted_epitopes))} ### Binding Region Residues (Top-k Union) {', '.join(map(str, top_k_region_residues))} """ # Create epitope list text with residue names epitope_text = f"Predicted Epitope Residues ({len(predicted_epitopes)}):\n" epitope_lines = [] for res in predicted_epitopes: # Get residue index from residue number if res in antigen_chain.resnum_to_index: res_idx = antigen_chain.resnum_to_index[res] res_name = antigen_chain.sequence[res_idx] epitope_lines.append(f"Residue {res} ({res_name})") else: epitope_lines.append(f"Residue {res}") epitope_text += "\n".join(epitope_lines) # Create binding region text with residue names binding_text = f"Binding Region Residues ({len(top_k_region_residues)}):\n" binding_lines = [] for res in top_k_region_residues: # Get residue index from residue number if res in antigen_chain.resnum_to_index: res_idx = antigen_chain.resnum_to_index[res] res_name = antigen_chain.sequence[res_idx] binding_lines.append(f"Residue {res} ({res_name})") else: binding_lines.append(f"Residue {res}") binding_text += "\n".join(binding_lines) # Create downloadable files if progress: progress(0.9, desc="Preparing download files...") # JSON file json_data = { "protein_info": { "id": pdb_id, "chain_id": chain_id, "length": protein_length, "sequence": antigen_chain.sequence }, "prediction": { "predicted_epitopes": predicted_epitopes, "predictions": predictions, "top_k_centers": top_k_centers, "top_k_region_residues": top_k_region_residues, "top_k_regions": [ { "center_idx": region.get('center_idx', 0), "graph_pred": region.get('graph_pred', 0), "covered_indices": region.get('covered_indices', []) } for region in top_k_regions ], "coverage_rate": coverage_rate, "mean_region_value": 0 # No longer calculated }, "parameters": { "radius": radius, "k": k, "encoder": encoder, "device_config": device_config, "use_threshold": use_threshold, "threshold": final_threshold, "auto_cleanup": auto_cleanup } } # Save JSON file json_file_path = tempfile.mktemp(suffix=".json") with open(json_file_path, "w") as f: json.dump(json_data, f, indent=2) # CSV file csv_data = [] for i, residue_num in enumerate(antigen_chain.residue_index): residue_num = int(residue_num) csv_data.append({ "Residue_Number": residue_num, "Residue_Type": antigen_chain.sequence[i], "Prediction_Probability": predictions.get(residue_num, 0.0), "Is_Predicted_Epitope": 1 if residue_num in predicted_epitopes else 0, "Is_In_TopK_Regions": 1 if residue_num in top_k_region_residues else 0 }) csv_df = pd.DataFrame(csv_data) csv_file_path = tempfile.mktemp(suffix=".csv") csv_df.to_csv(csv_file_path, index=False) # Create 3D visualization if progress: progress(0.95, desc="Creating 3D visualization...") # Generate PDB string for visualization HTML file html_file_path = None try: pdb_str = generate_pdb_string(antigen_chain) html_content = create_pdb_visualization_html( pdb_str, predicted_epitopes, predictions, f"{pdb_id}_{chain_id}", top_k_regions ) # Save HTML file to data directory for download timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") html_filename = f"{pdb_id}_{chain_id}_visualization_{timestamp}.html" html_file_path = PDB_DATA_DIR / html_filename with open(html_file_path, "w", encoding='utf-8') as f: f.write(html_content) print(f"โœ… 3D visualization HTML saved to: {html_file_path}") except Exception as e: html_file_path = None print(f"Warning: Could not create 3D visualization: {str(e)}") # Clean up temporary files if auto_cleanup is enabled if auto_cleanup and temp_file_path and os.path.exists(temp_file_path): os.remove(temp_file_path) print(f"๐Ÿงน Cleaned up temporary file: {temp_file_path}") elif temp_file_path and os.path.exists(temp_file_path): print(f"๐Ÿ“ PDB file retained at: {temp_file_path}") if progress: progress(1.0, desc="Prediction completed!") # Return all results including HTML file path for download return ( summary_text, epitope_text, binding_text, str(html_file_path) if html_file_path else None, # HTML file moved to 4th position json_file_path, csv_file_path ) except Exception as e: import traceback error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" return error_msg, "", "", "", "", "" def generate_pdb_string(antigen_chain) -> str: """Generate PDB string for 3D visualization""" from esm.utils import residue_constants as RC pdb_str = "MODEL 1\n" atom_num = 1 for res_idx in range(len(antigen_chain.sequence)): one_letter = antigen_chain.sequence[res_idx] resname = antigen_chain.convert_letter_1to3(one_letter) resnum = antigen_chain.residue_index[res_idx] mask = antigen_chain.atom37_mask[res_idx] coords = antigen_chain.atom37_positions[res_idx][mask] atoms = [name for name, exists in zip(RC.atom_types, mask) if exists] for atom_name, coord in zip(atoms, coords): x, y, z = coord pdb_str += (f"ATOM {atom_num:5d} {atom_name:<3s} {resname:>3s} {antigen_chain.chain_id:1s}{resnum:4d}" f" {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n") atom_num += 1 pdb_str += "ENDMDL\n" return pdb_str def create_interface(): """Create the Gradio interface""" with gr.Blocks(css=""" .container { max-width: 1200px; margin: 0 auto; padding: 20px; } .header { text-align: center; margin-bottom: 30px; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; } .header h1 { font-size: 2.5em; margin-bottom: 10px; } .form-row { display: flex; gap: 20px; align-items: end; } .form-row > * { flex: 1; } .section { margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 8px; border-left: 4px solid #007bff; } .section h2 { color: #333; margin-bottom: 15px; } .results-section { margin-top: 30px; padding: 20px; background: #f0f8ff; border-radius: 8px; border: 1px solid #e0e8f0; } .download-section { margin-top: 20px; padding: 15px; background: #f9f9f9; border-radius: 8px; } .download-section h3 { color: #333; margin-bottom: 10px; } """) as interface: # Header gr.HTML("""

๐Ÿงฌ B-cell Epitope Prediction Server

Predict epitopes using the ReCEP model

""") with gr.Row(): with gr.Column(scale=1): gr.HTML("

๐Ÿ“‹ Input Protein Structure

") input_method = gr.Radio( choices=["PDB ID", "Upload PDB File"], value="PDB ID", label="Input Method" ) pdb_id = gr.Textbox( label="PDB ID", placeholder="e.g., 5I9Q", max_lines=1, visible=True ) pdb_file = gr.File( label="Upload PDB File", file_types=[".pdb", ".ent"], visible=False ) chain_id = gr.Textbox( label="Chain ID", value="A", max_lines=1 ) with gr.Accordion("๐Ÿ”ง Advanced Parameters", open=False): radius = gr.Slider( label="Radius (ร…)", minimum=1.0, maximum=50.0, step=0.1, value=19.0 ) k = gr.Slider( label="Top-k Regions", minimum=1, maximum=20, step=1, value=7 ) encoder = gr.Dropdown( label="Encoder", choices=["esmc", "esm2"], value="esmc" ) device_config = gr.Dropdown( label="Device Configuration", choices=["CPU Only", "GPU 0", "GPU 1", "GPU 2", "GPU 3"], value="CPU Only" ) use_threshold = gr.Checkbox( label="Use Custom Threshold", value=False ) threshold = gr.Number( label="Threshold Value", value=0.366, visible=False ) auto_cleanup = gr.Checkbox( label="Auto-cleanup Generated Data", value=True ) predict_btn = gr.Button("๐Ÿงฎ Predict Epitopes", variant="primary", size="lg") with gr.Column(scale=2): gr.HTML("

๐Ÿ“Š Prediction Results

") # 3D Visualization download (moved to top) gr.HTML("

๐Ÿงฌ 3D Visualization

You can download the HTML to visualize the prediction results and the spheres used.

") html_download = gr.File( label="Download Interactive 3D Visualization HTML", visible=True ) results_text = gr.Markdown(label="Prediction Summary", visible=True) with gr.Row(): epitope_list = gr.Textbox( label="Predicted Epitope Residues", max_lines=10, interactive=False ) binding_regions = gr.Textbox( label="Binding Region Residues", max_lines=10, interactive=False ) gr.HTML("

๐Ÿ“ฅ Download Data Results

") with gr.Row(): json_download = gr.File( label="JSON Results", visible=True ) csv_download = gr.File( label="CSV Results", visible=True ) def toggle_input_method(method): return (gr.update(visible=method == "PDB ID"), gr.update(visible=method == "Upload PDB File")) def toggle_threshold(use_threshold): return gr.update(visible=use_threshold) input_method.change(toggle_input_method, inputs=[input_method], outputs=[pdb_id, pdb_file]) use_threshold.change(toggle_threshold, inputs=[use_threshold], outputs=[threshold]) predict_btn.click( predict_epitopes, inputs=[ pdb_id, pdb_file, chain_id, radius, k, encoder, device_config, use_threshold, threshold, auto_cleanup ], outputs=[ results_text, epitope_list, binding_regions, html_download, json_download, csv_download ], show_progress=True ) gr.HTML("""

ยฉ 2024 B-cell Epitope Prediction Server | Powered by ReCEP model

Features: PDB ID/File support โ€ข 3D visualization โ€ข Multiple export formats

""") return interface if __name__ == "__main__": # Create and launch the interface try: interface = create_interface() # Check if running on Hugging Face Spaces is_spaces = os.getenv("SPACE_ID") is not None interface.launch( server_name="0.0.0.0", server_port=7860, share=is_spaces, # Use share=True on Spaces, False locally show_error=True, max_threads=4 if is_spaces else 8 ) except Exception as e: print(f"Error launching application: {e}") print("Please ensure all dependencies are installed correctly.") import traceback traceback.print_exc()