nonzeroexit commited on
Commit
68ded6f
·
verified ·
1 Parent(s): c83df39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -195
app.py CHANGED
@@ -8,25 +8,17 @@ import torch
8
  from transformers import BertTokenizer, BertModel
9
  from math import expm1
10
 
11
- # =====================
12
- # Load AMP Classifier Model (Random Forest)
13
- # =====================
14
- # Ensure 'RF.joblib' and 'norm (4).joblib' are in the same directory or provide full paths
15
  model = joblib.load("RF.joblib")
16
  scaler = joblib.load("norm (4).joblib")
17
 
18
- # =====================
19
- # Load ProtBert Model Globally for MIC Prediction
20
- # =====================
21
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
22
  protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
23
- # Move model to GPU if available for faster inference
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- protbert_model = protbert_model.to(device).eval() # Set to evaluation mode
26
 
27
- # =====================
28
- # Feature List (ProPy Descriptors) used by AMP Classifier
29
- # =====================
30
  selected_features = [
31
  "_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
32
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001",
@@ -62,223 +54,113 @@ selected_features = [
62
  "APAAC24"
63
  ]
64
 
65
- # =====================
66
- # AMP Feature Extractor Function
67
- # =====================
68
  def extract_features(sequence):
69
- """
70
- Extracts physiochemical and compositional features from a protein sequence using ProPy.
71
- Applies the pre-trained scaler and selects relevant features.
72
- """
73
  all_features_dict = {}
74
- # Clean sequence to include only valid amino acids
75
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
76
-
77
  if len(sequence) < 10:
78
- return "Error: Sequence too short or invalid. Must contain at least 10 valid amino acids."
79
-
80
- # Calculate various ProPy features
81
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
82
- # Note: Dipeptide composition calculates 400 features, using a slice here might be specific to the original model's training
83
- # If the original model used all 400, this slice needs to be adjusted or removed.
84
- # For now, keeping as per the provided code.
85
- filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]} # This slice is unusual if only 400 dipeptides exist.
86
  ctd_features = CTD.CalculateCTD(sequence)
87
- auto_features = Autocorrelation.CalculateAutoTotal(sequence) # Includes Moran, Geary, Moreau-Broto
88
- pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9) # Pseudo Amino Acid Composition
89
-
90
- # Combine all extracted features into a single dictionary
91
  all_features_dict.update(ctd_features)
92
  all_features_dict.update(filtered_dipeptide_features)
93
  all_features_dict.update(auto_features)
94
  all_features_dict.update(pseudo_features)
95
-
96
- # Convert to DataFrame for consistent column handling with scaler
97
  feature_df_all = pd.DataFrame([all_features_dict])
98
-
99
- # Handle missing features (if any arise from short sequences or specific AA combinations not producing all features)
100
- # Ensure all selected_features are present, add as 0 if missing.
101
- for col in selected_features:
102
- if col not in feature_df_all.columns:
103
- feature_df_all[col] = 0
104
-
105
- # Normalize features using the pre-trained scaler
106
- # Ensure the order of columns matches the scaler's training order before scaling
107
- feature_df_all = feature_df_all[scaler.feature_names_in_] # Align columns with scaler's expected input
108
  normalized_array = scaler.transform(feature_df_all.values)
109
-
110
- # Select only the features that the final RF model expects
111
- selected_df = pd.DataFrame(normalized_array, columns=scaler.feature_names_in_)[selected_features].fillna(0)
112
-
113
  return selected_df.values
114
 
115
- # =====================
116
- # MIC Predictor Function (ProtBert-based)
117
- # =====================
118
- def predict_mic_values(sequence, selected_bacteria_keys):
119
- """
120
- Predicts Minimum Inhibitory Concentration (MIC) for a given peptide sequence
121
- against selected bacteria using ProtBert embeddings and pre-trained models.
122
- """
 
 
 
 
 
 
123
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
124
  if len(sequence) < 10:
125
- return {"Error": "Sequence too short or invalid for MIC prediction."}
126
-
127
- # Tokenize the sequence for ProtBert
128
  seq_spaced = ' '.join(list(sequence))
