Spaces:
Running
Running
File size: 7,390 Bytes
85c36de 942bf87 51a3749 ea9a1bf e199881 51a3749 f0f9b27 51a3749 68ded6f e199881 8bc43cc 942bf87 68ded6f f0f9b27 68ded6f f0f9b27 68ded6f 248a61c e199881 11e1095 dc9275e 68ded6f 3b84715 aa6838a f0f9b27 68ded6f c63f76d 68ded6f bee2eef 68ded6f 8319384 bee2eef 8319384 f3b700a f0f9b27 68ded6f f0f9b27 9748994 68ded6f 0c1f1e9 68ded6f 0c1f1e9 f0f9b27 0c1f1e9 48574c5 0c1f1e9 68ded6f 0c1f1e9 68ded6f 48574c5 68ded6f 0c1f1e9 68ded6f 0c1f1e9 48574c5 68ded6f 357b75d e9f8ebb 48574c5 357b75d 68ded6f e9f8ebb 48574c5 68ded6f e9f8ebb 48574c5 68ded6f 48574c5 68ded6f 48574c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
from sklearn.preprocessing import MinMaxScaler
import torch
from transformers import BertTokenizer, BertModel
from math import expm1
# Load AMP Classifier
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")
# Load ProtBert Globally
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protbert_model = protbert_model.to(device).eval()
# Selected Features
selected_features = [
"_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001",
"_PolarizabilityD2001", "_PolarizabilityD3001", "_SolventAccessibilityD1001",
"_SolventAccessibilityD2001", "_SolventAccessibilityD3001", "_SecondaryStrD1001",
"_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001",
"_PolarityD1050", "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001",
"_NormalizedVDWVD2001", "_NormalizedVDWVD2025", "_NormalizedVDWVD2050", "_NormalizedVDWVD3001",
"_HydrophobicityD1001", "_HydrophobicityD2001", "_HydrophobicityD3001", "_HydrophobicityD3025",
"A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL",
"HC", "IA", "IL", "IV", "LA", "LC", "LE", "LI", "LT", "LV", "KC", "MA",
"MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29",
"GearyAuto_AvFlexibility30", "GearyAuto_Polarizability22", "GearyAuto_Polarizability24",
"GearyAuto_Polarizability25", "GearyAuto_Polarizability27", "GearyAuto_Polarizability28",
"GearyAuto_Polarizability29", "GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24",
"GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30", "GearyAuto_ResidueASA21",
"GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24",
"GearyAuto_ResidueVol25", "GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28",
"GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30", "GearyAuto_Steric18",
"GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28",
"GearyAuto_Mutability29", "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5",
"APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13", "APAAC15", "APAAC18", "APAAC19",
"APAAC24"
]
# AMP Feature Extractor
def extract_features(sequence):
all_features_dict = {}
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return "Error: Sequence too short."
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
ctd_features = CTD.CalculateCTD(sequence)
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
all_features_dict.update(ctd_features)
all_features_dict.update(filtered_dipeptide_features)
all_features_dict.update(auto_features)
all_features_dict.update(pseudo_features)
feature_df_all = pd.DataFrame([all_features_dict])
normalized_array = scaler.transform(feature_df_all.values)
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
selected_df = normalized_df[selected_features].fillna(0)
return selected_df.values
# MIC Predictor
def predictmic(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
seq_spaced = ' '.join(list(sequence))
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
outputs = protbert_model(**tokens)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
bacteria_config = {
"E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
"S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
"P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
"K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
}
mic_results = {}
for bacterium, cfg in bacteria_config.items():
try:
scaler = joblib.load(cfg["scaler"])
scaled = scaler.transform(embedding)
transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
model = joblib.load(cfg["model"])
mic_log = model.predict(transformed)[0]
mic = round(expm1(mic_log), 3)
mic_results[bacterium] = mic
except Exception as e:
mic_results[bacterium] = f"Error: {str(e)}"
return mic_results
# Combined Output as Single String
def full_prediction(sequence):
features = extract_features(sequence)
if isinstance(features, str): # error message returned
return features
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
if prediction == 0: # only predict MIC if AMP
mic_values = predictmic(sequence)
result += "\nPredicted MIC Values (µM):\n"
for organism, mic in mic_values.items():
result += f"- {organism}: {mic}\n"
else:
result += "\nMIC prediction is not available because sequence is Non-AMP."
return result
# Gradio Interface (Single Label Output)
iface = gr.Interface(
fn=full_prediction,
inputs=gr.Textbox(label="Enter Protein Sequence"),
outputs=gr.Textbox(label="AMP & MIC Prediction Summary"),
title="AMP & MIC Predictor",
description="Enter an amino acid sequence (≥10 valid letters) to predict AMP class and MIC values."
)
iface.launch(share=True)
|