chshan commited on
Commit
6f96910
·
verified ·
1 Parent(s): d1b4723

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -100
app.py CHANGED
@@ -1,5 +1,26 @@
1
-
2
- # app.py - RLAnOxPeptide Gradio Web Application (FINAL CORRECTED VERSION - Robust Loading)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import os
5
  import torch
@@ -12,14 +33,14 @@ from sklearn.cluster import KMeans
12
  from tqdm import tqdm
13
  import transformers
14
 
15
- # Suppress verbose logging from transformers
16
  transformers.logging.set_verbosity_error()
17
 
18
  # --------------------------------------------------------------------------
19
  # SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
20
  # --------------------------------------------------------------------------
21
 
22
- # --- Vocabulary Definition ---
23
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
24
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
25
  token2id["<PAD>"] = 0
@@ -27,71 +48,71 @@ token2id["<EOS>"] = 1
27
  id2token = {i: t for t, i in token2id.items()}
28
  VOCAB_SIZE = len(token2id)
29
 
30
- # --- ROBUST FeatureProtT5Model Class for Feature Extraction ---
 
 
31
  class FeatureProtT5Model:
32
  def __init__(self, model_dir_path, finetuned_weights_path=None):
33
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
- print(f"Initializing ProtT5 from base directory: {model_dir_path}")
35
 
36
- # Step 1: Load the base model architecture and tokenizer from the directory.
37
- # This step requires the original pytorch_model.bin to be in the model_dir_path.
38
  self.tokenizer = transformers.T5Tokenizer.from_pretrained(model_dir_path, do_lower_case=False)
39
  self.model = transformers.T5EncoderModel.from_pretrained(model_dir_path)
40
 
41
- # Step 2: If a separate fine-tuned weights file is provided, load it.
42
  if finetuned_weights_path and os.path.exists(finetuned_weights_path):
43
- print(f"Loading and applying fine-tuned weights from: {finetuned_weights_path}")
44
- # Load the state_dict from your specific fine-tuned file
45
  state_dict = torch.load(finetuned_weights_path, map_location=self.device)
46
- # Use strict=False because the fine-tuned model may only contain encoder weights
47
  self.model.load_state_dict(state_dict, strict=False)
48
- print("Successfully applied fine-tuned weights to the model.")
49
  else:
50
- print("Warning: Fine-tuned weights file not provided or not found. Using the base ProtT5 model weights.")
51
 
52
  self.model.to(self.device)
53
  self.model.eval()
54
 
55
- def encode(self, sequence):
56
- if not sequence or not isinstance(sequence, str):
57
- return np.zeros((1, 1024), dtype=np.float32)
58
- seq_spaced = " ".join(list(sequence))
59
- encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True, max_length=1022)
60
- encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
61
- with torch.no_grad():
62
- embedding = self.model(**encoded_input).last_hidden_state
63
- emb = embedding.squeeze(0).cpu().numpy()
64
- return emb if emb.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
65
-
66
  # --- Predictor Model Architecture ---
 
 
67
  class AntioxidantPredictor(nn.Module):
68
- def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
69
  super(AntioxidantPredictor, self).__init__()
70
  self.prott5_dim = 1024
71
  self.handcrafted_dim = input_dim - self.prott5_dim
72
  self.seq_len = 16
73
- self.prott5_feature_dim = 64
 
74
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
75
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
 
76
  fused_dim = self.prott5_feature_dim + self.handcrafted_dim
77
  self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
78
  self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
79
  self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
80
- def forward(self, x, *args):
 
81
  batch_size = x.size(0)
 
82
  prot_t5_features = x[:, :self.prott5_dim]
83
  handcrafted_features = x[:, self.prott5_dim:]
 
 
84
  prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
 
85
  encoded_seq = self.transformer_encoder(prot_t5_seq)
86
  refined_prott5 = encoded_seq.mean(dim=1)
 
87
  fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
88
- fused_features = self.fusion_fc(fused_features)
89
- logits = self.classifier(fused_features)
 
90
  return logits / self.temperature
