chshan commited on
Commit
897c594
·
verified ·
1 Parent(s): 334ea25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py - RLAnOxPeptide Gradio Web Application (FINAL CORRECTED VERSION)
2
 
3
  import os
4
  import torch
@@ -10,6 +10,7 @@ import gradio as gr
10
  from sklearn.cluster import KMeans
11
  from tqdm import tqdm
12
  import transformers
 
13
 
14
  # Suppress verbose logging from transformers
15
  transformers.logging.set_verbosity_error()
@@ -19,7 +20,7 @@ transformers.logging.set_verbosity_error()
19
  # These definitions are now synchronized with your provided, working scripts.
20
  # --------------------------------------------------------------------------
21
 
22
- # --- Vocabulary Definition ---
23
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
24
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
25
  token2id["<PAD>"] = 0
@@ -27,7 +28,7 @@ token2id["<EOS>"] = 1
27
  id2token = {i: t for t, i in token2id.items()}
28
  VOCAB_SIZE = len(token2id)
29
 
30
- # --- Predictor Model Architecture (Copied from your LATEST antioxidant_predictor_5.py) ---
31
  class AntioxidantPredictor(nn.Module):
32
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
33
  super(AntioxidantPredictor, self).__init__()
@@ -75,9 +76,7 @@ class AntioxidantPredictor(nn.Module):
75
  fused_features = self.fusion_fc(fused_features)
76
 
77
  logits = self.classifier(fused_features)
78
-
79
  logits_scaled = logits / self.temperature
80
-
81
  return logits_scaled
82
 
83
  def set_temperature(self, temp_value, device):
@@ -86,7 +85,7 @@ class AntioxidantPredictor(nn.Module):
86
  def get_temperature(self):
87
  return self.temperature.item()
88
 
89
- # --- Generator Model Architecture (Copied from your generator.py) ---
90
  class ProtT5Generator(nn.Module):
91
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
92
  super(ProtT5Generator, self).__init__()
@@ -97,7 +96,7 @@ class ProtT5Generator(nn.Module):
97
  self.vocab_size = vocab_size
98
  self.eos_token_id = token2id["<EOS>"]
99
  self.pad_token_id = token2id["<PAD>"]
100
-
101
  def forward(self, input_ids):
102
  embeddings = self.embed_tokens(input_ids)
103
  encoder_output = self.encoder(embeddings)
@@ -130,7 +129,7 @@ class ProtT5Generator(nn.Module):
130
  seqs.append(seq)
131
  return seqs
132
 
133
- # --- Feature Extraction (needs feature_extract.py) ---
134
  try:
135
  from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
136
  except ImportError:
@@ -142,10 +141,10 @@ def cluster_sequences(generator, sequences, num_clusters, device):
142
  return sequences[:num_clusters]
143
  with torch.no_grad():
144
  token_ids_list = []
145
- max_len = max(len(seq) for seq in sequences) + 2
146
  for seq in sequences:
147
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
148
- ids = [np.random.randint(2, VOCAB_SIZE)] + ids
149
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
150
  token_ids_list.append(ids)
151
 
@@ -181,13 +180,16 @@ try:
181
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
182
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
183
  PROTT5_BASE_MODEL_PATH = "prott5/model/"
 
184
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
185
 
186
  # --- Load Predictor ---
187
  print("Loading Predictor Model...")
 
188
  PREDICTOR_MODEL = AntioxidantPredictor(
189
  input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
190
  )
 
191
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
192
  PREDICTOR_MODEL.to(DEVICE)
193
  PREDICTOR_MODEL.eval()
@@ -226,7 +228,7 @@ def predict_peptide_wrapper(sequence_str):
226
  return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
227
 
228
  try:
229
- # Use feature extraction params from your working predictor.py
230
  features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
231
  scaled_features = SCALER.transform(features.reshape(1, -1))
232
 
@@ -243,6 +245,7 @@ def predict_peptide_wrapper(sequence_str):
243
  return "N/A", f"An error occurred during processing: {e}"
244
 
245
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
 
246
  num_to_generate = int(num_to_generate)
247
  min_len = int(min_len)
248
  max_len = int(max_len)
@@ -254,8 +257,7 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
254
 
255
  with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
256
  while len(unique_seqs) < target_pool_size:
