chshan commited on
Commit
bffbeec
·
verified ·
1 Parent(s): a96a115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -91
app.py CHANGED
@@ -1,5 +1,4 @@
1
- # app.py - RLAnOxPeptide Gradio Web Application
2
- # This script integrates both the predictor and generator into a user-friendly web UI.
3
 
4
  import os
5
  import torch
@@ -17,11 +16,10 @@ transformers.logging.set_verbosity_error()
17
 
18
  # --------------------------------------------------------------------------
19
  # SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
20
- # To make this app self-contained, we copy necessary class definitions here.
21
- # These should match the versions used during training.
22
  # --------------------------------------------------------------------------
23
 
24
- # --- Vocabulary Definition (from both scripts) ---
25
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
26
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
27
  token2id["<PAD>"] = 0
@@ -29,16 +27,13 @@ token2id["<EOS>"] = 1
29
  id2token = {i: t for t, i in token2id.items()}
30
  VOCAB_SIZE = len(token2id)
31
 
32
- # --- Predictor Model Architecture (Corrected to match saved weights) ---
33
  class AntioxidantPredictor(nn.Module):
34
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
35
  super(AntioxidantPredictor, self).__init__()
36
- # 根据错误日志和您的训练脚本,我们知道输入维度是固定的
37
- # 并且模型内部处理 ProtT5 和传统特征的分离
38
  self.t5_dim = 1024
39
  self.hand_crafted_dim = input_dim - self.t5_dim
40
 
41
- # 定义 Transformer Encoder
42
  encoder_layer = nn.TransformerEncoderLayer(
43
  d_model=self.t5_dim,
44
  nhead=transformer_heads,
@@ -47,9 +42,6 @@ class AntioxidantPredictor(nn.Module):
47
  )
48
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
49
 
50
- # 定义 MLP
51
- # 错误日志表明权重文件没有 fusion_fc 和 classifier,只有一个 mlp
52
- # 我们根据 predictor_train.py 的原始结构来重建
53
  self.mlp = nn.Sequential(
54
  nn.Linear(input_dim, 512),
55
  nn.ReLU(),
@@ -62,19 +54,15 @@ class AntioxidantPredictor(nn.Module):
62
  self.temperature = nn.Parameter(torch.ones(1))
63
 
64
  def forward(self, fused_features):
65
- # 这个前向传播逻辑与您的训练脚本 predictor_train.py 更为匹配
66
  prot_t5_features = fused_features[:, :self.t5_dim]
67
  hand_crafted_features = fused_features[:, self.t5_dim:]
68
 
69
- # Transformer 只处理 ProtT5 特征
70
  prot_t5_features_unsqueezed = prot_t5_features.unsqueeze(1)
71
  transformer_output = self.transformer_encoder(prot_t5_features_unsqueezed)
72
  transformer_output_pooled = transformer_output.mean(dim=1)
73
 
74
- # 将处理后的 ProtT5 特征与传统特征拼接
75
  combined_features = torch.cat((transformer_output_pooled, hand_crafted_features), dim=1)
76
 
77
- # 将最终拼接的特征送入 MLP
78
  logits = self.mlp(combined_features)
79
 
80
  return logits / self.temperature
@@ -87,7 +75,6 @@ class AntioxidantPredictor(nn.Module):
87
 
88
  # --- Generator Model Architecture (from generator.py) ---
89
  class ProtT5Generator(nn.Module):
90
- # This class definition should be an exact copy from your project
91
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
92
  super(ProtT5Generator, self).__init__()
93
  self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=token2id["<PAD>"])
@@ -97,7 +84,7 @@ class ProtT5Generator(nn.Module):
97
  self.vocab_size = vocab_size
98
  self.eos_token_id = token2id["<EOS>"]
99
  self.pad_token_id = token2id["<PAD>"]
100
-
101
  def forward(self, input_ids):
102
  embeddings = self.embed_tokens(input_ids)
103
  encoder_output = self.encoder(embeddings)
@@ -112,12 +99,11 @@ class ProtT5Generator(nn.Module):
112
  next_logits = logits[:, -1, :] / temperature
113
  if generated.size(1) < min_decoded_length:
114
  next_logits[:, self.eos_token_id] = -float("inf")
