chshan commited on
Commit
f1d641f
Β·
verified Β·
1 Parent(s): 6a02daf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -106
app.py CHANGED
@@ -2,22 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
- # This script combines logic from predictor.py, generator.py, and the original app.py
6
- # into a single, self-contained file for a Hugging Face Space.
7
- #
8
- # REQUIRED FILE STRUCTURE IN HUGGING FACE REPO:
9
- # .
10
- # β”œβ”€β”€ app.py (This file)
11
- # β”œβ”€β”€ feature_extract.py (CRITICAL: This file with your `extract_features` function MUST be present)
12
- # β”œβ”€β”€ checkpoints/
13
- # β”‚ β”œβ”€β”€ final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth
14
- # β”‚ └── scaler_FINETUNED_PROTT5.pkl
15
- # β”œβ”€β”€ generator_checkpoints_v3.6/
16
- # β”‚ └── final_generator_model.pth
17
- # β”œβ”€β”€ prott5/
18
- # β”‚ └── model/
19
- # β”‚ └── finetuned_prott5.bin (Your fine-tuned feature extractor weights)
20
- # └── requirements.txt
21
 
22
  import os
23
  import torch
@@ -29,15 +14,16 @@ import gradio as gr
29
  from sklearn.cluster import KMeans
30
  from tqdm import tqdm
31
  import transformers
 
32
 
33
- # Suppress verbose logging from transformers, which can clutter the app logs
34
  transformers.logging.set_verbosity_error()
35
 
36
  # --------------------------------------------------------------------------
37
  # SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
38
  # --------------------------------------------------------------------------
39
 
40
- # --- Vocabulary Definition (Consistent across all scripts) ---
41
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
42
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
43
  token2id["<PAD>"] = 0
@@ -47,19 +33,15 @@ VOCAB_SIZE = len(token2id)
47
 
48
 
49
  # --- Feature Extractor Model Class (For ProtT5) ---
50
- # MODIFIED: This class now loads the base model from the Hugging Face Hub ID
51
- # and then applies your local fine-tuned weights.
52
  class FeatureProtT5Model:
53
  def __init__(self, base_model_id, finetuned_weights_path=None):
54
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
55
  print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
56
 
57
- # Load the base model architecture and tokenizer directly from the Hub ID.
58
  print(f"Loading base model and tokenizer from '{base_model_id}'...")
59
  self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id, do_lower_case=False)
60
  self.model = transformers.T5EncoderModel.from_pretrained(base_model_id)
61
 
62
- # If a path to a fine-tuned weights file is provided, load and apply those weights.
63
  if finetuned_weights_path and os.path.exists(finetuned_weights_path):
64
  print(f"Applying local fine-tuned weights from: {finetuned_weights_path}")
65
  state_dict = torch.load(finetuned_weights_path, map_location=self.device)
@@ -71,46 +53,28 @@ class FeatureProtT5Model:
71
  self.model.to(self.device)
72
  self.model.eval()
73
 
74
- # βœ… NEWLY ADDED METHOD: This provides the functionality to encode sequences.
75
  def encode(self, sequence):
76
- """
77
- Takes a peptide sequence string and returns its ProtT5 embedding.
78
- """
79
- # The extract_features function expects this method to exist.
80
  if not sequence or not isinstance(sequence, str):
81
- # Return a zero vector of the correct shape if input is invalid
82
  return np.zeros((1, 1024), dtype=np.float32)
83
 
84
- # ProtT5 expects amino acids to be separated by spaces.
85
  seq_spaced = " ".join(list(sequence))
86
-
87
- # Tokenize the input sequence.
88
  encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True)
89
  encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
90
-
91
- # Get embeddings from the model.
92
  with torch.no_grad():
93
  embedding = self.model(**encoded_input).last_hidden_state
94
 
95
- # Move the embedding to CPU and convert to a NumPy array.
96
- # Squeeze to remove the batch dimension.
97
  emb_np = embedding.squeeze(0).cpu().numpy()
98
-
99
- # Handle cases where the embedding might be empty.
100
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
101
 
102
-
103
-
104
  # --- Predictor Model Architecture ---
105
- # This is the antioxidant activity predictor model. Its architecture must
106
- # exactly match the architecture used to save the checkpoint file.
107
  class AntioxidantPredictor(nn.Module):
108
  def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
109
  super(AntioxidantPredictor, self).__init__()
110
  self.prott5_dim = 1024
111
  self.handcrafted_dim = input_dim - self.prott5_dim
