nonzeroexit commited on
Commit
77584b9
·
verified ·
1 Parent(s): febb4a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -43
app.py CHANGED
@@ -65,31 +65,20 @@ selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondarySt
65
 
66
  # LIME Explainer Setup
67
  try:
68
- # Attempt to load a real sample data for LIME background if available
69
- # e.g., sample_data = np.load(os.path.join(MODEL_DIR, 'sample_training_features_scaled.npy'))
70
  sample_data = np.random.rand(500, len(selected_features)) # Fallback: Generate random sample data
71
  except Exception:
72
  print("Warning: Could not load pre-saved sample data for LIME. Generating random sample data.")
73
- sample_data = np.random.rand(500, len(selected_features)) # Generate enough samples
74
 
75
  explainer = LimeTabularExplainer(
76
  training_data=sample_data,
77
  feature_names=selected_features,
78
- class_names=["AMP", "Non-AMP"], # Assuming 0 is AMP, 1 is Non-AMP as per model prediction
79
  mode="classification"
80
  )
81
 
82
  # --- Feature Extraction Function ---
83
  def extract_features(sequence: str) -> np.ndarray:
84
- """
85
- Extracts biochemical and compositional features from an amino acid sequence.
86
- Args:
87
- sequence (str): The amino acid sequence.
88
- Returns:
89
- np.ndarray: A scaled 2D numpy array of selected features (1, num_features).
90
- Raises:
91
- gr.Error: If the sequence is invalid or feature extraction fails.
92
- """
93
  cleaned_sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
94
  if not (10 <= len(cleaned_sequence) <= 100):
95
  raise gr.Error(f"Invalid sequence length ({len(cleaned_sequence)}). Must be between 10 and 100 characters and contain only standard amino acids.")
@@ -119,17 +108,6 @@ def extract_features(sequence: str) -> np.ndarray:
119
 
120
  # --- MIC Prediction Function ---
121
  def predictmic(sequence: str, selected_bacteria_keys: list) -> dict:
122
- """
123
- Predicts Minimum Inhibitory Concentration (MIC) for selected bacteria using ProtBert embeddings.
124
- Args:
125
- sequence (str): The amino acid sequence.
126
- selected_bacteria_keys (list): List of keys for bacteria to predict MIC for (e.g., ['e_coli', 'p_aeruginosa']).
127
- Returns:
128
- dict: A dictionary where keys are bacterium keys and values are predicted MICs in µM.
129
- Returns error messages for individual bacteria if prediction fails.
130
- Raises:
131
- gr.Error: If ProtBert embedding fails or sequence is invalid.
132
- """
133
  cleaned_sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
134
  if not (10 <= len(cleaned_sequence) <= 100):
135
  raise gr.Error(f"Invalid sequence length for MIC prediction ({len(cleaned_sequence)}). Must be between 10 and 100 characters.")
@@ -179,13 +157,6 @@ def predictmic(sequence: str, selected_bacteria_keys: list) -> dict:
179
 
180
  # --- LIME Plot Generation Helper ---
181
  def generate_lime_plot_base64(explanation_list: list) -> str:
182
- """
183
- Generates a LIME explanation plot and returns it as a base64 encoded PNG string.
184
- Args:
185
- explanation_list (list): The output from LimeExplanation.as_list().
186
- Returns:
187
- str: Base64 encoded PNG image string.
188
- """
189
  if not explanation_list:
190
  return ""
191
 
@@ -218,11 +189,6 @@ def generate_lime_plot_base64(explanation_list: list) -> str:
218
  # --- Gradio API Endpoints ---
219
 
220
  def classify_and_interpret_amp(sequence: str) -> dict:
221
- """
222
- Gradio API endpoint for AMP classification and interpretability (LIME).
223
- This function processes the sequence, performs classification, generates LIME explanation,
224
- and formats the output as a structured dictionary for the frontend.
225
- """
226
  try:
227
  features = extract_features(sequence)
