chshan commited on
Commit
334ea25
·
verified ·
1 Parent(s): 1ed0175

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -54
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py - RLAnOxPeptide Gradio Web Application (Corrected Version)
2
 
3
  import os
4
  import torch
@@ -27,53 +27,66 @@ token2id["<EOS>"] = 1
27
  id2token = {i: t for t, i in token2id.items()}
28
  VOCAB_SIZE = len(token2id)
29
 
30
- # --- Predictor Model Architecture (VERSION THAT MATCHES YOUR .pth FILE) ---
31
  class AntioxidantPredictor(nn.Module):
32
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
33
  super(AntioxidantPredictor, self).__init__()
34
- self.t5_dim = 1024
35
- self.hand_crafted_dim = input_dim - self.t5_dim
36
-
 
 
37
  encoder_layer = nn.TransformerEncoderLayer(
38
- d_model=self.t5_dim,
39
- nhead=transformer_heads,
40
- dropout=transformer_dropout,
41
  batch_first=True
42
  )
43
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
44
-
45
- self.mlp = nn.Sequential(
46
- nn.Linear(input_dim, 512),
 
47
  nn.ReLU(),
48
- nn.Dropout(0.5),
 
 
 
 
 
 
49
  nn.Linear(512, 256),
50
  nn.ReLU(),
51
- nn.Dropout(0.5),
52
  nn.Linear(256, 1)
53
  )
54
- self.temperature = nn.Parameter(torch.ones(1))
55
 
56
- def forward(self, fused_features):
57
- prot_t5_features = fused_features[:, :self.t5_dim]
58
- hand_crafted_features = fused_features[:, self.t5_dim:]
 
59
 
60
- prot_t5_features_unsqueezed = prot_t5_features.unsqueeze(1)
61
- transformer_output = self.transformer_encoder(prot_t5_features_unsqueezed)
62
- transformer_output_pooled = transformer_output.mean(dim=1)
63
 
64
- combined_features = torch.cat((transformer_output_pooled, hand_crafted_features), dim=1)
 
65
 
66
- logits = self.mlp(combined_features)
67
 
68
- return logits / self.temperature
69
-
70
- def get_temperature(self):
71
- return self.temperature.item()
72
 
73
  def set_temperature(self, temp_value, device):
74
  self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
75
 
76
- # --- Generator Model Architecture (from generator.py) ---
 
 
 
77
  class ProtT5Generator(nn.Module):
78
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
79
  super(ProtT5Generator, self).__init__()
@@ -102,7 +115,6 @@ class ProtT5Generator(nn.Module):
102
  probs = torch.softmax(next_logits, dim=-1)
103
  next_token = torch.multinomial(probs, num_samples=1)
104
  generated = torch.cat((generated, next_token), dim=1)
105
- # Early stop if all sequences in batch have generated an EOS token
106
  if (generated == self.eos_token_id).any(dim=1).all():
107
  break
108
  return generated
@@ -111,19 +123,18 @@ class ProtT5Generator(nn.Module):
111
  seqs = []
112
  for ids_tensor in token_ids_batch:
113
  seq = ""
114
- # Start from index 1 to skip the initial random start token
115
- for token_id in ids_tensor.tolist()[1:]:
116
  if token_id == self.eos_token_id: break
117
  if token_id == self.pad_token_id: continue
118
  seq += id2token.get(token_id, "?")
119
  seqs.append(seq)
120
  return seqs
121
 
122
- # --- Feature Extraction Logic (needs feature_extract.py) ---
123
  try:
124
  from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
125
  except ImportError:
126
- raise gr.Error("Failed to import feature_extract.py. Please ensure the file is in the same directory as app.py.")
127
 
128
  # --- Clustering Logic (from generator.py) ---
129
  def cluster_sequences(generator, sequences, num_clusters, device):
@@ -131,10 +142,10 @@ def cluster_sequences(generator, sequences, num_clusters, device):
131
  return sequences[:num_clusters]
132
  with torch.no_grad():
133
  token_ids_list = []
134
- max_len = max(len(seq) for seq in sequences) + 2
135
  for seq in sequences:
136
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
137
- ids = [np.random.randint(2, VOCAB_SIZE)] + ids
138
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
139
  token_ids_list.append(ids)
140
 
@@ -161,7 +172,7 @@ def cluster_sequences(generator, sequences, num_clusters, device):
161
  # --------------------------------------------------------------------------
162
  # SECTION 2: GLOBAL MODEL LOADING
163
  # --------------------------------------------------------------------------
164
- print("Loading all models and dependencies. Please wait...")
165
  DEVICE = "cpu"
166
 
167
  try:
@@ -174,7 +185,9 @@ try:
174
 
175
  # --- Load Predictor ---
176
  print("Loading Predictor Model...")
177
- PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914, transformer_layers=3, transformer_heads=4)
 
 
178
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
179
  PREDICTOR_MODEL.to(DEVICE)
