nonzeroexit commited on
Commit
c83df39
·
verified ·
1 Parent(s): ce9bf06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -80
app.py CHANGED
@@ -9,21 +9,23 @@ from transformers import BertTokenizer, BertModel
9
  from math import expm1
10
 
11
  # =====================
12
- # Load AMP Classifier
13
  # =====================
 
14
  model = joblib.load("RF.joblib")
15
  scaler = joblib.load("norm (4).joblib")
16
 
17
  # =====================
18
- # Load ProtBert Globally
19
  # =====================
20
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
21
  protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
 
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- protbert_model = protbert_model.to(device).eval()
24
 
25
  # =====================
26
- # Feature List (ProPy)
27
  # =====================
28
  selected_features = [
29
  "_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
@@ -61,133 +63,222 @@ selected_features = [
61
  ]
62
 
63
  # =====================
64
- # AMP Feature Extractor
65
  # =====================
66
  def extract_features(sequence):
 
 
 
 
67
  all_features_dict = {}
 
68
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
 
69
  if len(sequence) < 10:
70
- return "Error: Sequence too short."
71
 
 
72
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
73
- filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
 
 
 
74
  ctd_features = CTD.CalculateCTD(sequence)
75
- auto_features = Autocorrelation.CalculateAutoTotal(sequence)
76
- pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
77
 
 
78
  all_features_dict.update(ctd_features)
79
  all_features_dict.update(filtered_dipeptide_features)
80
  all_features_dict.update(auto_features)
81
  all_features_dict.update(pseudo_features)
82
 
 
83
  feature_df_all = pd.DataFrame([all_features_dict])
 
 
 
 
 
 
 
 
 
 
84
  normalized_array = scaler.transform(feature_df_all.values)
85
- normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
86
- selected_df = normalized_df[selected_features].fillna(0)
 
87
 
88
  return selected_df.values
89
 
90
  # =====================
91
- # AMP Classifier
92
- # =====================
93
- def predict(sequence):
94
- features = extract_features(sequence)
95
- if isinstance(features, str):
96
- return features
97
- prediction = model.predict(features)[0]
98
- probabilities = model.predict_proba(features)[0]
99
- if prediction == 0:
100
- return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
101
- else:
102
- return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
103
-
104
- # =====================
105
- # MIC Predictor (ProtBert-based)
106
  # =====================
107
- def predictmic(sequence):
 
 
 
 
108
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
109
  if len(sequence) < 10:
110
- return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
111
 
112
- # Tokenize
113
  seq_spaced = ' '.join(list(sequence))
114
  tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
115
  tokens = {k: v.to(device) for k, v in tokens.items()}
116
 
 
117
  with torch.no_grad():
118
  outputs = protbert_model(**tokens)
 
119
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
120
 
121
- # MIC model config
122
  bacteria_config = {
123
- "E.coli": {
124
- "model": "coli_xgboost_model.pkl",
125
- "scaler": "coli_scaler.pkl",
126
- "pca": None
 
127
  },
128
- "S.aureus": {
129
- "model": "aur_xgboost_model.pkl",
130
- "scaler": "aur_scaler.pkl",
131
- "pca": None
 
132
  },
133
- "P.aeruginosa": {
134
- "model": "arg_xgboost_model.pkl",
135
- "scaler": "arg_scaler.pkl",
136
- "pca": None
 
137
  },
138
- "K.Pneumonia": {
139
- "model": "pne_mlp_model.pkl",
140
- "scaler": "pne_scaler.pkl",
141
- "pca": "pne_pca.pkl"
 
142
  }
143
  }
144
 
145
  mic_results = {}
146
- for bacterium, cfg in bacteria_config.items():
 
 
 
 
 
147
  try:
148
- scaler = joblib.load(cfg["scaler"])
149
- scaled = scaler.transform(embedding)
150
- if cfg["pca"]:
151
- pca = joblib.load(cfg["pca"])
152
- transformed = pca.transform(scaled)
 
 
 
153
  else:
154
- transformed = scaled
155
- model = joblib.load(cfg["model"])
156
- mic_log = model.predict(transformed)[0]
157
- mic = round(expm1(mic_log), 3)
158
- mic_results[bacterium] = mic
 
 
 
 
159
  except Exception as e:
160
- mic_results[bacterium] = f"Error: {str(e)}"
161
 
162
  return mic_results
163
 
164
  # =====================
165
- # Combined Prediction Function
166
  # =====================
167
- def full_prediction(sequence):
 
 
 
 
 
168
  features = extract_features(sequence)
169
- if isinstance(features, str):
170
- return "Error", "0%", {}
 
171
  prediction = model.predict(features)[0]
172
  probabilities = model.predict_proba(features)[0]
173
- amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
174
- confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
175
- mic_values = predictmic(sequence)
176
- return amp_result, f"{confidence}%", mic_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  # =====================
179
- # Gradio Interface
180
  # =====================
181
- iface = gr.Interface(
182
- fn=full_prediction,
183
- inputs=gr.Textbox(label="Enter Protein Sequence"),
184
- outputs=[
185
- gr.Label(label="AMP Classification"),
186
- gr.Label(label="Confidence"),
187
- gr.JSON(label="Predicted MIC (µM) for Each Bacterium")
188
- ],
189
- title="AMP & MIC Predictor",
190
- description="Enter an amino acid sequence (≥10 valid letters) to predict AMP class and MIC values."
191
- )
192
-
193
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from math import expm1
10
 
11
  # =====================
12
+ # Load AMP Classifier Model (Random Forest)
13
  # =====================
14
+ # Ensure 'RF.joblib' and 'norm (4).joblib' are in the same directory or provide full paths
15
  model = joblib.load("RF.joblib")
16
  scaler = joblib.load("norm (4).joblib")
17
 
18
  # =====================
19
+ # Load ProtBert Model Globally for MIC Prediction
20
  # =====================
21
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
22
  protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
23
+ # Move model to GPU if available for faster inference
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ protbert_model = protbert_model.to(device).eval() # Set to evaluation mode
26
 
27
  # =====================
28
+ # Feature List (ProPy Descriptors) used by AMP Classifier
29
  # =====================
30
  selected_features = [
31
  "_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
 
63
  ]
64
 
65
  # =====================
66
+ # AMP Feature Extractor Function
67
  # =====================
68
  def extract_features(sequence):
69
+ """
70
+ Extracts physiochemical and compositional features from a protein sequence using ProPy.
71
+ Applies the pre-trained scaler and selects relevant features.
72
+ """
73
  all_features_dict = {}
74
+ # Clean sequence to include only valid amino acids
75
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
76
+
77
  if len(sequence) < 10:
78
+ return "Error: Sequence too short or invalid. Must contain at least 10 valid amino acids."
79
 
80
+ # Calculate various ProPy features
81
  dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
82
+ # Note: Dipeptide composition calculates 400 features, using a slice here might be specific to the original model's training
83
+ # If the original model used all 400, this slice needs to be adjusted or removed.
84
+ # For now, keeping as per the provided code.
85
+ filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]} # This slice is unusual if only 400 dipeptides exist.
86
  ctd_features = CTD.CalculateCTD(sequence)