91
- def set_temperature(self, temp_value, device): self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
92
- def get_temperature(self): return self.temperature.item()
93
 
94
- # --- Generator Model Architecture (Copied VERBATIM from your generator.py) ---
 
 
 
95
  class ProtT5Generator(nn.Module):
96
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
97
  super(ProtT5Generator, self).__init__()
@@ -117,49 +138,50 @@ class ProtT5Generator(nn.Module):
117
  next_logits = logits[:, -1, :] / temperature
118
  if generated.size(1) < min_decoded_length:
119
  next_logits[:, self.eos_token_id] = -float("inf")
 
120
  probs = torch.softmax(next_logits, dim=-1)
121
  next_token = torch.multinomial(probs, num_samples=1)
122
  generated = torch.cat((generated, next_token), dim=1)
123
- if (generated == self.eos_token_id).any(dim=1).all():
124
- break
125
  return generated
126
 
127
  def decode(self, token_ids_batch):
128
- seqs = []
129
  for ids_tensor in token_ids_batch:
130
  seq = ""
131
- for token_id in ids_tensor.tolist()[1:]: # Skip start token
132
  if token_id == self.eos_token_id: break
133
  if token_id == self.pad_token_id: continue
134
- seq += id2token.get(token_id, "?")
135
- seqs.append(seq)
136
- return seqs
137
-
138
- # --- Feature Extraction (needs feature_extract.py in the same directory) ---
 
 
 
139
  try:
140
- from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
141
  except ImportError:
142
- raise gr.Error("Failed to import feature_extract.py. Ensure it is in the same directory.")
143
 
144
  # --- Clustering Logic (from generator.py) ---
145
  def cluster_sequences(generator, sequences, num_clusters, device):
146
  if not sequences or len(sequences) < num_clusters:
147
  return sequences[:num_clusters]
 
148
  with torch.no_grad():
149
  token_ids_list = []
150
- max_len = max(len(seq) for seq in sequences) + 2
151
  for seq in sequences:
152
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
153
- ids = [np.random.randint(2, VOCAB_SIZE)] + ids
154
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
155
  token_ids_list.append(ids)
156
 
157
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
158
  embeddings = generator.embed_tokens(input_ids)
159
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
160
- embeddings = embeddings * mask
161
- lengths = mask.sum(dim=1)
162
- seq_embeds = embeddings.sum(dim=1) / (lengths + 1e-9)
163
  seq_embeds_np = seq_embeds.cpu().numpy()
164
 
165
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
@@ -175,67 +197,67 @@ def cluster_sequences(generator, sequences, num_clusters, device):
175
  return representatives
176
 
177
  # --------------------------------------------------------------------------
178
- # SECTION 2: GLOBAL MODEL LOADING
179
  # --------------------------------------------------------------------------
180
- print("Loading all models and dependencies...")
181
- DEVICE = "cpu"
 
182
 
183
  try:
184
- # --- Define file paths (!! CHECK THESE PATHS !!) ---
185
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
186
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
187
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
188
  PROTT5_BASE_MODEL_PATH = "prott5/model/"
189
- # This path is now used by the FeatureProtT5Model to load the fine-tuned weights
190
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
191
 
192
- # --- Load Predictor ---
193
- print("Loading Predictor Model...")
194
- # Initialize the correct class
195
- PREDICTOR_MODEL = AntioxidantPredictor(
196
- input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
197
- )
198
- # Load the state dict that matches this class
199
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
200
  PREDICTOR_MODEL.to(DEVICE)
201
  PREDICTOR_MODEL.eval()
202
  print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
203
 
204
  # --- Load Scaler & Feature Extractor ---
205
- print("Loading Scaler and Feature Extractor...")
206
  SCALER = joblib.load(SCALER_PATH)
 
207
  PROTT5_EXTRACTOR = FeatureProtT5Model(
208
- model_path=PROTT5_BASE_MODEL_PATH,
209
- finetuned_model_file=FINETUNED_PROTT5_FOR_FEATURES_PATH
210
  )
211
  print("✅ Scaler and Feature Extractor loaded.")
212
 