115
-
116
  probs = torch.softmax(next_logits, dim=-1)
117
  next_token = torch.multinomial(probs, num_samples=1)
118
  generated = torch.cat((generated, next_token), dim=1)
119
-
120
- if (next_token == self.eos_token_id).all():
121
  break
122
  return generated
123
 
@@ -133,9 +119,7 @@ class ProtT5Generator(nn.Module):
133
  seqs.append(seq)
134
  return seqs
135
 
136
- # --- Feature Extraction Logic (from feature_extract.py) ---
137
- # Note: You need the actual ProtT5Model and extract_features here.
138
- # Assuming they are in a file named `feature_extract.py` in the same directory.
139
  try:
140
  from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
141
  except ImportError:
@@ -147,11 +131,10 @@ def cluster_sequences(generator, sequences, num_clusters, device):
147
  return sequences[:num_clusters]
148
  with torch.no_grad():
149
  token_ids_list = []
150
- max_len = max(len(seq) for seq in sequences) + 2 # Start token + EOS
151
  for seq in sequences:
152
- # Recreate encoding to match how generator sees it (with start token)
153
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
154
- ids = [np.random.randint(2, VOCAB_SIZE)] + ids # Add a dummy start token
155
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
156
  token_ids_list.append(ids)
157
 
@@ -175,55 +158,44 @@ def cluster_sequences(generator, sequences, num_clusters, device):
175
  representatives.append(sequences[representative_index])
176
  return representatives
177
 
178
-
179
  # --------------------------------------------------------------------------
180
  # SECTION 2: GLOBAL MODEL LOADING
181
- # Load all models and dependencies once when the app starts.
182
  # --------------------------------------------------------------------------
183
  print("Loading all models and dependencies. Please wait...")
184
- DEVICE = "cpu" # Use CPU for compatibility with Hugging Face free tier
185
 
186
  try:
187
- # --- Define all required file paths here ---
188
- # !! IMPORTANT: Ensure these are relative paths to the files in your Space !!
189
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
190
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
191
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
192
  PROTT5_BASE_MODEL_PATH = "prott5/model/"
193
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
194
 
195
- # --- Load Predictor Components ---
196
  print("Loading Predictor Model...")
197
- PREDICTOR_MODEL = AntioxidantPredictor(
198
- input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
199
- )
200
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
201
  PREDICTOR_MODEL.to(DEVICE)
202
  PREDICTOR_MODEL.eval()
203
- print("✅ Predictor model loaded.")
204
 
205
- print("Loading Scaler...")
 
206
  SCALER = joblib.load(SCALER_PATH)
207
- print("✅ Scaler loaded.")
208
-
209
- print("Loading ProtT5 Feature Extractor...")
210
- # This extractor must use the fine-tuned model for features, as per your training logic
211
  PROTT5_EXTRACTOR = FeatureProtT5Model(
212
  model_path=PROTT5_BASE_MODEL_PATH,
213
  finetuned_model_file=FINETUNED_PROTT5_FOR_FEATURES_PATH
214
  )
215
- print("✅ ProtT5 Feature Extractor loaded.")
216
 
217
- # --- Load Generator Model ---
218
  print("Loading Generator Model...")
219
- GENERATOR_MODEL = ProtT5Generator(
220
- vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
221
- )
222
  GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
223
  GENERATOR_MODEL.to(DEVICE)
224
  GENERATOR_MODEL.eval()
225
  print("✅ Generator model loaded.")
226
-
227
  print("\n--- All models loaded successfully! Gradio app is ready. ---\n")
228
 
229
  except Exception as e:
@@ -232,22 +204,17 @@ except Exception as e:
232
 
233
  # --------------------------------------------------------------------------
234
  # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
235
- # These functions connect the UI to our model's logic.
236
  # --------------------------------------------------------------------------
237
 
238
  def predict_peptide_wrapper(sequence_str):
239
- """Takes a peptide sequence string and returns its predicted probability and class."""
240
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
241
  return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
242
 
243
  try:
244
- # 1. Extract features using the same logic as training/prediction scripts
245
- features = extract_features(sequence_str, PROTT5_EXTRACTOR)
246
-
247
- # 2. Scale features
248
  scaled_features = SCALER.transform(features.reshape(1, -1))
