AMP-Classifier / app.py
nonzeroexit's picture
Update app.py
f0f9b27 verified
raw
history blame
8.26 kB
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
from sklearn.preprocessing import MinMaxScaler
import torch
from transformers import BertTokenizer, BertModel
from math import expm1
# =====================
# Load AMP Classifier
# =====================
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")
# =====================
# Load ProtBert Globally
# =====================
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protbert_model = protbert_model.to(device).eval()
# =====================
# Feature List (ProPy)
# =====================
selected_features = [
"_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001",
"_PolarizabilityD2001", "_PolarizabilityD3001", "_SolventAccessibilityD1001",
"_SolventAccessibilityD2001", "_SolventAccessibilityD3001", "_SecondaryStrD1001",
"_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001",
"_PolarityD1050", "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001",
"_NormalizedVDWVD2001", "_NormalizedVDWVD2025", "_NormalizedVDWVD2050", "_NormalizedVDWVD3001",
"_HydrophobicityD1001", "_HydrophobicityD2001", "_HydrophobicityD3001", "_HydrophobicityD3025",
"A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL",
"HC", "IA", "IL", "IV", "LA", "LC", "LE", "LI", "LT", "LV", "KC", "MA",
"MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29",
"GearyAuto_AvFlexibility30", "GearyAuto_Polarizability22", "GearyAuto_Polarizability24",
"GearyAuto_Polarizability25", "GearyAuto_Polarizability27", "GearyAuto_Polarizability28",
"GearyAuto_Polarizability29", "GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24",
"GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30", "GearyAuto_ResidueASA21",
"GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24",
"GearyAuto_ResidueVol25", "GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28",
"GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30", "GearyAuto_Steric18",
"GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28",
"GearyAuto_Mutability29", "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5",
"APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13", "APAAC15", "APAAC18", "APAAC19",
"APAAC24"
]
# =====================
# AMP Feature Extractor
# =====================
def extract_features(sequence):
all_features_dict = {}
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return "Error: Sequence too short."
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
ctd_features = CTD.CalculateCTD(sequence)
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
all_features_dict.update(ctd_features)
all_features_dict.update(filtered_dipeptide_features)
all_features_dict.update(auto_features)
all_features_dict.update(pseudo_features)
feature_df_all = pd.DataFrame([all_features_dict])
normalized_array = scaler.transform(feature_df_all.values)
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
selected_df = normalized_df[selected_features].fillna(0)
return selected_df.values
# =====================
# AMP Classifier
# =====================
def predict(sequence):
features = extract_features(sequence)
if isinstance(features, str):
return features
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
if prediction == 0:
return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
else:
return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
# =====================
# MIC Predictor (ProtBert-based)
# =====================
def predictmic(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
# Tokenize
seq_spaced = ' '.join(list(sequence))
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
outputs = protbert_model(**tokens)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
# MIC model config
bacteria_config = {
"E.coli": {
"model": "coli_xgboost_model.pkl",
"scaler": "coli_scaler.pkl",
"pca": None
},
"S.aureus": {
"model": "aur_xgboost_model.pkl",
"scaler": "aur_scaler.pkl",
"pca": None
},
"P.aeruginosa": {
"model": "arg_xgboost_model.pkl",
"scaler": "arg_scaler.pkl",
"pca": None
},
"K.Pneumonia": {
"model": "pne_mlp_model.pkl",
"scaler": "pne_scaler.pkl",
"pca": "pne_pca"
}
}
mic_results = {}
for bacterium, cfg in bacteria_config.items():
try:
scaler = joblib.load(cfg["scaler"])
scaled = scaler.transform(embedding)
if cfg["pca"]:
pca = joblib.load(cfg["pca"])
transformed = pca.transform(scaled)
else:
transformed = scaled
model = joblib.load(cfg["model"])
mic_log = model.predict(transformed)[0]
mic = round(expm1(mic_log), 3)
mic_results[bacterium] = mic
except Exception as e:
mic_results[bacterium] = f"Error: {str(e)}"
return mic_results
# =====================
# Combined Prediction Function
# =====================
def full_prediction(sequence):
features = extract_features(sequence)
if isinstance(features, str):
return "Error", "0%", {}
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
mic_values = predictmic(sequence)
return amp_result, f"{confidence}%", mic_values
# =====================
# Gradio Interface
# =====================
iface = gr.Interface(
fn=full_prediction,
inputs=gr.Textbox(label="Enter Protein Sequence"),
outputs=[
gr.Label(label="AMP Classification"),
gr.Label(label="Confidence"),
gr.JSON(label="Predicted MIC (µg/mL) for Each Bacterium")
],
title="AMP & MIC Predictor",
description="Enter an amino acid sequence (≥10 valid letters) to predict AMP class and MIC values."
)
iface.launch(share=True)