87
+ auto_features = Autocorrelation.CalculateAutoTotal(sequence) # Includes Moran, Geary, Moreau-Broto
88
+ pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9) # Pseudo Amino Acid Composition
89
 
90
+ # Combine all extracted features into a single dictionary
91
  all_features_dict.update(ctd_features)
92
  all_features_dict.update(filtered_dipeptide_features)
93
  all_features_dict.update(auto_features)
94
  all_features_dict.update(pseudo_features)
95
 
96
+ # Convert to DataFrame for consistent column handling with scaler
97
  feature_df_all = pd.DataFrame([all_features_dict])
98
+
99
+ # Handle missing features (if any arise from short sequences or specific AA combinations not producing all features)
100
+ # Ensure all selected_features are present, add as 0 if missing.
101
+ for col in selected_features:
102
+ if col not in feature_df_all.columns:
103
+ feature_df_all[col] = 0
104
+
105
+ # Normalize features using the pre-trained scaler
106
+ # Ensure the order of columns matches the scaler's training order before scaling
107
+ feature_df_all = feature_df_all[scaler.feature_names_in_] # Align columns with scaler's expected input
108
  normalized_array = scaler.transform(feature_df_all.values)
109
+
110
+ # Select only the features that the final RF model expects
111
+ selected_df = pd.DataFrame(normalized_array, columns=scaler.feature_names_in_)[selected_features].fillna(0)
112
 