213
- # --- Load Generator ---
214
- print("Loading Generator Model...")
215
- GENERATOR_MODEL = ProtT5Generator(
216
- vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
217
- )
218
  GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
219
  GENERATOR_MODEL.to(DEVICE)
220
  GENERATOR_MODEL.eval()
221
  print("✅ Generator model loaded.")
222
- print("\n--- All models loaded successfully! Gradio app is ready. ---\n")
 
223
 
224
  except Exception as e:
225
- print(f"💥 FATAL ERROR: Failed to load a model or dependency file: {e}")
226
- raise gr.Error(f"Model or dependency loading failed! Check file paths and integrity. Error: {e}")
227
 
228
  # --------------------------------------------------------------------------
229
- # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
230
  # --------------------------------------------------------------------------
231
 
232
  def predict_peptide_wrapper(sequence_str):
 
233
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
234
- return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
235
 
236
  try:
237
- # These L_fixed and d_model_pe values are from your predictor.py args
238
- features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
 
 
 
239
  scaled_features = SCALER.transform(features.reshape(1, -1))
240
 
241
  with torch.no_grad():
@@ -247,21 +269,22 @@ def predict_peptide_wrapper(sequence_str):
247
  return f"{probability:.4f}", classification
248
 
249
  except Exception as e:
250
- print(f"Prediction error for sequence '{sequence_str}': {e}")
251
- return "N/A", f"An error occurred during processing: {e}"
252
 
253
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
254
- # This logic is a direct adaptation of your generator.py main function
255
  num_to_generate = int(num_to_generate)
256
  min_len = int(min_len)
257
  max_len = int(max_len)
258
 
259
  try:
260
- # Step 1: Generate a pool of unique sequences
261
  target_pool_size = int(num_to_generate * diversity_factor)
262
  unique_seqs = set()
263
 
264
- with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
 
265
  while len(unique_seqs) < target_pool_size:
266
  batch_size = max(1, (target_pool_size - len(unique_seqs)))
267
  with torch.no_grad():
@@ -269,19 +292,19 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
269
  batch_size=batch_size, max_length=max_len, device=DEVICE,
270
  temperature=temperature, min_decoded_length=min_len
271
  )
272
- decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
273
 
274
  initial_count = len(unique_seqs)
275
- for seq in decoded:
276
  if min_len <= len(seq) <= max_len:
277
  unique_seqs.add(seq)
278
  pbar.update(len(unique_seqs) - initial_count)
279
 
280
  candidate_seqs = list(unique_seqs)
281
 
282
- # Step 2: Validate the generated sequences
283
  validated_pool = {}
284
- for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
285
  prob_str, _ = predict_peptide_wrapper(seq)
286
  try:
287
  prob = float(prob_str)
@@ -291,40 +314,41 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
291
  continue
292
 
293
  if not validated_pool:
294
- return pd.DataFrame([{"Sequence": "No high-activity peptides (>0.9 prob) were generated.", "Predicted Probability": "N/A"}])
295
 
296
  high_quality_sequences = list(validated_pool.keys())
297
 
298
- # Step 3: Cluster to ensure diversity
299
- progress(1.0, desc="Clustering for diversity...")
300
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
301
 
302
- # Step 4: Format final results
303
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
304
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
305
 
306
  return pd.DataFrame(final_results, columns=["Sequence", "Predicted Probability"])
307
 
308
  except Exception as e:
309
- print(f"Generation error: {e}")
310
- return pd.DataFrame([{"Sequence": f"An error occurred: {e}", "Predicted Probability": "N/A"}])
311
 
312
  # --------------------------------------------------------------------------
313
  # SECTION 4: GRADIO UI CONSTRUCTION
314
  # --------------------------------------------------------------------------
315
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
316
- gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
317
  gr.Markdown("An integrated framework combining reinforcement learning and a Transformer model for the efficient prediction and innovative design of antioxidant peptides.")
318
 
319
  with gr.Tabs():
 
320
  with gr.TabItem("Peptide Activity Predictor"):
