File size: 13,953 Bytes
85c36de
942bf87
51a3749
ea9a1bf
e199881
51a3749
f0f9b27
 
 
51a3749
f0f9b27
c83df39
f0f9b27
c83df39
e199881
8bc43cc
942bf87
f0f9b27
c83df39
f0f9b27
 
 
c83df39
f0f9b27
c83df39
f0f9b27
 
c83df39
f0f9b27
248a61c
e199881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11e1095
dc9275e
f0f9b27
c83df39
f0f9b27
3b84715
c83df39
 
 
 
aa6838a
c83df39
f0f9b27
c83df39
f0f9b27
c83df39
4d0770a
c83df39
c63f76d
c83df39
 
 
 
bee2eef
c83df39
 
c63f76d
c83df39
8319384
bee2eef
 
8319384
c63f76d
c83df39
f3b700a
c83df39
 
 
 
 
 
 
 
 
 
f0f9b27
c83df39
 
 
4d0770a
f0f9b27
9748994
f0f9b27
c83df39
f0f9b27
c83df39
 
 
 
 
0c1f1e9
 
c83df39
0c1f1e9
c83df39
0c1f1e9
 
 
 
c83df39
0c1f1e9
f0f9b27
c83df39
f0f9b27
0c1f1e9
c83df39
0c1f1e9
c83df39
 
 
 
 
0c1f1e9
c83df39
 
 
 
 
0c1f1e9
c83df39
 
 
 
 
0c1f1e9
c83df39
 
 
 
 
0c1f1e9
 
 
 
c83df39
 
 
 
 
 
0c1f1e9
c83df39
 
 
 
 
 
 
 
0c1f1e9
c83df39
 
 
 
 
 
 
 
 
0c1f1e9
c83df39
0c1f1e9
 
 
f0f9b27
c83df39
f0f9b27
c83df39
 
 
 
 
 
357b75d
c83df39
 
 
357b75d
 
c83df39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357b75d
f0f9b27
c83df39
f0f9b27
c83df39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
from sklearn.preprocessing import MinMaxScaler
import torch
from transformers import BertTokenizer, BertModel
from math import expm1

# =====================
# Load AMP Classifier Model (Random Forest)
# =====================
# Ensure 'RF.joblib' and 'norm (4).joblib' are in the same directory or provide full paths
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")

# =====================
# Load ProtBert Model Globally for MIC Prediction
# =====================
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
# Move model to GPU if available for faster inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protbert_model = protbert_model.to(device).eval() # Set to evaluation mode

# =====================
# Feature List (ProPy Descriptors) used by AMP Classifier
# =====================
selected_features = [
    "_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
    "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001",
    "_PolarizabilityD2001", "_PolarizabilityD3001", "_SolventAccessibilityD1001",
    "_SolventAccessibilityD2001", "_SolventAccessibilityD3001", "_SecondaryStrD1001",
    "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
    "_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001",
    "_PolarityD1050", "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001",
    "_NormalizedVDWVD2001", "_NormalizedVDWVD2025", "_NormalizedVDWVD2050", "_NormalizedVDWVD3001",
    "_HydrophobicityD1001", "_HydrophobicityD2001", "_HydrophobicityD3001", "_HydrophobicityD3025",
    "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
    "AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL",
    "HC", "IA", "IL", "IV", "LA", "LC", "LE", "LI", "LT", "LV", "KC", "MA",
    "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
    "MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
    "GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
    "GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
    "GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
    "GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29",
    "GearyAuto_AvFlexibility30", "GearyAuto_Polarizability22", "GearyAuto_Polarizability24",
    "GearyAuto_Polarizability25", "GearyAuto_Polarizability27", "GearyAuto_Polarizability28",
    "GearyAuto_Polarizability29", "GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24",
    "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30", "GearyAuto_ResidueASA21",
    "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
    "GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24",
    "GearyAuto_ResidueVol25", "GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28",
    "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30", "GearyAuto_Steric18",
    "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
    "GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
    "GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28",
    "GearyAuto_Mutability29", "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5",
    "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13", "APAAC15", "APAAC18", "APAAC19",
    "APAAC24"
]