112
  self.seq_len = 16
113
- self.prott5_feature_dim = 64 # 16 * 64 = 1024
114
 
115
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
116
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
@@ -122,26 +86,20 @@ class AntioxidantPredictor(nn.Module):
122
 
123
  def forward(self, x):
124
  batch_size = x.size(0)
125
- # The input 'x' is a flat 1914-dim vector from extract_features()
126
  prot_t5_features = x[:, :self.prott5_dim]
127
  handcrafted_features = x[:, self.prott5_dim:]
128
-
129
- # Reshape the first 1024 features back into a sequence representation
130
  prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
131
-
132
  encoded_seq = self.transformer_encoder(prot_t5_seq)
133
  refined_prott5 = encoded_seq.mean(dim=1)
134
-
135
  fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
136
  fused_output = self.fusion_fc(fused_features)
137
  logits = self.classifier(fused_output)
138
-
139
  return logits / self.temperature
140
 
141
  def get_temperature(self):
142
  return self.temperature.item()
143
 
144
- # --- Generator Model Architecture (from generator.py) ---
145
  class ProtT5Generator(nn.Module):
146
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
147
  super(ProtT5Generator, self).__init__()
@@ -167,7 +125,6 @@ class ProtT5Generator(nn.Module):
167
  next_logits = logits[:, -1, :] / temperature
168
  if generated.size(1) < min_decoded_length:
169
  next_logits[:, self.eos_token_id] = -float("inf")
170
-
171
  probs = torch.softmax(next_logits, dim=-1)
172
  next_token = torch.multinomial(probs, num_samples=1)
173
  generated = torch.cat((generated, next_token), dim=1)
@@ -177,7 +134,7 @@ class ProtT5Generator(nn.Module):
177
  sequences = []
178
  for ids_tensor in token_ids_batch:
179
  seq = ""
180
- for token_id in ids_tensor.tolist()[1:]: # Skip the random start token
181
  if token_id == self.eos_token_id: break
182
  if token_id == self.pad_token_id: continue
183
  seq += id2token.get(token_id, "")
@@ -185,15 +142,12 @@ class ProtT5Generator(nn.Module):
185
  return sequences
186
 
187
  # --- CRITICAL DEPENDENCY: feature_extract.py ---
188
- # This application requires a function named `extract_features` to convert a peptide
189
- # sequence into a 1914-dimensional feature vector for the prediction model.
190
- # This function must be defined in a file named `feature_extract.py` in the repository root.
191
  try:
192
  from feature_extract import extract_features
193
  except ImportError:
194
- raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required for the application to run. Please upload it to your repository.")
195
 
196
- # --- Clustering Logic (from generator.py) ---
197
  def cluster_sequences(generator, sequences, num_clusters, device):
198
  if not sequences or len(sequences) < num_clusters:
199
  return sequences[:num_clusters]
@@ -203,7 +157,7 @@ def cluster_sequences(generator, sequences, num_clusters, device):
203
  max_len = max((len(seq) for seq in sequences), default=0) + 2
204
  for seq in sequences:
205
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
206
- ids = [np.random.randint(2, VOCAB_SIZE)] + ids # Prepend a start token
207
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
208
  token_ids_list.append(ids)
209
 
@@ -233,12 +187,10 @@ print("--- Starting Application: Loading all models and dependencies ---")
233
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
234
 
235
  try:
236
- # --- Define file paths relative to the repository root ---
237
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
238
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
239
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
240
-
241
- # Define the base model ID from the Hub and the path to your local fine-tuned weights.
242
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
243
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
244
 
@@ -254,7 +206,6 @@ try:
254
  print(f"Loading Scaler from: {SCALER_PATH}")
255
  SCALER = joblib.load(SCALER_PATH)
256
  print("Loading ProtT5 Feature Extractor...")