321
  gr.Markdown("### Enter an amino acid sequence to predict its antioxidant activity.")
322
  with gr.Row():
323
  peptide_input = gr.Textbox(label="Peptide Sequence", placeholder="e.g., WHYHDYKY", scale=3)
324
  predict_button = gr.Button("Predict", variant="primary", scale=1)
325
  with gr.Row():
326
- probability_output = gr.Textbox(label="Predicted Probability")
327
- class_output = gr.Textbox(label="Predicted Class")
328
 
329
  predict_button.click(
330
  fn=predict_peptide_wrapper,
@@ -332,23 +356,27 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
332
  outputs=[probability_output, class_output]
333
  )
334
  gr.Examples(
335
- examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"]],
336
- inputs=peptide_input
 
 
 
337
  )
338
 
 
339
  with gr.TabItem("Novel Sequence Generator"):
340
  gr.Markdown("### Set parameters to generate novel, high-activity antioxidant peptides.")
341
  with gr.Column():
342
  with gr.Row():
343
- num_input = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
344
- min_len_input = gr.Slider(minimum=2, maximum=10, value=3, step=1, label="Minimum Length")
345
  max_len_input = gr.Slider(minimum=10, maximum=20, value=20, step=1, label="Maximum Length")
346
  with gr.Row():
347
  temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
348
- diversity_input = gr.Slider(minimum=1.0, maximum=3.0, value=1.2, step=0.1, label="Diversity Factor (Higher = Larger initial pool for clustering)")
349
 
350
  generate_button = gr.Button("Generate Peptides", variant="primary")
351
- results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides", wrap=True)
352
 