180
  PREDICTOR_MODEL.eval()
@@ -191,7 +204,9 @@ try:
191
 
192
  # --- Load Generator ---
193
  print("Loading Generator Model...")
194
- GENERATOR_MODEL = ProtT5Generator(vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8)
 
 
195
  GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
196
  GENERATOR_MODEL.to(DEVICE)
197
  GENERATOR_MODEL.eval()
@@ -211,7 +226,7 @@ def predict_peptide_wrapper(sequence_str):
211
  return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
212
 
213
  try:
214
- # These L_fixed and d_model_pe values are from your predictor.py args
215
  features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
216
  scaled_features = SCALER.transform(features.reshape(1, -1))
217
 
@@ -237,27 +252,22 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
237
  target_pool_size = int(num_to_generate * diversity_factor)
238
  unique_seqs = set()
239
 
240
- # A simple generation loop based on generator.py logic
241
  with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
242
  while len(unique_seqs) < target_pool_size:
243
- batch_size = (target_pool_size - len(unique_seqs))
 
244
  with torch.no_grad():
245
  generated_tokens = GENERATOR_MODEL.sample(
246
- batch_size=max(1, batch_size),
247
- max_length=max_len,
248
- device=DEVICE,
249
- temperature=temperature,
250
- min_decoded_length=min_len
251
  )
252
  decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
253
 
254
- newly_added = 0
255
  for seq in decoded:
256
  if min_len <= len(seq) <= max_len:
257
- if seq not in unique_seqs:
258
- unique_seqs.add(seq)
259
- newly_added +=1
260
- pbar.update(newly_added)
261
 
262
  candidate_seqs = list(unique_seqs)
263
 
@@ -267,13 +277,13 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
267
  prob_str, _ = predict_peptide_wrapper(seq)
268
  try:
269
  prob = float(prob_str)
270
- if prob > 0.90: # Filter for high-quality peptides
271
  validated_pool[seq] = prob
272
  except (ValueError, TypeError):
273
  continue
274
 
275
  if not validated_pool:
276
- return pd.DataFrame({"Sequence": ["No high-activity peptides (>0.9 prob) were generated."], "Predicted Probability": ["N/A"]})
277
 
278
  high_quality_sequences = list(validated_pool.keys())
279
 
@@ -289,10 +299,10 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
289
 
290
  except Exception as e:
291
  print(f"Generation error: {e}")
292
- return pd.DataFrame({"Sequence": [f"An error occurred during generation: {e}"], "Predicted Probability": ["N/A"]})
293
 
294
  # --------------------------------------------------------------------------
295
- # SECTION 4: GRADIO UI CONSTRUCTION (ALL ENGLISH)
296
  # --------------------------------------------------------------------------
297
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
298
  gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
 
1
+ # app.py - RLAnOxPeptide Gradio Web Application (FINAL CORRECTED VERSION)
2
 
3
  import os
4
  import torch
 
27
  id2token = {i: t for t, i in token2id.items()}
28
  VOCAB_SIZE = len(token2id)
29
 
30
+ # --- Predictor Model Architecture (Copied from your LATEST antioxidant_predictor_5.py) ---
31
  class AntioxidantPredictor(nn.Module):
32
  def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
33
  super(AntioxidantPredictor, self).__init__()
34
+ self.prott5_dim = 1024
35
+ self.handcrafted_dim = input_dim - self.prott5_dim
36
+ self.seq_len = 16
37
+ self.prott5_feature_dim = 64 # 16 * 64 = 1024
38
+
39
  encoder_layer = nn.TransformerEncoderLayer(
40
+ d_model=self.prott5_feature_dim,
41
+ nhead=transformer_heads,
42
+ dropout=transformer_dropout,
43
  batch_first=True
44
  )
45
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
46
+
47
+ fused_dim = self.prott5_feature_dim + self.handcrafted_dim
48
+ self.fusion_fc = nn.Sequential(
49
+ nn.Linear(fused_dim, 1024),
50
  nn.ReLU(),
51
+ nn.Dropout(0.3),
52
+ nn.Linear(1024, 512),
53
+ nn.ReLU(),
54
+ nn.Dropout(0.3)
55
+ )
56
+
57
+ self.classifier = nn.Sequential(
58
  nn.Linear(512, 256),
59
  nn.ReLU(),
60
+ nn.Dropout(0.3),
61
  nn.Linear(256, 1)
62
  )
63
+ self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
64
 
65
+ def forward(self, x, *args):
66
+ batch_size = x.size(0)
67
+ prot_t5_features = x[:, :self.prott5_dim]
68
+ handcrafted_features = x[:, self.prott5_dim:]
69
 
70
+ prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
71
+ encoded_seq = self.transformer_encoder(prot_t5_seq)
72
+ refined_prott5 = encoded_seq.mean(dim=1)
73
 
74
+ fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
75
+ fused_features = self.fusion_fc(fused_features)
76
 