113
  return selected_df.values
114
 
115
  # =====================
116
+ # MIC Predictor Function (ProtBert-based)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  # =====================
118
+ def predict_mic_values(sequence, selected_bacteria_keys):
119
+ """
120
+ Predicts Minimum Inhibitory Concentration (MIC) for a given peptide sequence
121
+ against selected bacteria using ProtBert embeddings and pre-trained models.
122
+ """
123
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
124
  if len(sequence) < 10:
125
+ return {"Error": "Sequence too short or invalid for MIC prediction."}
126
 
127
+ # Tokenize the sequence for ProtBert
128
  seq_spaced = ' '.join(list(sequence))
129
  tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
130
  tokens = {k: v.to(device) for k, v in tokens.items()}
131
 
132
+ # Get ProtBert embedding
133
  with torch.no_grad():
134
  outputs = protbert_model(**tokens)
135
+ # Use mean of last hidden state as sequence embedding
136
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
137
 
138
+ # Configuration for MIC models (paths to joblib files)
139
  bacteria_config = {
140
+ "e_coli": { # Changed keys to match frontend values
141
+ "display_name": "E.coli",
142
+ "model_path": "coli_xgboost_model.pkl",
143
+ "scaler_path": "coli_scaler.pkl",
144
+ "pca_path": None
145
  },
146
+ "s_aureus": { # Changed keys to match frontend values
147
+ "display_name": "S.aureus",
148
+ "model_path": "aur_xgboost_model.pkl",
149
+ "scaler_path": "aur_scaler.pkl",
150
+ "pca_path": None
151
  },
152
+ "p_aeruginosa": { # Changed keys to match frontend values
153
+ "display_name": "P.aeruginosa",
154
+ "model_path": "arg_xgboost_model.pkl",
155
+ "scaler_path": "arg_scaler.pkl",
156
+ "pca_path": None
157
  },
158
+ "k_pneumoniae": { # Changed keys to match frontend values
159
+ "display_name": "K.Pneumoniae",
160
+ "model_path": "pne_mlp_model.pkl",
161
+ "scaler_path": "pne_scaler.pkl",
162
+ "pca_path": "pne_pca.pkl"
163
  }
164
  }
165
 
166
  mic_results = {}
167
+ for bacterium_key in selected_bacteria_keys:
168
+ cfg = bacteria_config.get(bacterium_key)
169
+ if not cfg:
170
+ mic_results[bacterium_key] = "Error: Invalid bacterium key"
171
+ continue
172
+
173
  try:
174
+ # Load scaler and transform embedding
175
+ scaler = joblib.load(cfg["scaler_path"])
176
+ scaled_embedding = scaler.transform(embedding)
177
+
178
+ # Apply PCA if configured
179
+ if cfg["pca_path"]:
180
+ pca = joblib.load(cfg["pca_path"])
181
+ final_features = pca.transform(scaled_embedding)
182
  else:
183
+ final_features = scaled_embedding
184
+
185
+ # Load and predict with the MIC model
186
+ mic_model = joblib.load(cfg["model_path"])
187
+ mic_log = mic_model.predict(final_features)[0]
188
+
189
+ # Convert log-transformed MIC back to original scale (µM)
190
+ mic = round(expm1(mic_log), 3) # expm1(x) is equivalent to exp(x) - 1, robust for small x
191
+ mic_results[cfg["display_name"]] = mic
192
  except Exception as e:
193
+ mic_results[cfg["display_name"]] = f"Prediction Error: {str(e)}"
194
 
195
  return mic_results
196
 
197
  # =====================
198
+ # Gradio Interface Functions
199
  # =====================
200
+
201
+ def amp_classifier_predict(sequence):
202
+ """
203
+ Function for AMP classification endpoint in Gradio.
204
+ Returns the AMP classification label, confidence, and SHAP plot Base64 string.
205
+ """
206
  features = extract_features(sequence)
207
+ if isinstance(features, str): # Handle extraction error
208
+ return gr.Label(f"Error: {features}", label="AMP Classification"), None
209
+
210
  prediction = model.predict(features)[0]