249
 
250
- # 3. Predict with the model
251
  with torch.no_grad():
252
  features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
253
  logits = PREDICTOR_MODEL(features_tensor)
@@ -261,47 +228,46 @@ def predict_peptide_wrapper(sequence_str):
261
  return "N/A", f"An error occurred during processing: {e}"
262
 
263
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
264
- """Generates, validates, and clusters sequences."""
265
  num_to_generate = int(num_to_generate)
266
  min_len = int(min_len)
267
  max_len = int(max_len)
268
 
269
  try:
270
- # STEP 1: Generate an initial pool of unique sequences
271
  target_pool_size = int(num_to_generate * diversity_factor)
272
  unique_seqs = set()
273
- progress(0, desc="Generating initial peptide pool...")
274
 
275
- max_attempts = 10
276
- attempts = 0
277
- while len(unique_seqs) < target_pool_size and attempts < max_attempts:
278
- batch_size = (target_pool_size - len(unique_seqs)) * 2 # Generate extra to account for duplicates/short ones
279
- with torch.no_grad():
280
- generated_tokens = GENERATOR_MODEL.sample(
281
- batch_size=batch_size,
282
- max_length=max_len,
283
- device=DEVICE,
284
- temperature=temperature,
285
- min_decoded_length=min_len
286
- )
287
- decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
288
- for seq in decoded:
289
- if min_len <= len(seq) <= max_len:
290
- unique_seqs.add(seq)
291
- attempts += 1
292
- progress(len(unique_seqs) / target_pool_size, desc=f"Generated {len(unique_seqs)} unique sequences...")
 
 
 
293
 
294
  candidate_seqs = list(unique_seqs)
295
- if not candidate_seqs:
296
- return pd.DataFrame({"Sequence": ["Failed to generate valid sequences."], "Predicted Probability": ["N/A"]})
297
 
298
- # STEP 2: Validate the generated sequences
299
  validated_pool = {}
300
  for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
301
  prob_str, _ = predict_peptide_wrapper(seq)
302
  try:
303
  prob = float(prob_str)
304
- if prob > 0.90: # Filter for high-quality peptides as in generator.py
305
  validated_pool[seq] = prob
306
  except (ValueError, TypeError):
307
  continue
@@ -311,11 +277,11 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
311
 
312
  high_quality_sequences = list(validated_pool.keys())
313
 
314
- # STEP 3: Cluster to ensure diversity
315
  progress(1.0, desc="Clustering for diversity...")
316
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
317
 
318
- # STEP 4: Format final results
319
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
320
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
321
 
@@ -325,10 +291,8 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
325
  print(f"Generation error: {e}")
326
  return pd.DataFrame({"Sequence": [f"An error occurred during generation: {e}"], "Predicted Probability": ["N/A"]})
327
 
328
-
329
  # --------------------------------------------------------------------------
330
- # SECTION 4: GRADIO UI CONSTRUCTION
331
- # Building the web interface. All text is in English.
332
  # --------------------------------------------------------------------------
333
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
334
  gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
@@ -350,11 +314,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
350
  outputs=[probability_output, class_output]
351
  )
352
  gr.Examples(
353
- examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"], ["INVALIDSEQUENCE"]],
354
- inputs=peptide_input,
355
- outputs=[probability_output, class_output],
356
- fn=predict_peptide_wrapper,
357
- cache_examples=False,
358
  )
359
 
360
  with gr.TabItem("Novel Sequence Generator"):
@@ -378,4 +339,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
378
  )
379
 
380
  if __name__ == "__main__":
381
- demo.launch()
 
1
+ # app.py - RLAnOxPeptide Gradio Web Application (Corrected Version)
 
2
 
3
  import os
4
  import torch
 
16
 
17
  # --------------------------------------------------------------------------
18
  # SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
19
+ # These definitions are now synchronized with your provided, working scripts.
 
20
  # --------------------------------------------------------------------------
21
 
22
+ # --- Vocabulary Definition ---
23
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
24
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
25
  token2id["<PAD>"] = 0
 
27
  id2token = {i: t for t, i in token2id.items()}