77
+ logits = self.classifier(fused_features)
78
 
79
+ logits_scaled = logits / self.temperature
80
+
81
+ return logits_scaled
 
82
 
83
  def set_temperature(self, temp_value, device):
84
  self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
85
 
86
+ def get_temperature(self):
87
+ return self.temperature.item()
88
+
89
+ # --- Generator Model Architecture (Copied from your generator.py) ---
90
  class ProtT5Generator(nn.Module):
91
  def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
92
  super(ProtT5Generator, self).__init__()
 
115
  probs = torch.softmax(next_logits, dim=-1)
116
  next_token = torch.multinomial(probs, num_samples=1)
117
  generated = torch.cat((generated, next_token), dim=1)
 
118
  if (generated == self.eos_token_id).any(dim=1).all():
119
  break
120
  return generated
 
123
  seqs = []
124
  for ids_tensor in token_ids_batch:
125
  seq = ""
126
+ for token_id in ids_tensor.tolist()[1:]: # Skip start token
 
127
  if token_id == self.eos_token_id: break
128
  if token_id == self.pad_token_id: continue
129
  seq += id2token.get(token_id, "?")
130
  seqs.append(seq)
131
  return seqs
132
 
133
+ # --- Feature Extraction (needs feature_extract.py) ---
134
  try:
135
  from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
136
  except ImportError:
137
+ raise gr.Error("Failed to import feature_extract.py. Ensure it is in the same directory.")
138
 
139
  # --- Clustering Logic (from generator.py) ---
140
  def cluster_sequences(generator, sequences, num_clusters, device):
 
142
  return sequences[:num_clusters]
143
  with torch.no_grad():
144
  token_ids_list = []
145
+ max_len = max(len(seq) for seq in sequences) + 2
146
  for seq in sequences:
147
  ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
148
+ ids = [np.random.randint(2, VOCAB_SIZE)] + ids
149
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
150
  token_ids_list.append(ids)
151
 
 
172
  # --------------------------------------------------------------------------
173
  # SECTION 2: GLOBAL MODEL LOADING
174
  # --------------------------------------------------------------------------
175
+ print("Loading all models and dependencies...")
176
  DEVICE = "cpu"
177
 
178
  try:
 
185
 
186
  # --- Load Predictor ---
187
  print("Loading Predictor Model...")
188
+ PREDICTOR_MODEL = AntioxidantPredictor(
189
+ input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
190
+ )
191
  PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
192
  PREDICTOR_MODEL.to(DEVICE)
193
  PREDICTOR_MODEL.eval()
 
204
 
205
  # --- Load Generator ---
206
  print("Loading Generator Model...")
207
+ GENERATOR_MODEL = ProtT5Generator(
208
+ vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
209
+ )
210
  GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
211
  GENERATOR_MODEL.to(DEVICE)
212
  GENERATOR_MODEL.eval()
 
226
  return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
227
 
228
  try:
229
+ # Use feature extraction params from your working predictor.py
230
  features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
231
  scaled_features = SCALER.transform(features.reshape(1, -1))
232
 
 
252
  target_pool_size = int(num_to_generate * diversity_factor)
253
  unique_seqs = set()
254
 
 
255
  with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
256
  while len(unique_seqs) < target_pool_size:
257
+ # Generate a surplus to account for filtering
258
+ batch_size = max(1, (target_pool_size - len(unique_seqs)) * 2)
259
  with torch.no_grad():
260
  generated_tokens = GENERATOR_MODEL.sample(
261
+ batch_size=batch_size, max_length=max_len, device=DEVICE,
262
+ temperature=temperature, min_decoded_length=min_len
 
 
 
263
  )
264
  decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
265
 
266
+ initial_count = len(unique_seqs)
267
  for seq in decoded:
268
  if min_len <= len(seq) <= max_len:
269
+ unique_seqs.add(seq)
270
+ pbar.update(len(unique_seqs) - initial_count)
 
 
271
 
272
  candidate_seqs = list(unique_seqs)
273
 
 
277
  prob_str, _ = predict_peptide_wrapper(seq)
278
  try:
279
  prob = float(prob_str)
280
+ if prob > 0.90:
281
  validated_pool[seq] = prob
282
  except (ValueError, TypeError):
283
  continue
284
 
285
  if not validated_pool:
286
+ return pd.DataFrame([{"Sequence": "No high-activity peptides (>0.9 prob) were generated.", "Predicted Probability": "N/A"}])
287
 
288
  high_quality_sequences = list(validated_pool.keys())
289
 
 
299
 
300
  except Exception as e:
301
  print(f"Generation error: {e}")
302
+ return pd.DataFrame([{"Sequence": f"An error occurred: {e}", "Predicted Probability": "N/A"}])
303
 
304
  # --------------------------------------------------------------------------
305
+ # SECTION 4: GRADIO UI CONSTRUCTION
306
  # --------------------------------------------------------------------------
307
  with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
308
  gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")