211
  probabilities = model.predict_proba(features)[0]
212
+
213
+ amp_label = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
214
+ confidence_value = probabilities[prediction] # Confidence of the predicted class
215
+
216
+ # Placeholder for SHAP plot generation (not implemented in this snippet)
217
+ # In a real scenario, you'd generate a SHAP plot image here (e.g., using matplotlib, shap library)
218
+ # and encode it to base64.
219
+ shap_plot_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" # A tiny transparent PNG base64
220
+
221
+ # The Gradio `predict` function can return structured data as a dictionary if using `gr.JSON` output
222
+ # However, since the frontend is expecting `data[0].label`, `data[0].confidence`, etc.
223
+ # we'll return a dictionary that matches that structure.
224
+ return {
225
+ "label": amp_label,
226
+ "confidence": confidence_value,
227
+ "shap_plot_base64": shap_plot_base64 # Return SHAP plot as Base64 (placeholder for now)
228
+ }
229
+
230
+ def mic_predictor_predict(sequence, selected_bacteria):
231
+ """
232
+ Function for MIC prediction endpoint in Gradio.
233
+ Takes the sequence and a list of selected bacteria keys.
234
+ """
235
+ # Only predict MIC if AMP (Positive) classification
236
+ # This check would ideally be part of the frontend logic or a combined backend function
237
+ # but for standalone MIC endpoint, we just proceed.
238
+ # The frontend is responsible for calling this only if AMP is positive.
239
+ mic_results = predict_mic_values(sequence, selected_bacteria)
240
+ return mic_results # Returns a dictionary of MIC values
241
 
242
  # =====================
243
+ # Define Gradio Interface (hidden, for client connection)
244
  # =====================
245
+ # This Gradio app is designed to be used as a backend service by your custom HTML frontend.
246
+ # The inputs and outputs here correspond to what the frontend's `gradio.client` expects.
247
+
248
+ with gr.Blocks() as demo:
249
+ gr.Markdown("# BCBU-ZC AMP/MIC Backend Service")
250
+ gr.Markdown("This Gradio application serves as the backend for the AMP classification and MIC prediction. It provides endpoints for sequence analysis and MIC prediction.")
251
+
252
+ with gr.Tab("AMP Classification"):
253
+ gr.Markdown("### AMP Classification Endpoint (`/predict`)")
254
+ amp_input_sequence = gr.Textbox(label="Amino Acid Sequence")
255
+ amp_output_json = gr.JSON(label="Classification Result (Label, Confidence, SHAP Plot Base64)")
256
+ amp_predict_button = gr.Button("Predict AMP")
257
+ amp_predict_button.click(
258
+ fn=amp_classifier_predict,
259
+ inputs=[amp_input_sequence],
260
+ outputs=[amp_output_json],
261
+ api_name="predict" # Define an API endpoint name for `gradio.client`
262
+ )
263
+
264
+ with gr.Tab("MIC Prediction"):
265
+ gr.Markdown("### MIC Prediction Endpoint (`/predict_mic`)")
266
+ mic_input_sequence = gr.Textbox(label="Amino Acid Sequence")
267
+ mic_selected_bacteria = gr.CheckboxGroup(
268
+ label="Select Bacteria",
269
+ choices=["e_coli", "p_aeruginosa", "s_aureus", "k_pneumoniae"],
270
+ value=["e_coli", "p_aeruginosa", "s_aureus", "k_pneumoniae"] # Default for testing
271
+ )
272
+ mic_output_json = gr.JSON(label="Predicted MIC Values (µM)")
273
+ mic_predict_button = gr.Button("Predict MIC")
274
+ mic_predict_button.click(
275
+ fn=mic_predictor_predict,
276
+ inputs=[mic_input_sequence, mic_selected_bacteria],
277
+ outputs=[mic_output_json],
278
+ api_name="predict_mic" # Define a separate API endpoint name
279
+ )
280
+
281
+ # Launch the Gradio app
282
+ # `share=True` creates a public, temporary URL for external access (useful for testing frontend)
283
+ # `allowed_paths` should be set to allow access from specific origins if deploying
284
+ demo.launch(share=True)