chshan commited on
Commit
456a63e
·
verified ·
1 Parent(s): d8d52cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -52
app.py CHANGED
@@ -2,7 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
- # Final version incorporating user feedback on generator logic and UI controls.
6
 
7
  import os
8
  import torch
@@ -16,6 +16,9 @@ from tqdm import tqdm
16
  import transformers
17
  import time
18
 
 
 
 
19
  # Suppress verbose logging from transformers
20
  transformers.logging.set_verbosity_error()
21
 
@@ -32,26 +35,27 @@ id2token = {i: t for t, i in token2id.items()}
32
  VOCAB_SIZE = len(token2id)
33
 
34
 
35
- # --- Feature Extractor Model Class (For ProtT5) ---
36
- class FeatureProtT5Model:
37
- def __init__(self, base_model_id, finetuned_weights_path=None):
 
38
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
- print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
40
-
41
- print(f"Loading base model and tokenizer from '{base_model_id}'...")
42
- self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id, do_lower_case=False)
43
- self.model = transformers.T5EncoderModel.from_pretrained(base_model_id)
44
-
45
- if finetuned_weights_path and os.path.exists(finetuned_weights_path):
46
- print(f"Applying local fine-tuned weights from: {finetuned_weights_path}")
47
- state_dict = torch.load(finetuned_weights_path, map_location=self.device)
48
- self.model.load_state_dict(state_dict, strict=False)
49
- print("Successfully applied fine-tuned weights.")
50
- else:
51
- print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
52
 
53
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
54
  self.model.eval()
 
55
 
56
  def encode(self, sequence):
57
  if not sequence or not isinstance(sequence, str):
@@ -67,7 +71,7 @@ class FeatureProtT5Model:
67
  emb_np = embedding.squeeze(0).cpu().numpy()
68
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
69
 
70
- # --- Predictor Model Architecture ---
71
  class AntioxidantPredictor(nn.Module):
72
  def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
73
  super(AntioxidantPredictor, self).__init__()
@@ -99,7 +103,7 @@ class AntioxidantPredictor(nn.Module):
99
  def get_temperature(self):
100
  return self.temperature.item()
101
 
102
- # --- Generator Model Architecture ---
103
  class ProtT5Generator(nn.Module):
104
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
105
  super(ProtT5Generator, self).__init__()
@@ -141,17 +145,16 @@ class ProtT5Generator(nn.Module):
141
  sequences.append(seq)
142
  return sequences
143
 
144
- # --- CRITICAL DEPENDENCY: feature_extract.py ---
145
  try:
146
  from feature_extract import extract_features
147
  except ImportError:
148
  raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
149
 
150
- # --- Clustering Logic ---
151
  def cluster_sequences(generator, sequences, num_clusters, device):
152
  if not sequences or len(sequences) < num_clusters:
153
  return sequences[:num_clusters]
154
-
155
  with torch.no_grad():
156
  token_ids_list = []
157
  max_len = max((len(seq) for seq in sequences), default=0) + 2
@@ -160,13 +163,11 @@ def cluster_sequences(generator, sequences, num_clusters, device):
160
  ids = [np.random.randint(2, VOCAB_SIZE)] + ids
161
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
162
  token_ids_list.append(ids)
163
-
164
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
165
  embeddings = generator.embed_tokens(input_ids)
166
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
167
  seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
168
  seq_embeds_np = seq_embeds.cpu().numpy()
169
-
170
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
171
  representatives = []
172
  for i in range(int(num_clusters)):
@@ -191,10 +192,12 @@ try:
191
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
192
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
193
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
 
 
194
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
195
- FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
196
 
197
- # --- Load Predictor Model ---
198
  print(f"Loading Predictor from: {PREDICTOR_CHECKPOINT_PATH}")
199
  PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
200
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
@@ -202,13 +205,14 @@ try:
202
  PREDICTOR_MODEL.eval()
203
  print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
204
 
205
- # --- Load Scaler & Feature Extractor ---
206
  print(f"Loading Scaler from: {SCALER_PATH}")
207
  SCALER = joblib.load(SCALER_PATH)