257
- # Pass the Hub ID to the updated class to load the base model.
258
  PROTT5_EXTRACTOR = FeatureProtT5Model(
259
  base_model_id=PROTT5_BASE_MODEL_ID,
260
  finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
@@ -280,16 +231,11 @@ except Exception as e:
280
  # --------------------------------------------------------------------------
281
 
282
  def predict_peptide_wrapper(sequence_str):
283
- """Handles the prediction for a single peptide sequence from the UI."""
284
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
285
  return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
286
 
287
  try:
288
- # Use the imported extract_features function.
289
- # The L_fixed and d_model_pe values are taken from your original predictor.py arguments.
290
  features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
291
-
292
- # Scale the features using the loaded scaler
293
  scaled_features = SCALER.transform(features.reshape(1, -1))
294
 
295
  with torch.no_grad():
@@ -304,57 +250,74 @@ def predict_peptide_wrapper(sequence_str):
304
  print(f"Prediction Error for sequence '{sequence_str}': {e}")
305
  return "N/A", f"An error occurred during prediction: {e}"
306
 
307
- def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
308
- """Handles the full generation-validation-clustering pipeline."""
 
 
 
309
  num_to_generate = int(num_to_generate)
310
  min_len = int(min_len)
311
  max_len = int(max_len)
 
 
 
 
 
312
 
313
  try:
314
- # Step 1: Generate a large, unique pool of candidate sequences
315
- target_pool_size = int(num_to_generate * diversity_factor)
316
- unique_seqs = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
- pbar_desc = "Step 1/3: Generating candidate sequences"
319
- with tqdm(total=target_pool_size, desc=pbar_desc) as pbar:
320
- while len(unique_seqs) < target_pool_size:
321
- batch_size = max(1, (target_pool_size - len(unique_seqs)))
322
- with torch.no_grad():
323
- generated_tokens = GENERATOR_MODEL.sample(
324
- batch_size=batch_size, max_length=max_len, device=DEVICE,
325
- temperature=temperature, min_decoded_length=min_len
326
- )
327
- decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
328
-
329
- initial_count = len(unique_seqs)
330
- for seq in decoded_sequences:
331
- if min_len <= len(seq) <= max_len:
332
- unique_seqs.add(seq)
333
- pbar.update(len(unique_seqs) - initial_count)
334
-
335
- candidate_seqs = list(unique_seqs)
336
-
337
- # Step 2: Validate the generated sequences and filter for high probability
338
- validated_pool = {}
339
- for seq in tqdm(candidate_seqs, desc="Step 2/3: Validating generated sequences"):
340
- prob_str, _ = predict_peptide_wrapper(seq)
341
- try:
342
- prob = float(prob_str)
343
- if prob > 0.90:
344
- validated_pool[seq] = prob
345
- except (ValueError, TypeError):
346
- continue
347
 
348
  if not validated_pool:
349
- return pd.DataFrame([{"Sequence": "No high-activity peptides (>0.9 prob) were generated. Try increasing the Diversity Factor or changing the Temperature.", "Predicted Probability": "N/A"}])
350
 
 
351
  high_quality_sequences = list(validated_pool.keys())
352
 
353
- # Step 3: Cluster to ensure diversity in the final set
354
- progress(1.0, desc="Step 3/3: Clustering for diversity...")
355
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
356
 
357
- # Step 4: Format final results into a DataFrame
358
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
359
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
360
 
@@ -401,8 +364,9 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
401
  with gr.Column():
402
  with gr.Row():
403
  num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
404
- min_len_input = gr.Slider(minimum=3, maximum=10, value=3, step=1, label="Minimum Length")
405
- max_len_input = gr.Slider(minimum=10, maximum=20, value=20, step=1, label="Maximum Length")
 
406
  with gr.Row():
407
  temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
408
  diversity_input = gr.Slider(minimum=1.1, maximum=5.0, value=1.5, step=0.1, label="Diversity Factor (Larger initial pool for clustering)")
@@ -410,6 +374,15 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
410
  generate_button = gr.Button("Generate Peptides", variant="primary")
411
  results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
412
 
 
 
 
 
 
 
 
 
 
413
  generate_button.click(
414
  fn=generate_peptide_wrapper,
415
  inputs=[num_input, min_len_input, max_len_input, temp_input, diversity_input],
 
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
+ # Final version incorporating user feedback on generator logic and UI controls.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  import os
8
  import torch
 
14
  from sklearn.cluster import KMeans
15
  from tqdm import tqdm
16
  import transformers
17
+ import time
18
 
19
+ # Suppress verbose logging from transformers
20
  transformers.logging.set_verbosity_error()
21
 
22
  # --------------------------------------------------------------------------
23
  # SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
24
  # --------------------------------------------------------------------------
25
 
26
+ # --- Vocabulary Definition ---
27
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
28
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
29
  token2id["<PAD>"] = 0
 
33
 
34
 
35
  # --- Feature Extractor Model Class (For ProtT5) ---
 
 
36
  class FeatureProtT5Model:
37
  def __init__(self, base_model_id, finetuned_weights_path=None):
38
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
  print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
40
 
 
41
  print(f"Loading base model and tokenizer from '{base_model_id}'...")
42
  self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id, do_lower_case=False)
43
  self.model = transformers.T5EncoderModel.from_pretrained(base_model_id)
44
 
 
45
  if finetuned_weights_path and os.path.exists(finetuned_weights_path):
46
  print(f"Applying local fine-tuned weights from: {finetuned_weights_path}")
47
  state_dict = torch.load(finetuned_weights_path, map_location=self.device)
 
53
  self.model.to(self.device)
54
  self.model.eval()
55
 
 
56
  def encode(self, sequence):
 
 
 
 
57
  if not sequence or not isinstance(sequence, str):
 
58
  return np.zeros((1, 1024), dtype=np.float32)
59
 
 
60
  seq_spaced = " ".join(list(sequence))
 
 
61
  encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True)