129
  tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
130
  tokens = {k: v.to(device) for k, v in tokens.items()}
131
-
132
- # Get ProtBert embedding
133
  with torch.no_grad():
134
  outputs = protbert_model(**tokens)
135
- # Use mean of last hidden state as sequence embedding
136
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
137
-
138
- # Configuration for MIC models (paths to joblib files)
139
  bacteria_config = {
140
- "e_coli": { # Changed keys to match frontend values
141
- "display_name": "E.coli",
142
- "model_path": "coli_xgboost_model.pkl",
143
- "scaler_path": "coli_scaler.pkl",
144
- "pca_path": None
145
  },
146
- "s_aureus": { # Changed keys to match frontend values
147
- "display_name": "S.aureus",
148
- "model_path": "aur_xgboost_model.pkl",
149
- "scaler_path": "aur_scaler.pkl",
150
- "pca_path": None
151
  },
152
- "p_aeruginosa": { # Changed keys to match frontend values
153
- "display_name": "P.aeruginosa",
154
- "model_path": "arg_xgboost_model.pkl",
155
- "scaler_path": "arg_scaler.pkl",
156
- "pca_path": None
157
  },
158
- "k_pneumoniae": { # Changed keys to match frontend values
159
- "display_name": "K.Pneumoniae",
160
- "model_path": "pne_mlp_model.pkl",
161
- "scaler_path": "pne_scaler.pkl",
162
- "pca_path": "pne_pca.pkl"
163
  }
164
  }
165
-
166
  mic_results = {}
167
- for bacterium_key in selected_bacteria_keys:
168
- cfg = bacteria_config.get(bacterium_key)
169
- if not cfg:
170
- mic_results[bacterium_key] = "Error: Invalid bacterium key"
171
- continue
172
-
173
  try:
174
- # Load scaler and transform embedding
175
- scaler = joblib.load(cfg["scaler_path"])
176
- scaled_embedding = scaler.transform(embedding)
177
-
178
- # Apply PCA if configured
179
- if cfg["pca_path"]:
180
- pca = joblib.load(cfg["pca_path"])
181
- final_features = pca.transform(scaled_embedding)
182
  else:
183
- final_features = scaled_embedding
184
-
185
- # Load and predict with the MIC model
186
- mic_model = joblib.load(cfg["model_path"])
187
- mic_log = mic_model.predict(final_features)[0]
188
-
189
- # Convert log-transformed MIC back to original scale (µM)
190
- mic = round(expm1(mic_log), 3) # expm1(x) is equivalent to exp(x) - 1, robust for small x
191
- mic_results[cfg["display_name"]] = mic
192
  except Exception as e:
193
- mic_results[cfg["display_name"]] = f"Prediction Error: {str(e)}"
194
-
195
  return mic_results
196
 
197
- # =====================
198
- # Gradio Interface Functions
199
- # =====================
200
-
201
- def amp_classifier_predict(sequence):
202
- """
203
- Function for AMP classification endpoint in Gradio.
204
- Returns the AMP classification label, confidence, and SHAP plot Base64 string.
205
- """
206
  features = extract_features(sequence)
207
- if isinstance(features, str): # Handle extraction error
208
- return gr.Label(f"Error: {features}", label="AMP Classification"), None
209
-
210
  prediction = model.predict(features)[0]
211
  probabilities = model.predict_proba(features)[0]