208
- print("Loading ProtT5 Feature Extractor...")
209
- PROTT5_EXTRACTOR = FeatureProtT5Model(
 
210
  base_model_id=PROTT5_BASE_MODEL_ID,
211
- finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
212
  )
213
  print("✅ Scaler and Feature Extractor loaded.")
214
 
@@ -227,7 +231,7 @@ except Exception as e:
227
  raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
228
 
229
  # --------------------------------------------------------------------------
230
- # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI
231
  # --------------------------------------------------------------------------
232
 
233
  def predict_peptide_wrapper(sequence_str):
@@ -235,6 +239,8 @@ def predict_peptide_wrapper(sequence_str):
235
  return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
236
 
237
  try:
 
 
238
  features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
239
  scaled_features = SCALER.transform(features.reshape(1, -1))
240
 
@@ -251,29 +257,23 @@ def predict_peptide_wrapper(sequence_str):
251
  return "N/A", f"An error occurred during prediction: {e}"
252
 
253
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress()):
254
- """
255
- Handles the full generation-validation-clustering pipeline with a loop to ensure
256
- the target number of peptides is generated.
257
- """
258
  num_to_generate = int(num_to_generate)
259
  min_len = int(min_len)
260
  max_len = int(max_len)
261
 
262
- # Safety check for length
263
  if min_len > max_len:
264
  gr.Warning("Minimum Length cannot be greater than Maximum Length. Adjusting min_len = max_len.")
265
  min_len = max_len
266
 
267
  try:
268
- validated_pool = {} # Use a dictionary to store unique sequences and their probabilities
269
  attempts = 0
270
- max_attempts = 20 # Safety break to prevent infinite loops
271
- generation_batch_size = 200 # Number of sequences to generate in each attempt
272
 
273
  while len(validated_pool) < num_to_generate and attempts < max_attempts:
274
  progress(len(validated_pool) / num_to_generate, desc=f"Found {len(validated_pool)} / {num_to_generate} peptides. (Attempt {attempts+1}/{max_attempts})")
275
 
276
- # Generate a batch of candidate sequences
277
  with torch.no_grad():
278
  generated_tokens = GENERATOR_MODEL.sample(
279
  batch_size=generation_batch_size, max_length=max_len, device=DEVICE,
@@ -281,21 +281,18 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
281
  )
282
  decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
283
 
284
- # Filter for length and uniqueness
285
  new_candidates = []
286
  for seq in decoded_sequences:
287
  if min_len <= len(seq) <= max_len:
288
  if seq not in validated_pool:
289
  new_candidates.append(seq)
290
 
291
- # Validate the new, unique candidates
292
  for seq in new_candidates:
293
  prob_str, _ = predict_peptide_wrapper(seq)
294
  try:
295
  prob = float(prob_str)
296
  if prob > 0.90:
297
  validated_pool[seq] = prob
298
- # Check if we have reached the target
299
  if len(validated_pool) >= num_to_generate:
300
  break
301
  except (ValueError, TypeError):
@@ -311,13 +308,9 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
311
  if not validated_pool:
312
  return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with the current settings. Try different parameters.", "Predicted Probability": "N/A"}])
313
 
314
- # --- Final Processing ---
315
  high_quality_sequences = list(validated_pool.keys())
316
-
317
- # Cluster to ensure diversity, selecting up to the target number
318
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
319
 
320
- # Format final results into a DataFrame
321
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
322
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
323
 
@@ -328,7 +321,7 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
328
  return pd.DataFrame([{"Sequence": f"An error occurred during generation: {e}", "Predicted Probability": "N/A"}])
329
 
330
  # --------------------------------------------------------------------------
331
- # SECTION 4: GRADIO UI CONSTRUCTION
332
  # --------------------------------------------------------------------------
333
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
334
  gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction")
@@ -364,7 +357,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
364
  with gr.Column():
365
  with gr.Row():
366
  num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
367
- # ✅ MODIFIED: Length sliders both have a range of 2-20
368
  min_len_input = gr.Slider(minimum=2, maximum=20, value=3, step=1, label="Minimum Length")
369
  max_len_input = gr.Slider(minimum=2, maximum=20, value=20, step=1, label="Maximum Length")
370
  with gr.Row():
@@ -374,10 +366,9 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
374
  generate_button = gr.Button("Generate Peptides", variant="primary")