# =====================
# AMP Feature Extractor Function
# =====================
def extract_features(sequence):
    """
    Extracts physiochemical and compositional features from a protein sequence using ProPy.
    Applies the pre-trained scaler and selects relevant features.
    """
    all_features_dict = {}
    # Clean sequence to include only valid amino acids
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])

    if len(sequence) < 10:
        return "Error: Sequence too short or invalid. Must contain at least 10 valid amino acids."

    # Calculate various ProPy features
    dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
    # Note: Dipeptide composition calculates 400 features, using a slice here might be specific to the original model's training
    # If the original model used all 400, this slice needs to be adjusted or removed.
    # For now, keeping as per the provided code.
    filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]} # This slice is unusual if only 400 dipeptides exist.
    ctd_features = CTD.CalculateCTD(sequence)
    auto_features = Autocorrelation.CalculateAutoTotal(sequence) # Includes Moran, Geary, Moreau-Broto
    pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9) # Pseudo Amino Acid Composition

    # Combine all extracted features into a single dictionary
    all_features_dict.update(ctd_features)
    all_features_dict.update(filtered_dipeptide_features)
    all_features_dict.update(auto_features)
    all_features_dict.update(pseudo_features)

    # Convert to DataFrame for consistent column handling with scaler
    feature_df_all = pd.DataFrame([all_features_dict])

    # Handle missing features (if any arise from short sequences or specific AA combinations not producing all features)
    # Ensure all selected_features are present, add as 0 if missing.
    for col in selected_features:
        if col not in feature_df_all.columns:
            feature_df_all[col] = 0

    # Normalize features using the pre-trained scaler
    # Ensure the order of columns matches the scaler's training order before scaling
    feature_df_all = feature_df_all[scaler.feature_names_in_] # Align columns with scaler's expected input
    normalized_array = scaler.transform(feature_df_all.values)
    
    # Select only the features that the final RF model expects
    selected_df = pd.DataFrame(normalized_array, columns=scaler.feature_names_in_)[selected_features].fillna(0)

    return selected_df.values

# =====================
# MIC Predictor Function (ProtBert-based)
# =====================
def predict_mic_values(sequence, selected_bacteria_keys):
    """
    Predicts Minimum Inhibitory Concentration (MIC) for a given peptide sequence
    against selected bacteria using ProtBert embeddings and pre-trained models.
    """
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        return {"Error": "Sequence too short or invalid for MIC prediction."}

    # Tokenize the sequence for ProtBert
    seq_spaced = ' '.join(list(sequence))
    tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    tokens = {k: v.to(device) for k, v in tokens.items()}

    # Get ProtBert embedding
    with torch.no_grad():
        outputs = protbert_model(**tokens)
        # Use mean of last hidden state as sequence embedding
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)

    # Configuration for MIC models (paths to joblib files)
    bacteria_config = {
        "e_coli": { # Changed keys to match frontend values
            "display_name": "E.coli",
            "model_path": "coli_xgboost_model.pkl",
            "scaler_path": "coli_scaler.pkl",
            "pca_path": None
        },
        "s_aureus": { # Changed keys to match frontend values
            "display_name": "S.aureus",
            "model_path": "aur_xgboost_model.pkl",
            "scaler_path": "aur_scaler.pkl",
            "pca_path": None
        },
        "p_aeruginosa": { # Changed keys to match frontend values
            "display_name": "P.aeruginosa",
            "model_path": "arg_xgboost_model.pkl",
            "scaler_path": "arg_scaler.pkl",
            "pca_path": None
        },
        "k_pneumoniae": { # Changed keys to match frontend values
            "display_name": "K.Pneumoniae",
            "model_path": "pne_mlp_model.pkl",
            "scaler_path": "pne_scaler.pkl",
            "pca_path": "pne_pca.pkl"
        }
    }

    mic_results = {}
    for bacterium_key in selected_bacteria_keys:
        cfg = bacteria_config.get(bacterium_key)
        if not cfg:
            mic_results[bacterium_key] = "Error: Invalid bacterium key"
            continue

        try:
            # Load scaler and transform embedding
            scaler = joblib.load(cfg["scaler_path"])
            scaled_embedding = scaler.transform(embedding)

            # Apply PCA if configured
            if cfg["pca_path"]:
                pca = joblib.load(cfg["pca_path"])
                final_features = pca.transform(scaled_embedding)
            else:
                final_features = scaled_embedding

            # Load and predict with the MIC model
            mic_model = joblib.load(cfg["model_path"])
            mic_log = mic_model.predict(final_features)[0]

            # Convert log-transformed MIC back to original scale (µM)
            mic = round(expm1(mic_log), 3) # expm1(x) is equivalent to exp(x) - 1, robust for small x
            mic_results[cfg["display_name"]] = mic
        except Exception as e:
            mic_results[cfg["display_name"]] = f"Prediction Error: {str(e)}"

    return mic_results