62
  encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
63
+
 
64
  with torch.no_grad():
65
  embedding = self.model(**encoded_input).last_hidden_state
66
 
 
 
67
  emb_np = embedding.squeeze(0).cpu().numpy()
 
 
68
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
69
 
 
 
70
  # --- Predictor Model Architecture ---
 
 
71
  class AntioxidantPredictor(nn.Module):
72
  def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
73
  super(AntioxidantPredictor, self).__init__()
74
  self.prott5_dim = 1024
75
  self.handcrafted_dim = input_dim - self.prott5_dim
76
  self.seq_len = 16
77
+ self.prott5_feature_dim = 64
78
 
79
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
80
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
 
86
 
87
  def forward(self, x):
88
  batch_size = x.size(0)
 
89
  prot_t5_features = x[:, :self.prott5_dim]
90
  handcrafted_features = x[:, self.prott5_dim:]
 
 
91
  prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
 
92
  encoded_seq = self.transformer_encoder(prot_t5_seq)
93
  refined_prott5 = encoded_seq.mean(dim=1)
 
94
  fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
95
  fused_output = self.fusion_fc(fused_features)
96
  logits = self.classifier(fused_output)
 
97
  return logits / self.temperature
98
 
99
  def get_temperature(self):
100
  return self.temperature.item()
101
 
102
+ # --- Generator Model Architecture ---
103
  class ProtT5Generator(nn.Module):
104
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
105
  super(ProtT5Generator, self).__init__()
 
125
  next_logits = logits[:, -1, :] / temperature
126
  if generated.size(1) < min_decoded_length:
127
  next_logits[:, self.eos_token_id] = -float("inf")
 
128
  probs = torch.softmax(next_logits, dim=-1)
129
  next_token = torch.multinomial(probs, num_samples=1)
130
  generated = torch.cat((generated, next_token), dim=1)
 
134
  sequences = []
135
  for ids_tensor in token_ids_batch:
136
  seq = ""
137
+ for token_id in ids_tensor.tolist()[1:]:
138
  if token_id == self.eos_token_id: break
139
  if token_id == self.pad_token_id: continue
140
  seq += id2token.get(token_id, "")
 
142
  return sequences
143
 
144
  # --- CRITICAL DEPENDENCY: feature_extract.py ---
 
 
 
145
  try:
146
  from feature_extract import extract_features
147
  except ImportError:
148
+ raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
149
 
150
+ # --- Clustering Logic ---
151
  def cluster_sequences(generator, sequences, num_clusters, device):
152
  if not sequences or len(sequences) < num_clusters:
153
  return sequences[:num_clusters]
 
157
  max_len = max((len(seq) for seq in sequences), default=0) + 2
158
  for seq in sequences:
159
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
160
+ ids = [np.random.randint(2, VOCAB_SIZE)] + ids
161
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
162
  token_ids_list.append(ids)
163
 
 
187
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
188
 
189
  try:
190
+ # --- Define file paths ---
191
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
192
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
193
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
 
 
194
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
195
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
196
 
 
206
  print(f"Loading Scaler from: {SCALER_PATH}")
207
  SCALER = joblib.load(SCALER_PATH)
208
  print("Loading ProtT5 Feature Extractor...")
 
