nonzeroexit commited on
Commit
febb4a6
·
verified ·
1 Parent(s): adaeb14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -88
app.py CHANGED
@@ -8,21 +8,37 @@ import torch
8
  from transformers import BertTokenizer, BertModel
9
  from lime.lime_tabular import LimeTabularExplainer
10
  from math import expm1
 
 
 
 
 
 
 
11
 
12
  # Load AMP Classifier
13
- model = joblib.load("RF.joblib")
14
- scaler = joblib.load("norm (4).joblib")
 
 
 
 
 
15
 
16
  # Load ProtBert
17
- tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
18
- protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- protbert_model = protbert_model.to(device).eval()
 
 
 
 
21
 
22
- # Full list of selected features
23
  selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
24
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
25
- "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
26
  "_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
27
  "_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
28
  "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
@@ -48,98 +64,252 @@ selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondarySt
48
  "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
49
 
50
  # LIME Explainer Setup
51
- sample_data = np.random.rand(100, len(selected_features))
 
 
 
 
 
 
 
52
  explainer = LimeTabularExplainer(
53
  training_data=sample_data,
54
  feature_names=selected_features,
55
- class_names=["AMP", "Non-AMP"],
56
  mode="classification"
57
  )
58
 
59
- def extract_features(sequence):
60
- sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
61
- if len(sequence) < 10:
62
- return "Error: Sequence too short."
63
- dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
64
- filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
65
- ctd_features = CTD.CalculateCTD(sequence)
66
- auto_features = Autocorrelation.CalculateAutoTotal(sequence)
67
- pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
68
- all_features_dict = {}
69
- all_features_dict.update(ctd_features)
70
- all_features_dict.update(filtered_dipeptide_features)
71
- all_features_dict.update(auto_features)
72
- all_features_dict.update(pseudo_features)
73
- feature_df_all = pd.DataFrame([all_features_dict])
74
- normalized_array = scaler.transform(feature_df_all.values)
75
- normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
76
- if not set(selected_features).issubset(set(normalized_df.columns)):
77
- return "Error: Some selected features are missing from computed features."
78
- selected_df = normalized_df[selected_features].fillna(0)
79
- return selected_df.values
80
-
81
- def predictmic(sequence):
82
- sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
83
- if len(sequence) < 10:
84
- return {"Error": "Sequence too short or invalid."}
85
- seq_spaced = ' '.join(list(sequence))
86
- tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
87
- tokens = {k: v.to(device) for k, v in tokens.items()}
88
- with torch.no_grad():
89
- outputs = protbert_model(**tokens)
90
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  bacteria_config = {
92
- "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
93
- "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
94
- "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
95
- "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
96
  }
 
97
  mic_results = {}
98
- for bacterium, cfg in bacteria_config.items():
 
 
 
 
 
99
  try:
100
- scaler = joblib.load(cfg["scaler"])
101
- scaled = scaler.transform(embedding)
102
- transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
103
- model = joblib.load(cfg["model"])
104
- mic_log = model.predict(transformed)[0]
 
 
 
 
 
105
  mic = round(expm1(mic_log), 3)
106
- mic_results[bacterium] = mic
 
 
107
  except Exception as e:
108
- mic_results[bacterium] = f"Error: {str(e)}"
109
  return mic_results
110
 