228
 
@@ -240,8 +206,6 @@ def classify_and_interpret_amp(sequence: str) -> dict:
240
 
241
  top_features = []
242
  for feat_str, weight in explanation.as_list():
243
- # Parse the feature string from LIME (e.g., "APAAC4 <= 0.23")
244
- # This parsing is a heuristic based on LIME's default output format.
245
  parts = feat_str.split(" ", 1)
246
  feature_name = parts[0]
247
  condition = parts[1] if len(parts) > 1 else ""
@@ -267,10 +231,6 @@ def classify_and_interpret_amp(sequence: str) -> dict:
267
  raise gr.Error(f"An unexpected error occurred during AMP classification: {e}")
268
 
269
  def get_mic_predictions_api(sequence: str, selected_bacteria_keys: list) -> dict:
270
- """
271
- Gradio API endpoint for MIC prediction.
272
- This function wraps the `predictmic` function to serve as a separate API endpoint.
273
- """
274
  try:
275
  mic_results = predictmic(sequence, selected_bacteria_keys)
276
  return mic_results
@@ -312,4 +272,5 @@ with gr.Blocks() as demo:
312
  api_name="predict_mic"
313
  )
314
 
315
- demo.launch(share=True, enable_queue=True, show_api=True)
 
 
65
 
66
  # LIME Explainer Setup
67
  try:
 
 
68
  sample_data = np.random.rand(500, len(selected_features)) # Fallback: Generate random sample data
69
  except Exception:
70
  print("Warning: Could not load pre-saved sample data for LIME. Generating random sample data.")
71
+ sample_data = np.random.rand(500, len(selected_features))
72
 
73
  explainer = LimeTabularExplainer(
74
  training_data=sample_data,
75
  feature_names=selected_features,
76
+ class_names=["AMP", "Non-AMP"],
77
  mode="classification"
78
  )
79
 
80
  # --- Feature Extraction Function ---
81
  def extract_features(sequence: str) -> np.ndarray:
 
 
 
 
 
 
 
 
 
82
  cleaned_sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
83
  if not (10 <= len(cleaned_sequence) <= 100):
84
  raise gr.Error(f"Invalid sequence length ({len(cleaned_sequence)}). Must be between 10 and 100 characters and contain only standard amino acids.")
 
108
 
109
  # --- MIC Prediction Function ---
110
  def predictmic(sequence: str, selected_bacteria_keys: list) -> dict:
 
 
 
 
 
 
 
 
 
 
 
111
  cleaned_sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
112
  if not (10 <= len(cleaned_sequence) <= 100):
113
  raise gr.Error(f"Invalid sequence length for MIC prediction ({len(cleaned_sequence)}). Must be between 10 and 100 characters.")
 
157
 
158
  # --- LIME Plot Generation Helper ---
159
  def generate_lime_plot_base64(explanation_list: list) -> str:
 
 
 
 
 
 
 
160
  if not explanation_list:
161
  return ""
162
 
 
189
  # --- Gradio API Endpoints ---
190
 
191
  def classify_and_interpret_amp(sequence: str) -> dict:
 
 
 
 
 
192
  try:
193
  features = extract_features(sequence)
194
 
 
206
 
207
  top_features = []
208
  for feat_str, weight in explanation.as_list():
 
 
209
  parts = feat_str.split(" ", 1)
210
  feature_name = parts[0]
211
  condition = parts[1] if len(parts) > 1 else ""
 
231
  raise gr.Error(f"An unexpected error occurred during AMP classification: {e}")
232
 
233
  def get_mic_predictions_api(sequence: str, selected_bacteria_keys: list) -> dict:
 
 
 
 
234
  try:
235
  mic_results = predictmic(sequence, selected_bacteria_keys)
236
  return mic_results
 
272
  api_name="predict_mic"
273
  )
274
 
275
+ # Corrected launch command: removed 'enable_queue'
276
+ demo.launch(share=True, show_api=True)