353
  generate_button.click(
354
  fn=generate_peptide_wrapper,
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # app.py - RLAnOxPeptide Gradio Web Application
5
+ # This script combines logic from predictor.py, generator.py, and the original app.py
6
+ # into a single, self-contained file for a Hugging Face Space.
7
+ #
8
+ # REQUIRED FILE STRUCTURE IN HUGGING FACE REPO:
9
+ # .
10
+ # ├── app.py (This file)
11
+ # ├── feature_extract.py (CRITICAL: This file with your `extract_features` function MUST be present)
12
+ # ├── checkpoints/
13
+ # │ ├── final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth
14
+ # │ └── scaler_FINETUNED_PROTT5.pkl
15
+ # ├── generator_checkpoints_v3.6/
16
+ # │ └── final_generator_model.pth
17
+ # ├── prott5/
18
+ # │ └── model/
19
+ # │ ├── config.json
20
+ # │ ├── pytorch_model.bin (The base ProtT5 model from Rostlab)
21
+ # │ ├── finetuned_prott5.bin (Your fine-tuned feature extractor weights)
22
+ # │ └── ... (other tokenizer files)
23
+ # └── requirements.txt
24
 
25
  import os
26
  import torch
 
33
  from tqdm import tqdm
34
  import transformers
35
 
36
+ # Suppress verbose logging from transformers, which can clutter the app logs
37
  transformers.logging.set_verbosity_error()
38
 
39
  # --------------------------------------------------------------------------
40
  # SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
41
  # --------------------------------------------------------------------------
42
 
43
+ # --- Vocabulary Definition (Consistent across all scripts) ---
44
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
45
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
46
  token2id["<PAD>"] = 0
 
48
  id2token = {i: t for t, i in token2id.items()}
49
  VOCAB_SIZE = len(token2id)
50
 
51
+
52
+ # --- Feature Extractor Model Class (For ProtT5) ---
53
+ # This class robustly loads the base ProtT5 model and applies your fine-tuned weights.
54
  class FeatureProtT5Model:
55
  def __init__(self, model_dir_path, finetuned_weights_path=None):
56
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
58
 
59
+ # Load the base model architecture and tokenizer from the specified directory.
 
60
  self.tokenizer = transformers.T5Tokenizer.from_pretrained(model_dir_path, do_lower_case=False)
61
  self.model = transformers.T5EncoderModel.from_pretrained(model_dir_path)
62
 
63
+ # If a path to a fine-tuned weights file is provided, load and apply those weights.
64
  if finetuned_weights_path and os.path.exists(finetuned_weights_path):
65
+ print(f"Applying fine-tuned weights from: {finetuned_weights_path}")
 
66
  state_dict = torch.load(finetuned_weights_path, map_location=self.device)
 
67
  self.model.load_state_dict(state_dict, strict=False)
68
+ print("Successfully applied fine-tuned weights.")
69
  else:
70
+ print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
71
 
72
  self.model.to(self.device)
73
  self.model.eval()
74
 
 
 
 
 
 
 
 
 
 
 
 
75
  # --- Predictor Model Architecture ---
76
+ # This is the antioxidant activity predictor model. Its architecture must
77
+ # exactly match the architecture used to save the checkpoint file.
78
  class AntioxidantPredictor(nn.Module):
79
+ def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
80
  super(AntioxidantPredictor, self).__init__()
81
  self.prott5_dim = 1024
82
  self.handcrafted_dim = input_dim - self.prott5_dim
83
  self.seq_len = 16
84
+ self.prott5_feature_dim = 64 # 16 * 64 = 1024
85
+
86
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
87
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
88
+
89
  fused_dim = self.prott5_feature_dim + self.handcrafted_dim
90
  self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
91
  self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
92
  self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
93
+
94
+ def forward(self, x):
95
  batch_size = x.size(0)
96
+ # The input 'x' is a flat 1914-dim vector from extract_features()
97
  prot_t5_features = x[:, :self.prott5_dim]
98
  handcrafted_features = x[:, self.prott5_dim:]
99
+
100
+ # Reshape the first 1024 features back into a sequence representation
101
  prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
102
+
103
  encoded_seq = self.transformer_encoder(prot_t5_seq)
104
  refined_prott5 = encoded_seq.mean(dim=1)
105
+
106
  fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
107
+ fused_output = self.fusion_fc(fused_features)
108
+ logits = self.classifier(fused_output)
109
+
110
  return logits / self.temperature
 
 
111
 
112
+ def get_temperature(self):
113
+ return self.temperature.item()
114
+
115
+ # --- Generator Model Architecture (from generator.py) ---
116
  class ProtT5Generator(nn.Module):
117
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
118
  super(ProtT5Generator, self).__init__()
 
138
  next_logits = logits[:, -1, :] / temperature
139
  if generated.size(1) < min_decoded_length:
140
  next_logits[:, self.eos_token_id] = -float("inf")
141
+
142
  probs = torch.softmax(next_logits, dim=-1)
143
  next_token = torch.multinomial(probs, num_samples=1)
144
  generated = torch.cat((generated, next_token), dim=1)
 
 
145
  return generated
146
 
147
  def decode(self, token_ids_batch):
148
+ sequences = []
149
  for ids_tensor in token_ids_batch:
150
  seq = ""
151
+ for token_id in ids_tensor.tolist()[1:]: # Skip the random start token
152
  if token_id == self.eos_token_id: break
153
  if token_id == self.pad_token_id: continue
154
+ seq += id2token.get(token_id, "")
155
+ sequences.append(seq)
156
+ return sequences
157
+
158
+ # --- CRITICAL DEPENDENCY: feature_extract.py ---
159
+ # This application requires a function named `extract_features` to convert a peptide
160
+ # sequence into a 1914-dimensional feature vector for the prediction model.
161
+ # This function must be defined in a file named `feature_extract.py` in the repository root.
162
  try:
163
+ from feature_extract import extract_features
164
  except ImportError:
165
+ raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required for the application to run. Please upload it to your repository.")
166
 
167
  # --- Clustering Logic (from generator.py) ---
168
  def cluster_sequences(generator, sequences, num_clusters, device):
169
  if not sequences or len(sequences) < num_clusters:
170
  return sequences[:num_clusters]
171
+
172
  with torch.no_grad():
173
  token_ids_list = []
174
+ max_len = max((len(seq) for seq in sequences), default=0) + 2
175
  for seq in sequences:
176
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
177
+ ids = [np.random.randint(2, VOCAB_SIZE)] + ids # Prepend a start token
178
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
179
  token_ids_list.append(ids)
180
 
181
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
182
  embeddings = generator.embed_tokens(input_ids)
183
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
184
+ seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
 
 
185
  seq_embeds_np = seq_embeds.cpu().numpy()
186
 
187
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
 
197
  return representatives
198
 
199
  # --------------------------------------------------------------------------
200
+ # SECTION 2: GLOBAL MODEL AND DEPENDENCY LOADING
201
  # --------------------------------------------------------------------------
202
+
203
+ print("--- Starting Application: Loading all models and dependencies ---")
204
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
205
 
206
  try:
207
+ # --- Define file paths relative to the repository root ---
208
  PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
209
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
210
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
211
  PROTT5_BASE_MODEL_PATH = "prott5/model/"
 
212
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
213
 
214
+ # --- Load Predictor Model ---
215
+ print(f"Loading Predictor from: {PREDICTOR_CHECKPOINT_PATH}")
216
+ PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
 
 
 
 
217
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
218
  PREDICTOR_MODEL.to(DEVICE)
219
  PREDICTOR_MODEL.eval()
220
  print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
221
 
222
  # --- Load Scaler & Feature Extractor ---
223
+ print(f"Loading Scaler from: {SCALER_PATH}")
224
  SCALER = joblib.load(SCALER_PATH)
225
+ print("Loading ProtT5 Feature Extractor...")
226
  PROTT5_EXTRACTOR = FeatureProtT5Model(
227
+ model_dir_path=PROTT5_BASE_MODEL_PATH,
228
+ finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
229
  )
230
  print("✅ Scaler and Feature Extractor loaded.")
231
 
232
+ # --- Load Generator Model ---
233
+ print(f"Loading Generator from: {GENERATOR_CHECKPOINT_PATH}")
234
+ GENERATOR_MODEL = ProtT5Generator(vocab_size=VOCAB_SIZE)
 
 
235
  GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
236
  GENERATOR_MODEL.to(DEVICE)
237
  GENERATOR_MODEL.eval()
238
  print("✅ Generator model loaded.")
239
+
240
+ print("\n--- All models loaded! Gradio app is ready. ---\n")
241
 
242
  except Exception as e:
243
+ print(f"💥 FATAL ERROR during model loading: {e}")
244
+ raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
245
 
246
  # --------------------------------------------------------------------------
247
+ # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI
248
  # --------------------------------------------------------------------------
249
 
250
  def predict_peptide_wrapper(sequence_str):
251
+ """Handles the prediction for a single peptide sequence from the UI."""
252
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
253
+ return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
254
 
255
  try:
256
+ # Use the imported extract_features function.
257
+ # The L_fixed and d_model_pe values are taken from your original predictor.py arguments.
258
+ features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
259
+
260
+ # Scale the features using the loaded scaler
261
  scaled_features = SCALER.transform(features.reshape(1, -1))
262
 
263
  with torch.no_grad():
 
269
  return f"{probability:.4f}", classification
270
 
271
  except Exception as e:
272
+ print(f"Prediction Error for sequence '{sequence_str}': {e}")
273
+ return "N/A", f"An error occurred during prediction: {e}"
274
 
275
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
276
+ """Handles the full generation-validation-clustering pipeline."""
277
  num_to_generate = int(num_to_generate)
278
  min_len = int(min_len)
279
  max_len = int(max_len)
280
 
281
  try:
282
+ # Step 1: Generate a large, unique pool of candidate sequences
283
  target_pool_size = int(num_to_generate * diversity_factor)
284
  unique_seqs = set()
285
 
286
+ pbar_desc = "Step 1/3: Generating candidate sequences"
287
+ with tqdm(total=target_pool_size, desc=pbar_desc) as pbar:
288
  while len(unique_seqs) < target_pool_size:
289
  batch_size = max(1, (target_pool_size - len(unique_seqs)))
290
  with torch.no_grad():
 
292
  batch_size=batch_size, max_length=max_len, device=DEVICE,
293
  temperature=temperature, min_decoded_length=min_len
294
  )