212
-
213
- amp_label = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
214
- confidence_value = probabilities[prediction] # Confidence of the predicted class
215
-
216
- # Placeholder for SHAP plot generation (not implemented in this snippet)
217
- # In a real scenario, you'd generate a SHAP plot image here (e.g., using matplotlib, shap library)
218
- # and encode it to base64.
219
- shap_plot_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" # A tiny transparent PNG base64
220
-
221
- # The Gradio `predict` function can return structured data as a dictionary if using `gr.JSON` output
222
- # However, since the frontend is expecting `data[0].label`, `data[0].confidence`, etc.
223
- # we'll return a dictionary that matches that structure.
224
- return {
225
- "label": amp_label,
226
- "confidence": confidence_value,
227
- "shap_plot_base64": shap_plot_base64 # Return SHAP plot as Base64 (placeholder for now)
228
- }
229
-
230
- def mic_predictor_predict(sequence, selected_bacteria):
231
- """
232
- Function for MIC prediction endpoint in Gradio.
233
- Takes the sequence and a list of selected bacteria keys.
234
- """
235
- # Only predict MIC if AMP (Positive) classification
236
- # This check would ideally be part of the frontend logic or a combined backend function
237
- # but for standalone MIC endpoint, we just proceed.
238
- # The frontend is responsible for calling this only if AMP is positive.
239
- mic_results = predict_mic_values(sequence, selected_bacteria)
240
- return mic_results # Returns a dictionary of MIC values
241
-
242
- # =====================
243
- # Define Gradio Interface (hidden, for client connection)
244
- # =====================
245
- # This Gradio app is designed to be used as a backend service by your custom HTML frontend.
246
- # The inputs and outputs here correspond to what the frontend's `gradio.client` expects.
247
-
248
- with gr.Blocks() as demo:
249
- gr.Markdown("# BCBU-ZC AMP/MIC Backend Service")
250
- gr.Markdown("This Gradio application serves as the backend for the AMP classification and MIC prediction. It provides endpoints for sequence analysis and MIC prediction.")
251
-
252
- with gr.Tab("AMP Classification"):
253
- gr.Markdown("### AMP Classification Endpoint (`/predict`)")
254
- amp_input_sequence = gr.Textbox(label="Amino Acid Sequence")
255
- amp_output_json = gr.JSON(label="Classification Result (Label, Confidence, SHAP Plot Base64)")
256
- amp_predict_button = gr.Button("Predict AMP")
257
- amp_predict_button.click(
258
- fn=amp_classifier_predict,
259
- inputs=[amp_input_sequence],
260
- outputs=[amp_output_json],
261
- api_name="predict" # Define an API endpoint name for `gradio.client`
262
- )
263
-
264
- with gr.Tab("MIC Prediction"):
265
- gr.Markdown("### MIC Prediction Endpoint (`/predict_mic`)")
266
- mic_input_sequence = gr.Textbox(label="Amino Acid Sequence")
267
- mic_selected_bacteria = gr.CheckboxGroup(
268
- label="Select Bacteria",
269
- choices=["e_coli", "p_aeruginosa", "s_aureus", "k_pneumoniae"],
270
- value=["e_coli", "p_aeruginosa", "s_aureus", "k_pneumoniae"] # Default for testing
271
- )
272
- mic_output_json = gr.JSON(label="Predicted MIC Values (µM)")
273
- mic_predict_button = gr.Button("Predict MIC")
274
- mic_predict_button.click(
275
- fn=mic_predictor_predict,
276
- inputs=[mic_input_sequence, mic_selected_bacteria],
277
- outputs=[mic_output_json],
278
- api_name="predict_mic" # Define a separate API endpoint name
279
- )
280
-
281
- # Launch the Gradio app
282
- # `share=True` creates a public, temporary URL for external access (useful for testing frontend)
283
- # `allowed_paths` should be set to allow access from specific origins if deploying
284
- demo.launch(share=True)
 
8
  from transformers import BertTokenizer, BertModel
9
  from math import expm1
10
 
11
+ # Load AMP Classifier
 
 
 
12
  model = joblib.load("RF.joblib")
13
  scaler = joblib.load("norm (4).joblib")
14
 
15
+ # Load ProtBert Globally
 
 
16
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
17
  protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ protbert_model = protbert_model.to(device).eval()
20
 
21
+ # Selected Features
 
 
22
  selected_features = [
23
  "_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
24
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001",
 
54
  "APAAC24"
55
  ]
56
 
57
+ # AMP Feature Extractor
 
 
58
  def extract_features(sequence):
 
 
 
 
59
  all_features_dict = {}
 
60
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
 
61
  if len(sequence) < 10:
62
+ return "Error: Sequence too short."
 
 
63
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
64
+ filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
 
 
 
65
  ctd_features = CTD.CalculateCTD(sequence)