# =====================
# Gradio Interface Functions
# =====================

def amp_classifier_predict(sequence):
    """
    Function for AMP classification endpoint in Gradio.
    Returns the AMP classification label, confidence, and SHAP plot Base64 string.
    """
    features = extract_features(sequence)
    if isinstance(features, str): # Handle extraction error
        return gr.Label(f"Error: {features}", label="AMP Classification"), None

    prediction = model.predict(features)[0]
    probabilities = model.predict_proba(features)[0]

    amp_label = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
    confidence_value = probabilities[prediction] # Confidence of the predicted class

    # Placeholder for SHAP plot generation (not implemented in this snippet)
    # In a real scenario, you'd generate a SHAP plot image here (e.g., using matplotlib, shap library)
    # and encode it to base64.
    shap_plot_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" # A tiny transparent PNG base64

    # The Gradio `predict` function can return structured data as a dictionary if using `gr.JSON` output
    # However, since the frontend is expecting `data[0].label`, `data[0].confidence`, etc.
    # we'll return a dictionary that matches that structure.
    return {
        "label": amp_label,
        "confidence": confidence_value,
        "shap_plot_base64": shap_plot_base64 # Return SHAP plot as Base64 (placeholder for now)
    }

def mic_predictor_predict(sequence, selected_bacteria):
    """
    Function for MIC prediction endpoint in Gradio.
    Takes the sequence and a list of selected bacteria keys.
    """
    # Only predict MIC if AMP (Positive) classification
    # This check would ideally be part of the frontend logic or a combined backend function
    # but for standalone MIC endpoint, we just proceed.
    # The frontend is responsible for calling this only if AMP is positive.
    mic_results = predict_mic_values(sequence, selected_bacteria)
    return mic_results # Returns a dictionary of MIC values

# =====================
# Define Gradio Interface (hidden, for client connection)
# =====================
# This Gradio app is designed to be used as a backend service by your custom HTML frontend.
# The inputs and outputs here correspond to what the frontend's `gradio.client` expects.

with gr.Blocks() as demo:
    gr.Markdown("# BCBU-ZC AMP/MIC Backend Service")
    gr.Markdown("This Gradio application serves as the backend for the AMP classification and MIC prediction. It provides endpoints for sequence analysis and MIC prediction.")

    with gr.Tab("AMP Classification"):
        gr.Markdown("### AMP Classification Endpoint (`/predict`)")
        amp_input_sequence = gr.Textbox(label="Amino Acid Sequence")
        amp_output_json = gr.JSON(label="Classification Result (Label, Confidence, SHAP Plot Base64)")
        amp_predict_button = gr.Button("Predict AMP")
        amp_predict_button.click(
            fn=amp_classifier_predict,
            inputs=[amp_input_sequence],
            outputs=[amp_output_json],
            api_name="predict" # Define an API endpoint name for `gradio.client`
        )

    with gr.Tab("MIC Prediction"):
        gr.Markdown("### MIC Prediction Endpoint (`/predict_mic`)")
        mic_input_sequence = gr.Textbox(label="Amino Acid Sequence")
        mic_selected_bacteria = gr.CheckboxGroup(
            label="Select Bacteria",
            choices=["e_coli", "p_aeruginosa", "s_aureus", "k_pneumoniae"],
            value=["e_coli", "p_aeruginosa", "s_aureus", "k_pneumoniae"] # Default for testing
        )
        mic_output_json = gr.JSON(label="Predicted MIC Values (µM)")
        mic_predict_button = gr.Button("Predict MIC")
        mic_predict_button.click(
            fn=mic_predictor_predict,
            inputs=[mic_input_sequence, mic_selected_bacteria],
            outputs=[mic_output_json],
            api_name="predict_mic" # Define a separate API endpoint name
        )

# Launch the Gradio app
# `share=True` creates a public, temporary URL for external access (useful for testing frontend)
# `allowed_paths` should be set to allow access from specific origins if deploying
demo.launch(share=True)