375
  results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
376
 
377
- # ✅ ADDED: Dynamic linking of min and max length sliders for better UX
378
  def update_min_len_range(max_len):
379
  return gr.Slider(maximum=max_len)
380
- max_len_input.change(fn=update_min_len_range, inputs=max_len_input, outputs=min_len_input)
381
 
382
  def update_max_len_range(min_len):
383
  return gr.Slider(minimum=min_len)
 
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
+ # Final version updated to use a LoRA-finetuned model for feature extraction.
6
 
7
  import os
8
  import torch
 
16
  import transformers
17
  import time
18
 
19
+ # NEW DEPENDENCY: peft library for LoRA
20
+ from peft import PeftModel
21
+
22
  # Suppress verbose logging from transformers
23
  transformers.logging.set_verbosity_error()
24
 
 
35
  VOCAB_SIZE = len(token2id)
36
 
37
 
38
+ # --- LoRA Feature Extractor Model Class ---
39
+ # ✅ REPLACED: This new class handles loading the base model and attaching the LoRA adapter.
40
+ class LoRAProtT5Extractor:
41
+ def __init__(self, base_model_id, lora_adapter_path):
42
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ print(f"Initializing LoRA-enhanced ProtT5 on device: {self.device}")
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ print(f" - Loading base model and tokenizer from '{base_model_id}'...")
46
+ base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
47
+ self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id)
48
+
49
+ if not os.path.exists(lora_adapter_path):
50
+ raise FileNotFoundError(f"Error: LoRA adapter directory not found at: {lora_adapter_path}")
51
+
52
+ print(f" - Loading and applying LoRA adapter from: {lora_adapter_path}")
53
+ lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
54
+
55
+ print(" - Merging LoRA weights for faster inference...")
56
+ self.model = lora_model.merge_and_unload().to(self.device)
57
  self.model.eval()
58
+ print(" - LoRA-enhanced feature extractor is ready.")
59
 
60
  def encode(self, sequence):
61
  if not sequence or not isinstance(sequence, str):
 
71
  emb_np = embedding.squeeze(0).cpu().numpy()
72
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
73
 
74
+ # --- Predictor Model Architecture (Unchanged) ---
75
  class AntioxidantPredictor(nn.Module):
76
  def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
77
  super(AntioxidantPredictor, self).__init__()
 
103
  def get_temperature(self):
104
  return self.temperature.item()
105
 
106
+ # --- Generator Model Architecture (Unchanged) ---
107
  class ProtT5Generator(nn.Module):
108
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
109
  super(ProtT5Generator, self).__init__()
 
145
  sequences.append(seq)
146
  return sequences
147
 
148
+ # --- CRITICAL DEPENDENCY: feature_extract.py (Unchanged) ---
149
  try:
150
  from feature_extract import extract_features
151
  except ImportError:
152
  raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
153
 
154
+ # --- Clustering Logic (Unchanged) ---
155
  def cluster_sequences(generator, sequences, num_clusters, device):
156
  if not sequences or len(sequences) < num_clusters:
157
  return sequences[:num_clusters]
 
158
  with torch.no_grad():
159
  token_ids_list = []
160
  max_len = max((len(seq) for seq in sequences), default=0) + 2
 
163
  ids = [np.random.randint(2, VOCAB_SIZE)] + ids
164
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
165
  token_ids_list.append(ids)
 
166
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
167
  embeddings = generator.embed_tokens(input_ids)
168
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
169
  seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
170
  seq_embeds_np = seq_embeds.cpu().numpy()
 
171
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
172
  representatives = []
173
  for i in range(int(num_clusters)):
 
192
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
193
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
194
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
195
+
196
+ # ✅ UPDATED: Define paths for LoRA-based loading
197
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
198
+ LORA_ADAPTER_PATH = "./lora_finetuned_prott5" # Assumes LoRA files are in this directory
199
 
200
+ # --- Load Predictor Model (Head) ---
201
  print(f"Loading Predictor from: {PREDICTOR_CHECKPOINT_PATH}")
202
  PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
203
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
 
205
  PREDICTOR_MODEL.eval()
206
  print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
207
 