28
  VOCAB_SIZE = len(token2id)
29
 
30
+ # --- Predictor Model Architecture (VERSION THAT MATCHES YOUR .pth FILE) ---
31
  class AntioxidantPredictor(nn.Module):
32
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
33
  super(AntioxidantPredictor, self).__init__()
 
 
34
  self.t5_dim = 1024
35
  self.hand_crafted_dim = input_dim - self.t5_dim
36
 
 
37
  encoder_layer = nn.TransformerEncoderLayer(
38
  d_model=self.t5_dim,
39
  nhead=transformer_heads,
 
42
  )
43
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
44
 
 
 
 
45
  self.mlp = nn.Sequential(
46
  nn.Linear(input_dim, 512),
47
  nn.ReLU(),
 
54
  self.temperature = nn.Parameter(torch.ones(1))
55
 
56
  def forward(self, fused_features):
 
57
  prot_t5_features = fused_features[:, :self.t5_dim]
58
  hand_crafted_features = fused_features[:, self.t5_dim:]
59
 
 
60
  prot_t5_features_unsqueezed = prot_t5_features.unsqueeze(1)
61
  transformer_output = self.transformer_encoder(prot_t5_features_unsqueezed)
62
  transformer_output_pooled = transformer_output.mean(dim=1)
63
 
 
64
  combined_features = torch.cat((transformer_output_pooled, hand_crafted_features), dim=1)
65
 
 
66
  logits = self.mlp(combined_features)
67
 
68
  return logits / self.temperature
 
75
 
76
  # --- Generator Model Architecture (from generator.py) ---
77
  class ProtT5Generator(nn.Module):
 
78
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
79
  super(ProtT5Generator, self).__init__()
80
  self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=token2id["<PAD>"])
 
84
  self.vocab_size = vocab_size
85
  self.eos_token_id = token2id["<EOS>"]
86
  self.pad_token_id = token2id["<PAD>"]
87
+
88
  def forward(self, input_ids):
89
  embeddings = self.embed_tokens(input_ids)
90
  encoder_output = self.encoder(embeddings)
 
99
  next_logits = logits[:, -1, :] / temperature
100
  if generated.size(1) < min_decoded_length:
101
  next_logits[:, self.eos_token_id] = -float("inf")
 
102
  probs = torch.softmax(next_logits, dim=-1)
103
  next_token = torch.multinomial(probs, num_samples=1)
104
  generated = torch.cat((generated, next_token), dim=1)
105
+ # Early stop if all sequences in batch have generated an EOS token
106
+ if (generated == self.eos_token_id).any(dim=1).all():
107
  break
108
  return generated
109
 
 
119
  seqs.append(seq)
120
  return seqs
121
 
122
+ # --- Feature Extraction Logic (needs feature_extract.py) ---
 
 
123
  try:
124
  from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
125
  except ImportError:
 
131
  return sequences[:num_clusters]
132
  with torch.no_grad():
133
  token_ids_list = []
134
+ max_len = max(len(seq) for seq in sequences) + 2
135
  for seq in sequences:
 
136
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
137
+ ids = [np.random.randint(2, VOCAB_SIZE)] + ids
138
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
139
  token_ids_list.append(ids)
140
 
 
158
  representatives.append(sequences[representative_index])
159
  return representatives
160
 
 
161
  # --------------------------------------------------------------------------
162
  # SECTION 2: GLOBAL MODEL LOADING
 
163
  # --------------------------------------------------------------------------
164
  print("Loading all models and dependencies. Please wait...")
165
+ DEVICE = "cpu"
166
 
167
  try:
168
+ # --- Define file paths (!! CHECK THESE PATHS !!) ---
 
169
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
170
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
171
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
172
  PROTT5_BASE_MODEL_PATH = "prott5/model/"
173
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
174
 
175
+ # --- Load Predictor ---
176
  print("Loading Predictor Model...")
177
+ PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914, transformer_layers=3, transformer_heads=4)
 
 
178
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
179
  PREDICTOR_MODEL.to(DEVICE)
180
  PREDICTOR_MODEL.eval()
181
+ print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
182
 