111
- def full_prediction(sequence):
112
- features = extract_features(sequence)
113
- if isinstance(features, str):
114
- return features
115
- prediction = model.predict(features)[0]
116
- probabilities = model.predict_proba(features)[0]
117
- amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
118
- confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
119
- result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
120
- if prediction == 0:
121
- mic_values = predictmic(sequence)
122
- result += "\nPredicted MIC Values (\u00b5M):\n"
123
- for org, mic in mic_values.items():
124
- result += f"- {org}: {mic}\n"
125
- else:
126
- result += "\nMIC prediction skipped for Non-AMP sequences.\n"
127
- explanation = explainer.explain_instance(
128
- data_row=features[0],
129
- predict_fn=model.predict_proba,
130
- num_features=10
131
- )
132
- result += "\nTop Features Influencing Prediction:\n"
133
- for feat, weight in explanation.as_list():
134
- result += f"- {feat}: {round(weight, 4)}\n"
135
- return result
136
-
137
- iface = gr.Interface(
138
- fn=full_prediction,
139
- inputs=gr.Textbox(label="Enter Protein Sequence"),
140
- outputs=gr.Textbox(label="Results"),
141
- title="AMP & MIC Predictor + LIME Explanation",
142
- description="Paste an amino acid sequence (\u226510 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
143
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- iface.launch(share=True)
 
8
  from transformers import BertTokenizer, BertModel
9
  from lime.lime_tabular import LimeTabularExplainer
10
  from math import expm1
11
+ import matplotlib.pyplot as plt
12
+ import io
13
+ import base64
14
+ import os
15
+
16
+ # --- Configuration and Model Loading ---
17
+ MODEL_DIR = os.path.dirname(os.path.abspath(__file__))
18
 
19
  # Load AMP Classifier
20
+ try:
21
+ model = joblib.load(os.path.join(MODEL_DIR, "RF.joblib"))
22
+ scaler = joblib.load(os.path.join(MODEL_DIR, "norm (4).joblib"))
23
+ except FileNotFoundError as e:
24
+ raise gr.Error(f"Classifier model or scaler not found: {e}. Make sure RF.joblib and norm (4).joblib are in the {MODEL_DIR} directory.")
25
+ except Exception as e:
26
+ raise gr.Error(f"Error loading classifier components: {e}")
27
 
28
  # Load ProtBert
29
+ try:
30
+ tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
31
+ protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ protbert_model = protbert_model.to(device).eval()
34
+ except Exception as e:
35
+ raise gr.Error(f"Error loading ProtBert model/tokenizer: {e}. Check internet connection or model availability.")
36
+
37
 
38
+ # Full list of selected features (as provided in the original code)
39
  selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
40
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
41
+ "_PolarabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
42
  "_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
43
  "_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
44
  "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
 
64
  "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
65
 
66
  # LIME Explainer Setup
67
+ try:
68
+ # Attempt to load a real sample data for LIME background if available
69
+ # e.g., sample_data = np.load(os.path.join(MODEL_DIR, 'sample_training_features_scaled.npy'))
70
+ sample_data = np.random.rand(500, len(selected_features)) # Fallback: Generate random sample data
71
+ except Exception:
72
+ print("Warning: Could not load pre-saved sample data for LIME. Generating random sample data.")
73
+ sample_data = np.random.rand(500, len(selected_features)) # Generate enough samples
74
+
75
  explainer = LimeTabularExplainer(
76
  training_data=sample_data,
77
  feature_names=selected_features,
78
+ class_names=["AMP", "Non-AMP"], # Assuming 0 is AMP, 1 is Non-AMP as per model prediction
79
  mode="classification"
80
  )
81
 
82
+ # --- Feature Extraction Function ---
83
+ def extract_features(sequence: str) -> np.ndarray:
84
+ """
85
+ Extracts biochemical and compositional features from an amino acid sequence.
86
+ Args:
87
+ sequence (str): The amino acid sequence.
88
+ Returns:
89
+ np.ndarray: A scaled 2D numpy array of selected features (1, num_features).
90
+ Raises:
91
+ gr.Error: If the sequence is invalid or feature extraction fails.
92
+ """
93
+ cleaned_sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
94
+ if not (10 <= len(cleaned_sequence) <= 100):
95
+ raise gr.Error(f"Invalid sequence length ({len(cleaned_sequence)}). Must be between 10 and 100 characters and contain only standard amino acids.")
96
+
97
+ try:
98
+ dipeptide_features = AAComposition.CalculateAADipeptideComposition(cleaned_sequence)
99
+ ctd_features = CTD.CalculateCTD(cleaned_sequence)
100
+ auto_features = Autocorrelation.CalculateAutoTotal(cleaned_sequence)
101
+ pseudo_features = PseudoAAC.GetAPseudoAAC(cleaned_sequence, lamda=9)
102
+
103
+ all_features_dict = {}
104
+ all_features_dict.update(ctd_features)
105
+ all_features_dict.update(dipeptide_features)
106
+ all_features_dict.update(auto_features)
107
+ all_features_dict.update(pseudo_features)
108
+
109
+ feature_df_all = pd.DataFrame([all_features_dict])
110
+
111
+ computed_features_ordered = feature_df_all.reindex(columns=selected_features, fill_value=0)
112
+ computed_features_ordered = computed_features_ordered.fillna(0)
113
+
114
+ normalized_array = scaler.transform(computed_features_ordered.values)
115
+
116
+ return normalized_array
117
+ except Exception as e:
118
+ raise gr.Error(f"Feature extraction failed: {e}. Ensure sequence is valid and Propy dependencies are met.")
119
+
120
+ # --- MIC Prediction Function ---
121
+ def predictmic(sequence: str, selected_bacteria_keys: list) -> dict:
122
+ """
123
+ Predicts Minimum Inhibitory Concentration (MIC) for selected bacteria using ProtBert embeddings.
124
+ Args:
125
+ sequence (str): The amino acid sequence.
126
+ selected_bacteria_keys (list): List of keys for bacteria to predict MIC for (e.g., ['e_coli', 'p_aeruginosa']).
127
+ Returns:
128
+ dict: A dictionary where keys are bacterium keys and values are predicted MICs in µM.
129
+ Returns error messages for individual bacteria if prediction fails.
130
+ Raises:
131
+ gr.Error: If ProtBert embedding fails or sequence is invalid.
132
+ """
133
+ cleaned_sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
134
+ if not (10 <= len(cleaned_sequence) <= 100):
135
+ raise gr.Error(f"Invalid sequence length for MIC prediction ({len(cleaned_sequence)}). Must be between 10 and 100 characters.")
136
+
137
+ seq_spaced = ' '.join(list(cleaned_sequence))
138
+ try:
139
+ tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
140
+ tokens = {k: v.to(device) for k, v in tokens.items()}
141
+ with torch.no_grad():
142
+ outputs = protbert_model(**tokens)
143
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
144
+ except Exception as e:
145
+ raise gr.Error(f"Error generating ProtBert embedding: {e}. Check sequence format or model availability.")
146
+
147
  bacteria_config = {
148
+ "e_coli": {"display_name": "E.coli", "model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
149
+ "p_aeruginosa": {"display_name": "P. aeruginosa", "model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
150
+ "s_aureus": {"display_name": "S. aureus", "model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
151
+ "k_pneumoniae": {"display_name": "K. pneumoniae", "model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
152
  }
153
+
154
  mic_results = {}
155
+ for bacterium_key in selected_bacteria_keys:
156
+ cfg = bacteria_config.get(bacterium_key)
157
+ if not cfg:
158
+ mic_results[bacterium_key] = "Error: Invalid bacterium key provided."
159
+ continue
160
+
161
  try:
162
+ mic_scaler = joblib.load(os.path.join(MODEL_DIR, cfg["scaler"]))
163
+ scaled_embedding = mic_scaler.transform(embedding)
164
+
165
+ transformed_embedding = scaled_embedding
166
+ if cfg["pca"]:
167
+ mic_pca = joblib.load(os.path.join(MODEL_DIR, cfg["pca"]))
168
+ transformed_embedding = mic_pca.transform(scaled_embedding)
169
+
170
+ mic_model = joblib.load(os.path.join(MODEL_DIR, cfg["model"]))
171
+ mic_log = mic_model.predict(transformed_embedding)[0]
172
  mic = round(expm1(mic_log), 3)
173
+ mic_results[bacterium_key] = mic
174
+ except FileNotFoundError as e:
175
+ mic_results[bacterium_key] = f"Model file not found for {cfg['display_name']}: {e}"
176
  except Exception as e:
177
+ mic_results[bacterium_key] = f"Prediction error for {cfg['display_name']}: {e}"
178
  return mic_results
179
 
180
+ # --- LIME Plot Generation Helper ---
181
+ def generate_lime_plot_base64(explanation_list: list) -> str:
182
+ """
183
+ Generates a LIME explanation plot and returns it as a base64 encoded PNG string.
184
+ Args:
185
+ explanation_list (list): The output from LimeExplanation.as_list().
186
+ Returns:
187
+ str: Base64 encoded PNG image string.
188
+ """
189
+ if not explanation_list:
190
+ return ""
191
+
192
+ fig, ax = plt.subplots(figsize=(10, 6))
193
+ features = [item[0] for item in explanation_list]
194
+ weights = [item[1] for item in explanation_list]
195
+
196
+ sorted_indices = np.argsort(np.abs(weights))[::-1]
197
+ features_sorted = [features[i] for i in sorted_indices]
198
+ weights_sorted = [weights[i] for i in sorted_indices]
199
+
200
+ y_pos = np.arange(len(features_sorted))
201
+ colors = ['green' if w > 0 else 'red' for w in weights_sorted]
202
+ ax.barh(y_pos, weights_sorted, align='center', color=colors)
203
+ ax.set_yticks(y_pos)
204
+ ax.set_yticklabels(features_sorted, fontsize=10)
205
+ ax.invert_yaxis()
206
+ ax.set_xlabel('Contribution to Prediction (LIME Weight)', fontsize=12)
207
+ ax.set_title('Top Features Influencing Prediction (LIME)', fontsize=14)
208
+ ax.axvline(0, color='grey', linestyle='--', linewidth=0.8)
209
+ plt.grid(axis='x', linestyle=':', alpha=0.7)
210
+
211
+ buf = io.BytesIO()
212
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=150)
213
+ buf.seek(0)
214
+ image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
215
+ plt.close(fig)
216
+ return image_base64
217
+
218
+ # --- Gradio API Endpoints ---
219
+
220
+ def classify_and_interpret_amp(sequence: str) -> dict:
221
+ """
222
+ Gradio API endpoint for AMP classification and interpretability (LIME).
223
+ This function processes the sequence, performs classification, generates LIME explanation,
224
+ and formats the output as a structured dictionary for the frontend.
225
+ """
226
+ try:
227
+ features = extract_features(sequence)
228
+
229
+ prediction_class_idx = model.predict(features)[0]
230
+ probabilities = model.predict_proba(features)[0]
231
+
232
+ amp_label = "AMP (Positive)" if prediction_class_idx == 0 else "Non-AMP"
233
+ confidence = probabilities[prediction_class_idx]
234
+
235
+ explanation = explainer.explain_instance(
236
+ data_row=features[0],
237
+ predict_fn=model.predict_proba,
238
+ num_features=10
239
+ )
240
+
241
+ top_features = []
242
+ for feat_str, weight in explanation.as_list():
243
+ # Parse the feature string from LIME (e.g., "APAAC4 <= 0.23")
244
+ # This parsing is a heuristic based on LIME's default output format.
245
+ parts = feat_str.split(" ", 1)
246
+ feature_name = parts[0]
247
+ condition = parts[1] if len(parts) > 1 else ""
248
+
249
+ top_features.append({
250
+ "feature": feature_name,
251
+ "condition": condition.strip(),
252
+ "value": round(weight, 4)
253
+ })
254
+
255
+ lime_plot_base64_str = generate_lime_plot_base64(explanation.as_list())
256
+
257
+ return {
258
+ "label": amp_label,
259
+ "confidence": float(confidence),
260
+ "shap_plot_base64": lime_plot_base64_str,
261
+ "top_features": top_features
262
+ }
263
+
264
+ except gr.Error as e:
265
+ raise e
266
+ except Exception as e:
267
+ raise gr.Error(f"An unexpected error occurred during AMP classification: {e}")
268
+
269
+ def get_mic_predictions_api(sequence: str, selected_bacteria_keys: list) -> dict:
270
+ """
271
+ Gradio API endpoint for MIC prediction.
272
+ This function wraps the `predictmic` function to serve as a separate API endpoint.
273
+ """
274
+ try:
275
+ mic_results = predictmic(sequence, selected_bacteria_keys)
276
+ return mic_results
277
+ except gr.Error as e:
278
+ raise e
279
+ except Exception as e:
280
+ raise gr.Error(f"An unexpected error occurred during MIC prediction API call: {e}")
281
+
282
+ # --- Gradio Interface Definition ---
283
+ with gr.Blocks() as demo:
284
+ gr.Markdown("# EPIC-AMP Platform Backend API")
285
+ gr.Markdown("This Gradio application provides the backend services for the EPIC-AMP frontend.")
286
+
287
+ with gr.Tab("AMP Classification & Interpretability API"):
288
+ gr.Markdown("### `/predict` Endpoint (AMP Classification, Confidence, LIME Plot, Top Features)")
289
+ gr.Markdown("Input an amino acid sequence (10-100 AAs) to get classification details.")
290
+ sequence_input_amp = gr.Textbox(label="Amino Acid Sequence", lines=5, placeholder="Enter sequence here...")
291
+ amp_api_output = gr.Json(label="AMP Prediction Details JSON Output")
292
+ gr.Button("Test Classification").click(
293
+ fn=classify_and_interpret_amp,
294
+ inputs=[sequence_input_amp],
295
+ outputs=[amp_api_output],
296
+ api_name="predict"
297
+ )
298
+
299
+ with gr.Tab("MIC Prediction API"):
300
+ gr.Markdown("### `/predict_mic` Endpoint (MIC Values)")
301
+ gr.Markdown("Input an amino acid sequence (only if classified as AMP) and select bacteria to get predicted MIC values.")
302
+ sequence_input_mic = gr.Textbox(label="Amino Acid Sequence", lines=5, placeholder="Enter AMP sequence for MIC prediction...")
303
+ mic_bacteria_checkboxes = gr.CheckboxGroup(
304
+ choices=["e_coli", "p_aeruginosa", "s_aureus", "k_pneumoniae"],
305
+ label="Select Bacteria for MIC Prediction (keys for backend)"
306
+ )
307
+ mic_api_output = gr.Json(label="MIC Prediction JSON Output")
308
+ gr.Button("Test MIC Prediction").click(
309
+ fn=get_mic_predictions_api,
310
+ inputs=[sequence_input_mic, mic_bacteria_checkboxes],
311
+ outputs=[mic_api_output],
312
+ api_name="predict_mic"
313
+ )
314
 
315
+ demo.launch(share=True, enable_queue=True, show_api=True)