66
+ auto_features = Autocorrelation.CalculateAutoTotal(sequence)
67
+ pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
 
 
68
  all_features_dict.update(ctd_features)
69
  all_features_dict.update(filtered_dipeptide_features)
70
  all_features_dict.update(auto_features)
71
  all_features_dict.update(pseudo_features)
 
 
72
  feature_df_all = pd.DataFrame([all_features_dict])
 
 
 
 
 
 
 
 
 
 
73
  normalized_array = scaler.transform(feature_df_all.values)
74
+ normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
75
+ selected_df = normalized_df[selected_features].fillna(0)
 
 
76
  return selected_df.values
77
 
78
+ # AMP Classifier
79
+ def predict(sequence):
80
+ features = extract_features(sequence)
81
+ if isinstance(features, str):
82
+ return features
83
+ prediction = model.predict(features)[0]
84
+ probabilities = model.predict_proba(features)[0]
85
+ if prediction == 0:
86
+ return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
87
+ else:
88
+ return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
89
+
90
+ # MIC Predictor
91
+ def predictmic(sequence):
92
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
93
  if len(sequence) < 10:
94
+ return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
 
 
95
  seq_spaced = ' '.join(list(sequence))
96
  tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
97
  tokens = {k: v.to(device) for k, v in tokens.items()}
 
 
98
  with torch.no_grad():
99
  outputs = protbert_model(**tokens)
 
100
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
 
 
101
  bacteria_config = {
102
+ "E.coli": {
103
+ "model": "coli_xgboost_model.pkl",
104
+ "scaler": "coli_scaler.pkl",
105
+ "pca": None
 
106
  },
107
+ "S.aureus": {
108
+ "model": "aur_xgboost_model.pkl",
109
+ "scaler": "aur_scaler.pkl",
110
+ "pca": None
 
111
  },
112
+ "P.aeruginosa": {
113
+ "model": "arg_xgboost_model.pkl",
114
+ "scaler": "arg_scaler.pkl",
115
+ "pca": None
 
116
  },
117
+ "K.Pneumonia": {
118
+ "model": "pne_mlp_model.pkl",
119
+ "scaler": "pne_scaler.pkl",
120
+ "pca": "pne_pca.pkl"
 
121
  }
122
  }
 
123
  mic_results = {}
124
+ for bacterium, cfg in bacteria_config.items():
 
 
 
 
 
125
  try:
126
+ scaler = joblib.load(cfg["scaler"])
127
+ scaled = scaler.transform(embedding)
128
+ if cfg["pca"]:
129
+ pca = joblib.load(cfg["pca"])
130
+ transformed = pca.transform(scaled)
 
 
 
131
  else:
132
+ transformed = scaled
133
+ model = joblib.load(cfg["model"])
134
+ mic_log = model.predict(transformed)[0]
135
+ mic = round(expm1(mic_log), 3)
136
+ mic_results[bacterium] = mic
 
 
 
 
137
  except Exception as e:
138
+ mic_results[bacterium] = f"Error: {str(e)}"
 
139
  return mic_results
140
 
141
+ # Combined Prediction
142
+ def full_prediction(sequence):
 
 
 
 
 
 
 
143
  features = extract_features(sequence)
144
+ if isinstance(features, str):
145
+ return "Error", "0%", {}
 
146
  prediction = model.predict(features)[0]
147
  probabilities = model.predict_proba(features)[0]
148
+ amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
149
+ confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
150
+ mic_values = predictmic(sequence)
151
+ return amp_result, f"{confidence}%", mic_values
152
+
153
+ # Gradio Interface
154
+ iface = gr.Interface(
155
+ fn=full_prediction,
156
+ inputs=gr.Textbox(label="Enter Protein Sequence"),
157
+ outputs=[
158
+ gr.Label(label="AMP Classification"),
159
+ gr.Label(label="Confidence"),
160
+ gr.JSON(label="Predicted MIC (µM) for Each Bacterium")
161
+ ],
162
+ title="AMP & MIC Predictor",
163
+ description="Enter an amino acid sequence (≥10 valid letters) to predict AMP class and MIC values."
164
+ )
165
+
166
+ iface.launch(share=True)