295
+ decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
296
 
297
  initial_count = len(unique_seqs)
298
+ for seq in decoded_sequences:
299
  if min_len <= len(seq) <= max_len:
300
  unique_seqs.add(seq)
301
  pbar.update(len(unique_seqs) - initial_count)
302
 
303
  candidate_seqs = list(unique_seqs)
304
 
305
+ # Step 2: Validate the generated sequences and filter for high probability
306
  validated_pool = {}
307
+ for seq in tqdm(candidate_seqs, desc="Step 2/3: Validating generated sequences"):
308
  prob_str, _ = predict_peptide_wrapper(seq)
309
  try:
310
  prob = float(prob_str)
 
314
  continue
315
 
316
  if not validated_pool:
317
+ return pd.DataFrame([{"Sequence": "No high-activity peptides (>0.9 prob) were generated. Try increasing the Diversity Factor or changing the Temperature.", "Predicted Probability": "N/A"}])
318
 
319
  high_quality_sequences = list(validated_pool.keys())
320
 
321
+ # Step 3: Cluster to ensure diversity in the final set
322
+ progress(1.0, desc="Step 3/3: Clustering for diversity...")
323
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
324
 
325
+ # Step 4: Format final results into a DataFrame
326
  final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
327
  final_results.sort(key=lambda x: float(x[1]), reverse=True)