183
+ # --- Load Scaler & Feature Extractor ---
184
+ print("Loading Scaler and Feature Extractor...")
185
  SCALER = joblib.load(SCALER_PATH)
 
 
 
 
186
  PROTT5_EXTRACTOR = FeatureProtT5Model(
187
  model_path=PROTT5_BASE_MODEL_PATH,
188
  finetuned_model_file=FINETUNED_PROTT5_FOR_FEATURES_PATH
189
  )
190
+ print("✅ Scaler and Feature Extractor loaded.")
191
 
192
+ # --- Load Generator ---
193
  print("Loading Generator Model...")
194
+ GENERATOR_MODEL = ProtT5Generator(vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8)
 
 
195
  GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
196
  GENERATOR_MODEL.to(DEVICE)
197
  GENERATOR_MODEL.eval()
198
  print("✅ Generator model loaded.")
 
199
  print("\n--- All models loaded successfully! Gradio app is ready. ---\n")
200
 
201
  except Exception as e:
 
204
 
205
  # --------------------------------------------------------------------------
206
  # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
 
207
  # --------------------------------------------------------------------------
208
 
209
  def predict_peptide_wrapper(sequence_str):
 
210
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
211
  return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
212
 
213
  try:
214
+ # These L_fixed and d_model_pe values are from your predictor.py args
215
+ features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
 
 
216
  scaled_features = SCALER.transform(features.reshape(1, -1))
217
 
 
218
  with torch.no_grad():
219
  features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
220
  logits = PREDICTOR_MODEL(features_tensor)
 
228
  return "N/A", f"An error occurred during processing: {e}"
229
 
230
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
 
231
  num_to_generate = int(num_to_generate)
232
  min_len = int(min_len)
233
  max_len = int(max_len)
234
 
235
  try:
236
+ # Step 1: Generate a pool of unique sequences
237
  target_pool_size = int(num_to_generate * diversity_factor)
238
  unique_seqs = set()
 
239
 
240
+ # A simple generation loop based on generator.py logic
241
+ with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
242
+ while len(unique_seqs) < target_pool_size:
243
+ batch_size = (target_pool_size - len(unique_seqs))
244
+ with torch.no_grad():
245
+ generated_tokens = GENERATOR_MODEL.sample(
246
+ batch_size=max(1, batch_size),
247
+ max_length=max_len,
248
+ device=DEVICE,
249
+ temperature=temperature,
250
+ min_decoded_length=min_len
251
+ )
252
+ decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
253
+
254
+ newly_added = 0
255
+ for seq in decoded:
256
+ if min_len <= len(seq) <= max_len:
257
+ if seq not in unique_seqs:
258
+ unique_seqs.add(seq)
259
+ newly_added +=1
260
+ pbar.update(newly_added)
261
 
262
  candidate_seqs = list(unique_seqs)
 
 
263
 
264
+ # Step 2: Validate the generated sequences
265
  validated_pool = {}
266
  for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
267
  prob_str, _ = predict_peptide_wrapper(seq)
268
  try:
269
  prob = float(prob_str)
270
+ if prob > 0.90: # Filter for high-quality peptides
271
  validated_pool[seq] = prob
272
  except (ValueError, TypeError):
273
  continue
 
277
 
278
  high_quality_sequences = list(validated_pool.keys())
279
 
280
+ # Step 3: Cluster to ensure diversity
281
  progress(1.0, desc="Clustering for diversity...")
282
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
283
 
284
+ # Step 4: Format final results
285
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
286
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
287
 
 
291
  print(f"Generation error: {e}")
292
  return pd.DataFrame({"Sequence": [f"An error occurred during generation: {e}"], "Predicted Probability": ["N/A"]})
293
 
 
294
  # --------------------------------------------------------------------------
295
+ # SECTION 4: GRADIO UI CONSTRUCTION (ALL ENGLISH)
 
296
  # --------------------------------------------------------------------------
297
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
298
  gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
 
314
  outputs=[probability_output, class_output]
315
  )
316
  gr.Examples(
317
+ examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"]],
318
+ inputs=peptide_input
 
 
 
319
  )
320
 
321
  with gr.TabItem("Novel Sequence Generator"):
 
339
  )
340
 
341
  if __name__ == "__main__":
342
+ demo.launch()