209
  PROTT5_EXTRACTOR = FeatureProtT5Model(
210
  base_model_id=PROTT5_BASE_MODEL_ID,
211
  finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
 
231
  # --------------------------------------------------------------------------
232
 
233
  def predict_peptide_wrapper(sequence_str):
 
234
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
235
  return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
236
 
237
  try:
 
 
238
  features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
 
 
239
  scaled_features = SCALER.transform(features.reshape(1, -1))
240
 
241
  with torch.no_grad():
 
250
  print(f"Prediction Error for sequence '{sequence_str}': {e}")
251
  return "N/A", f"An error occurred during prediction: {e}"
252
 
253
+ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress()):
254
+ """
255
+ Handles the full generation-validation-clustering pipeline with a loop to ensure
256
+ the target number of peptides is generated.
257
+ """
258
  num_to_generate = int(num_to_generate)
259
  min_len = int(min_len)
260
  max_len = int(max_len)
261
+
262
+ # Safety check for length
263
+ if min_len > max_len:
264
+ gr.Warning("Minimum Length cannot be greater than Maximum Length. Adjusting min_len = max_len.")
265
+ min_len = max_len
266
 
267
  try:
268
+ validated_pool = {} # Use a dictionary to store unique sequences and their probabilities
269
+ attempts = 0
270
+ max_attempts = 20 # Safety break to prevent infinite loops
271
+ generation_batch_size = 200 # Number of sequences to generate in each attempt
272
+
273
+ while len(validated_pool) < num_to_generate and attempts < max_attempts:
274
+ progress(len(validated_pool) / num_to_generate, desc=f"Found {len(validated_pool)} / {num_to_generate} peptides. (Attempt {attempts+1}/{max_attempts})")
275
+
276
+ # Generate a batch of candidate sequences
277
+ with torch.no_grad():
278
+ generated_tokens = GENERATOR_MODEL.sample(
279
+ batch_size=generation_batch_size, max_length=max_len, device=DEVICE,
280
+ temperature=temperature, min_decoded_length=min_len
281
+ )
282
+ decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
283
+
284
+ # Filter for length and uniqueness
285
+ new_candidates = []
286
+ for seq in decoded_sequences:
287
+ if min_len <= len(seq) <= max_len:
288
+ if seq not in validated_pool:
289
+ new_candidates.append(seq)
290
+
291
+ # Validate the new, unique candidates
292
+ for seq in new_candidates:
293
+ prob_str, _ = predict_peptide_wrapper(seq)
294
+ try:
295
+ prob = float(prob_str)
296
+ if prob > 0.90:
297
+ validated_pool[seq] = prob
298
+ # Check if we have reached the target
299
+ if len(validated_pool) >= num_to_generate:
300
+ break
301
+ except (ValueError, TypeError):
302
+ continue
303
+
304
+ attempts += 1
305
+ if len(validated_pool) >= num_to_generate:
306
+ break
307
 
308
+ progress(1.0, desc=f"Collected {len(validated_pool)} high-quality peptides. Clustering for diversity...")
309
+ time.sleep(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  if not validated_pool:
312
+ return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with the current settings. Try different parameters.", "Predicted Probability": "N/A"}])
313
 
314
+ # --- Final Processing ---
315
  high_quality_sequences = list(validated_pool.keys())
316
 
317
+ # Cluster to ensure diversity, selecting up to the target number
 
318
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
319
 
320
+ # Format final results into a DataFrame
321
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
322
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
323
 
 
364
  with gr.Column():
365
  with gr.Row():
366
  num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
367
+ # βœ… MODIFIED: Length sliders both have a range of 2-20
368
+ min_len_input = gr.Slider(minimum=2, maximum=20, value=3, step=1, label="Minimum Length")
369
+ max_len_input = gr.Slider(minimum=2, maximum=20, value=20, step=1, label="Maximum Length")
370
  with gr.Row():
371
  temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
372
  diversity_input = gr.Slider(minimum=1.1, maximum=5.0, value=1.5, step=0.1, label="Diversity Factor (Larger initial pool for clustering)")
 
374
  generate_button = gr.Button("Generate Peptides", variant="primary")
375
  results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
376
 
377
+ # βœ… ADDED: Dynamic linking of min and max length sliders for better UX
378
+ def update_min_len_range(max_len):
379
+ return gr.Slider(maximum=max_len)
380
+ max_len_input.change(fn=update_min_len_range, inputs=max_len_input, outputs=min_len_input)
381
+
382
+ def update_max_len_range(min_len):
383
+ return gr.Slider(minimum=min_len)
384
+ min_len_input.change(fn=update_max_len_range, inputs=min_len_input, outputs=max_len_input)
385
+
386
  generate_button.click(
387
  fn=generate_peptide_wrapper,
388
  inputs=[num_input, min_len_input, max_len_input, temp_input, diversity_input],