328
 
329
  return pd.DataFrame(final_results, columns=["Sequence", "Predicted Probability"])
330
 
331
  except Exception as e:
332
+ print(f"Generation Pipeline Error: {e}")
333
+ return pd.DataFrame([{"Sequence": f"An error occurred during generation: {e}", "Predicted Probability": "N/A"}])
334
 
335
  # --------------------------------------------------------------------------
336
  # SECTION 4: GRADIO UI CONSTRUCTION
337
  # --------------------------------------------------------------------------
338
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
339
+ gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction")
340
  gr.Markdown("An integrated framework combining reinforcement learning and a Transformer model for the efficient prediction and innovative design of antioxidant peptides.")
341
 
342
  with gr.Tabs():
343
+ # --- PREDICTION TAB ---
344
  with gr.TabItem("Peptide Activity Predictor"):
345
  gr.Markdown("### Enter an amino acid sequence to predict its antioxidant activity.")
346
  with gr.Row():
347
  peptide_input = gr.Textbox(label="Peptide Sequence", placeholder="e.g., WHYHDYKY", scale=3)
348
  predict_button = gr.Button("Predict", variant="primary", scale=1)
349
  with gr.Row():
350
+ probability_output = gr.Textbox(label="Predicted Probability", interactive=False)
351
+ class_output = gr.Textbox(label="Predicted Class", interactive=False)
352
 
353
  predict_button.click(
354
  fn=predict_peptide_wrapper,
 
356
  outputs=[probability_output, class_output]
357
  )
358
  gr.Examples(
359
+ examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"], ["WKYG"]],
360
+ inputs=peptide_input,
361
+ fn=predict_peptide_wrapper,
362
+ outputs=[probability_output, class_output],
363
+ cache_examples=True
364
  )
365
 
366
+ # --- GENERATION TAB ---
367
  with gr.TabItem("Novel Sequence Generator"):
368
  gr.Markdown("### Set parameters to generate novel, high-activity antioxidant peptides.")
369
  with gr.Column():
370
  with gr.Row():
371
+ num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
372
+ min_len_input = gr.Slider(minimum=3, maximum=10, value=3, step=1, label="Minimum Length")
373
  max_len_input = gr.Slider(minimum=10, maximum=20, value=20, step=1, label="Maximum Length")
374
  with gr.Row():
375
  temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
376
+ diversity_input = gr.Slider(minimum=1.1, maximum=5.0, value=1.5, step=0.1, label="Diversity Factor (Larger initial pool for clustering)")
377
 
378
  generate_button = gr.Button("Generate Peptides", variant="primary")
379
+ results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
380
 
381
  generate_button.click(
382
  fn=generate_peptide_wrapper,