Spaces:
Running
Running
| import gradio as gr | |
| from model_loader import load_model | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import re | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import copy | |
| import transformers, datasets | |
| from transformers import AutoTokenizer | |
| from transformers import DataCollatorForTokenClassification | |
| from datasets import Dataset | |
| from scipy.special import expit | |
| import requests | |
| # Biopython imports | |
| from Bio.PDB import PDBParser, Select | |
| from Bio.PDB.DSSP import DSSP | |
| from gradio_molecule3d import Molecule3D | |
| # Configuration | |
| checkpoint = 'ThorbenF/prot_t5_xl_uniref50' | |
| max_length = 1500 | |
| # Load model and move to device | |
| model, tokenizer = load_model(checkpoint, max_length) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| model.eval() | |
| def create_dataset(tokenizer, seqs, labels, checkpoint): | |
| tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True) | |
| dataset = Dataset.from_dict(tokenized) | |
| # Adjust labels based on checkpoint | |
| if ("esm" in checkpoint) or ("ProstT5" in checkpoint): | |
| labels = [l[:max_length-2] for l in labels] | |
| else: | |
| labels = [l[:max_length-1] for l in labels] | |
| dataset = dataset.add_column("labels", labels) | |
| return dataset | |
| def convert_predictions(input_logits): | |
| all_probs = [] | |
| for logits in input_logits: | |
| logits = logits.reshape(-1, 2) | |
| probabilities_class1 = expit(logits[:, 1] - logits[:, 0]) | |
| all_probs.append(probabilities_class1) | |
| return np.concatenate(all_probs) | |
| def normalize_scores(scores): | |
| min_score = np.min(scores) | |
| max_score = np.max(scores) | |
| return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores | |
| def predict_protein_sequence(test_one_letter_sequence): | |
| # Sanitize input sequence | |
| test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \ | |
| .replace("B", "X").replace("U", "X") \ | |
| .replace("Z", "X").replace("J", "X") | |
| # Prepare sequence for different model types | |
| if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint): | |
| test_one_letter_sequence = " ".join(test_one_letter_sequence) | |
| if "ProstT5" in checkpoint: | |
| test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence | |
| # Create dummy labels | |
| dummy_labels = [np.zeros(len(test_one_letter_sequence))] | |
| # Create dataset | |
| test_dataset = create_dataset(tokenizer, | |
| [test_one_letter_sequence], | |
| dummy_labels, | |
| checkpoint) | |
| # Select appropriate data collator | |
| data_collator = (DataCollatorForTokenClassification(tokenizer) | |
| if "esm" not in checkpoint and "ProstT5" not in checkpoint | |
| else DataCollatorForTokenClassification(tokenizer)) | |
| # Create data loader | |
| test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator) | |
| # Predict | |
| for batch in test_loader: | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits.detach().cpu().numpy() | |
| # Process logits | |
| logits = logits[:, :-1] # Remove last element for prot_t5 | |
| logits = convert_predictions(logits) | |
| # Normalize and format results | |
| normalized_scores = normalize_scores(logits) | |
| test_one_letter_sequence = test_one_letter_sequence.replace(" ", "") | |
| return test_one_letter_sequence, normalized_scores | |
| def fetch_pdb(pdb_id): | |
| try: | |
| # Create a directory to store PDB files if it doesn't exist | |
| os.makedirs('pdb_files', exist_ok=True) | |
| # Fetch the PDB structure from RCSB | |
| pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' | |
| pdb_path = f'pdb_files/{pdb_id}.pdb' | |
| # Download the file | |
| response = requests.get(pdb_url) | |
| if response.status_code == 200: | |
| with open(pdb_path, 'wb') as f: | |
| f.write(response.content) | |
| return pdb_path | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Error fetching PDB: {e}") | |
| return None | |
| def extract_protein_sequence(pdb_path): | |
| """ | |
| Extract the longest protein sequence from a PDB file | |
| """ | |
| parser = PDBParser(QUIET=1) | |
| structure = parser.get_structure('protein', pdb_path) | |
| class ProteinSelect(Select): | |
| def accept_residue(self, residue): | |
| # Only accept standard amino acids | |
| standard_aa = set('ACDEFGHIKLMNPQRSTVWY') | |
| return residue.get_resname() in standard_aa | |
| # Find the longest protein chain | |
| longest_sequence = "" | |
| longest_chain = None | |
| for model in structure: | |
| for chain in model: | |
| sequence = "" | |
| for residue in chain: | |
| if Select().accept_residue(residue): | |
| sequence += residue.get_resname() | |
| # Convert 3-letter amino acid codes to 1-letter | |
| aa_dict = { | |
| 'ALA':'A', 'CYS':'C', 'ASP':'D', 'GLU':'E', 'PHE':'F', | |
| 'GLY':'G', 'HIS':'H', 'ILE':'I', 'LYS':'K', 'LEU':'L', | |
| 'MET':'M', 'ASN':'N', 'PRO':'P', 'GLN':'Q', 'ARG':'R', | |
| 'SER':'S', 'THR':'T', 'VAL':'V', 'TRP':'W', 'TYR':'Y' | |
| } | |
| one_letter_sequence = ''.join([aa_dict.get(res, 'X') for res in sequence]) | |
| # Track the longest sequence | |
| if len(one_letter_sequence) > len(longest_sequence) and \ | |
| 10 < len(one_letter_sequence) < 1500: | |
| longest_sequence = one_letter_sequence | |
| longest_chain = chain | |
| return longest_sequence, longest_chain | |
| def process_pdb(pdb_id): | |
| # Fetch PDB file | |
| pdb_path = fetch_pdb(pdb_id) | |
| if not pdb_path: | |
| return "Failed to fetch PDB file", None | |
| # Extract protein sequence and chain | |
| protein_sequence, chain = extract_protein_sequence(pdb_path) | |
| if not protein_sequence: | |
| return "No suitable protein sequence found", None | |
| # Predict binding sites | |
| sequence, normalized_scores = predict_protein_sequence(protein_sequence) | |
| # Prepare representations for coloring residues | |
| reps = [ | |
| { | |
| "model": 0, | |
| "chain": chain.id, | |
| "resname": res, | |
| "resnum": i+1, | |
| "style": "cartoon", | |
| "color": f'rgb({int(score * 255)}, 0, {int(255 - score * 255)})', | |
| "residue_range": f"{i+1}-{i+1}", | |
| "around": 0, | |
| "byres": True, | |
| "visible": True | |
| } | |
| for i, (res, score) in enumerate(zip(sequence, normalized_scores)) | |
| ] | |
| # Prepare result string | |
| result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(sequence, normalized_scores)]) | |
| return result_str, {"root": [pdb_path]} | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Protein Binding Site Prediction") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # PDB ID input with default suggestion | |
| pdb_input = gr.Textbox( | |
| value="2IWI", | |
| label="PDB ID", | |
| placeholder="Enter PDB ID here..." | |
| ) | |
| # Predict button | |
| predict_btn = gr.Button("Predict Binding Sites") | |
| with gr.Column(): | |
| # Binding site predictions output | |
| predictions_output = gr.Textbox( | |
| label="Binding Site Predictions" | |
| ) | |
| # 3D Molecule visualization | |
| molecule_output = Molecule3D( | |
| label="Protein Structure", | |
| reps=[] # Start with empty representations | |
| ) | |
| # Prediction logic | |
| predict_btn.click( | |
| process_pdb, | |
| inputs=[pdb_input], | |
| outputs=[predictions_output, molecule_output] | |
| ) | |
| # Add some example inputs | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["2IWI"], | |
| ["7RPZ"], | |
| ["3TJN"] | |
| ], | |
| inputs=[pdb_input], | |
| outputs=[predictions_output, molecule_output] | |
| ) | |
| demo.launch(share=True) |