208
+ # --- Load Scaler & LoRA Feature Extractor ---
209
  print(f"Loading Scaler from: {SCALER_PATH}")
210
  SCALER = joblib.load(SCALER_PATH)
211
+ print("Loading LoRA-enhanced ProtT5 Feature Extractor...")
212
+ # UPDATED: Instantiate the new LoRA extractor class
213
+ PROTT5_EXTRACTOR = LoRAProtT5Extractor(
214
  base_model_id=PROTT5_BASE_MODEL_ID,
215
+ lora_adapter_path=LORA_ADAPTER_PATH
216
  )
217
  print("✅ Scaler and Feature Extractor loaded.")
218
 
 
231
  raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
232
 
233
  # --------------------------------------------------------------------------
234
+ # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI (Unchanged logic)
235
  # --------------------------------------------------------------------------
236
 
237
  def predict_peptide_wrapper(sequence_str):
 
239
  return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
240
 
241
  try:
242
+ # This function call remains the same because the PROTT5_EXTRACTOR object,
243
+ # despite its new internal logic, provides the same interface.
244
  features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
245
  scaled_features = SCALER.transform(features.reshape(1, -1))
246
 
 
257
  return "N/A", f"An error occurred during prediction: {e}"
258
 
259
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress()):
 
 
 
 
260
  num_to_generate = int(num_to_generate)
261
  min_len = int(min_len)
262
  max_len = int(max_len)
263
 
 
264
  if min_len > max_len:
265
  gr.Warning("Minimum Length cannot be greater than Maximum Length. Adjusting min_len = max_len.")
266
  min_len = max_len
267
 
268
  try:
269
+ validated_pool = {}
270
  attempts = 0
271
+ max_attempts = 20
272
+ generation_batch_size = 200
273
 
274
  while len(validated_pool) < num_to_generate and attempts < max_attempts:
275
  progress(len(validated_pool) / num_to_generate, desc=f"Found {len(validated_pool)} / {num_to_generate} peptides. (Attempt {attempts+1}/{max_attempts})")
276
 
 
277
  with torch.no_grad():
278
  generated_tokens = GENERATOR_MODEL.sample(
279
  batch_size=generation_batch_size, max_length=max_len, device=DEVICE,
 
281
  )
282
  decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
283
 
 
284
  new_candidates = []
285
  for seq in decoded_sequences:
286
  if min_len <= len(seq) <= max_len:
287
  if seq not in validated_pool:
288
  new_candidates.append(seq)
289
 
 
290
  for seq in new_candidates:
291
  prob_str, _ = predict_peptide_wrapper(seq)
292
  try:
293
  prob = float(prob_str)
294
  if prob > 0.90:
295
  validated_pool[seq] = prob
 
296
  if len(validated_pool) >= num_to_generate:
297
  break
298
  except (ValueError, TypeError):
 
308
  if not validated_pool:
309
  return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with the current settings. Try different parameters.", "Predicted Probability": "N/A"}])
310
 
 
311
  high_quality_sequences = list(validated_pool.keys())
 
 
312
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
313
 
 
314
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
315
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
316
 
 
321
  return pd.DataFrame([{"Sequence": f"An error occurred during generation: {e}", "Predicted Probability": "N/A"}])
322
 
323
  # --------------------------------------------------------------------------
324
+ # SECTION 4: GRADIO UI CONSTRUCTION (Unchanged)
325
  # --------------------------------------------------------------------------
326
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
327
  gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction")
 
357
  with gr.Column():
358
  with gr.Row():
359
  num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
 
360
  min_len_input = gr.Slider(minimum=2, maximum=20, value=3, step=1, label="Minimum Length")
361
  max_len_input = gr.Slider(minimum=2, maximum=20, value=20, step=1, label="Maximum Length")
362
  with gr.Row():
 
366
  generate_button = gr.Button("Generate Peptides", variant="primary")
367
  results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
368
 
 
369
  def update_min_len_range(max_len):
370
  return gr.Slider(maximum=max_len)
371
+ max_len_input.change(fn=update_min_len_range, inputs=max_len_input, outputs=max_len_input)
372
 
373
  def update_max_len_range(min_len):
374
  return gr.Slider(minimum=min_len)