257
- # Generate a surplus to account for filtering
258
- batch_size = max(1, (target_pool_size - len(unique_seqs)) * 2)
259
  with torch.no_grad():
260
  generated_tokens = GENERATOR_MODEL.sample(
261
  batch_size=batch_size, max_length=max_len, device=DEVICE,
 
1
+ # app.py - RLAnOxPeptide Gradio Web Application (FINAL CORRECTED VERSION - Synced with local scripts)
2
 
3
  import os
4
  import torch
 
10
  from sklearn.cluster import KMeans
11
  from tqdm import tqdm
12
  import transformers
13
+ import argparse # We won't use argparse but might need it for compatibility if any function expects it
14
 
15
  # Suppress verbose logging from transformers
16
  transformers.logging.set_verbosity_error()
 
20
  # These definitions are now synchronized with your provided, working scripts.
21
  # --------------------------------------------------------------------------
22
 
23
+ # --- Vocabulary Definition (from generator.py) ---
24
  AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
25
  token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
26
  token2id["<PAD>"] = 0
 
28
  id2token = {i: t for t, i in token2id.items()}
29
  VOCAB_SIZE = len(token2id)
30
 
31
+ # --- Predictor Model Architecture (Copied VERBATIM from your antioxidant_predictor_5.py) ---
32
  class AntioxidantPredictor(nn.Module):
33
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
34
  super(AntioxidantPredictor, self).__init__()
 
76
  fused_features = self.fusion_fc(fused_features)
77
 
78
  logits = self.classifier(fused_features)
 
79
  logits_scaled = logits / self.temperature
 
80
  return logits_scaled
81
 
82
  def set_temperature(self, temp_value, device):
 
85
  def get_temperature(self):
86
  return self.temperature.item()
87
 
88
+ # --- Generator Model Architecture (Copied VERBATIM from your generator.py) ---
89
  class ProtT5Generator(nn.Module):
90
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
91
  super(ProtT5Generator, self).__init__()
 
96
  self.vocab_size = vocab_size
97
  self.eos_token_id = token2id["<EOS>"]
98
  self.pad_token_id = token2id["<PAD>"]
99
+
100
  def forward(self, input_ids):
101
  embeddings = self.embed_tokens(input_ids)
102
  encoder_output = self.encoder(embeddings)
 
129
  seqs.append(seq)
130
  return seqs
131
 
132
+ # --- Feature Extraction (needs feature_extract.py in the same directory) ---
133
  try:
134
  from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
135
  except ImportError:
 
141
  return sequences[:num_clusters]
142
  with torch.no_grad():
143
  token_ids_list = []
144
+ max_len = max(len(seq) for seq in sequences) + 2
145
  for seq in sequences:
146
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
147
+ ids = [np.random.randint(2, VOCAB_SIZE)] + ids
148
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
149
  token_ids_list.append(ids)
150
 
 
180
  SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
181
  GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
182
  PROTT5_BASE_MODEL_PATH = "prott5/model/"
183
+ # This path is now used by the FeatureProtT5Model to load the fine-tuned weights
184
  FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
185
 
186
  # --- Load Predictor ---
187
  print("Loading Predictor Model...")
188
+ # Initialize the correct class
189
  PREDICTOR_MODEL = AntioxidantPredictor(
190
  input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
191
  )
192
+ # Load the state dict that matches this class
193
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
194
  PREDICTOR_MODEL.to(DEVICE)
195
  PREDICTOR_MODEL.eval()
 
228
  return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
229
 
230
  try:
231
+ # These L_fixed and d_model_pe values are from your predictor.py args
232
  features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
233
  scaled_features = SCALER.transform(features.reshape(1, -1))
234
 
 
245
  return "N/A", f"An error occurred during processing: {e}"
246
 
247
  def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
248
+ # This logic is a direct adaptation of your generator.py main function
249
  num_to_generate = int(num_to_generate)
250
  min_len = int(min_len)
251
  max_len = int(max_len)
 
257
 
258
  with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
259
  while len(unique_seqs) < target_pool_size:
260
+ batch_size = max(1, (target_pool_size - len(unique_seqs)))
 
261
  with torch.no_grad():
262
  generated_tokens = GENERATOR_MODEL.sample(
263
  batch_size=batch_size, max_length=max_len, device=DEVICE,