chshan commited on
Commit
ed68bd1
·
verified ·
1 Parent(s): 4f8dbbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -68
app.py CHANGED
@@ -2,7 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
- # Final version updated to use a LoRA-finetuned model for feature extraction.
6
 
7
  import os
8
  import torch
@@ -34,28 +34,26 @@ token2id["<EOS>"] = 1
34
  id2token = {i: t for t, i in token2id.items()}
35
  VOCAB_SIZE = len(token2id)
36
 
37
-
38
- # --- LoRA Feature Extractor Model Class ---
39
- # ✅ REPLACED: This new class handles loading the base model and attaching the LoRA adapter.
40
  class LoRAProtT5Extractor:
41
  def __init__(self, base_model_id, lora_adapter_path):
42
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
43
- print(f"Initializing LoRA-enhanced ProtT5 on device: {self.device}")
44
 
45
- print(f" - Loading base model and tokenizer from '{base_model_id}'...")
46
  base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
47
  self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id)
48
 
49
  if not os.path.exists(lora_adapter_path):
50
- raise FileNotFoundError(f"Error: LoRA adapter directory not found at: {lora_adapter_path}")
51
 
52
- print(f" - Loading and applying LoRA adapter from: {lora_adapter_path}")
53
  lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
54
 
55
- print(" - Merging LoRA weights for faster inference...")
56
  self.model = lora_model.merge_and_unload().to(self.device)
57
  self.model.eval()
58
- print(" - LoRA-enhanced feature extractor is ready.")
59
 
60
  def encode(self, sequence):
61
  if not sequence or not isinstance(sequence, str):
@@ -71,7 +69,8 @@ class LoRAProtT5Extractor:
71
  emb_np = embedding.squeeze(0).cpu().numpy()
72
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
73
 
74
- # --- Predictor Model Architecture (Unchanged) ---
 
75
  class AntioxidantPredictor(nn.Module):
76
  def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
77
  super(AntioxidantPredictor, self).__init__()
@@ -103,22 +102,34 @@ class AntioxidantPredictor(nn.Module):
103
  def get_temperature(self):
104
  return self.temperature.item()
105
 
106
- # --- Generator Model Architecture (Unchanged) ---
107
- class ProtT5Generator(nn.Module):
108
- def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
109
- super(ProtT5Generator, self).__init__()
110
- self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=token2id["<PAD>"])
111
- encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True)
112
- self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
 
 
 
 
 
 
 
 
113
  self.lm_head = nn.Linear(embed_dim, vocab_size)
 
114
  self.vocab_size = vocab_size
115
  self.eos_token_id = token2id["<EOS>"]
116
  self.pad_token_id = token2id["<PAD>"]
 
117
 
118
  def forward(self, input_ids):
119
- embeddings = self.embed_tokens(input_ids)
120
- encoder_output = self.encoder(embeddings)
121
- logits = self.lm_head(encoder_output)
 
122
  return logits
123
 
124
  def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
@@ -132,6 +143,8 @@ class ProtT5Generator(nn.Module):
132
  probs = torch.softmax(next_logits, dim=-1)
133
  next_token = torch.multinomial(probs, num_samples=1)
134
  generated = torch.cat((generated, next_token), dim=1)
 
 
135
  return generated
136
 
137
  def decode(self, token_ids_batch):
@@ -140,8 +153,8 @@ class ProtT5Generator(nn.Module):
140
  seq = ""
141
  for token_id in ids_tensor.tolist()[1:]:
142
  if token_id == self.eos_token_id: break
143
- if token_id == self.pad_token_id: continue
144
- seq += id2token.get(token_id, "")
145
  sequences.append(seq)
146
  return sequences
147
 
@@ -151,34 +164,36 @@ try:
151
  except ImportError:
152
  raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
153
 
154
- # --- Clustering Logic (Unchanged) ---
155
  def cluster_sequences(generator, sequences, num_clusters, device):
156
  if not sequences or len(sequences) < num_clusters:
157
  return sequences[:num_clusters]
158
  with torch.no_grad():
159
  token_ids_list = []
160
- max_len = max((len(seq) for seq in sequences), default=0) + 2
161
  for seq in sequences:
162
- ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
163
- ids = [np.random.randint(2, VOCAB_SIZE)] + ids
164
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
165
  token_ids_list.append(ids)
166
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
 
 
167
  embeddings = generator.embed_tokens(input_ids)
 
168
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
169
  seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
170
  seq_embeds_np = seq_embeds.cpu().numpy()
 
171
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
172
- representatives = []
173
  for i in range(int(num_clusters)):
174
- indices = np.where(kmeans.labels_ == i)[0]
175
- if len(indices) == 0: continue
176
- cluster_center = kmeans.cluster_centers_[i]
177
- cluster_embeddings = seq_embeds_np[indices]
178
- distances = np.linalg.norm(cluster_embeddings - cluster_center, axis=1)
179
- representative_index = indices[np.argmin(distances)]
180
- representatives.append(sequences[representative_index])
181
- return representatives
182
 
183
  # --------------------------------------------------------------------------
184
  # SECTION 2: GLOBAL MODEL AND DEPENDENCY LOADING
@@ -189,40 +204,43 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
189
 
190
  try:
191
  # --- Define file paths ---
192
- PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
193
- SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
194
- GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
195
-
196
- # ✅ UPDATED: Define paths for LoRA-based loading
197
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
198
- LORA_ADAPTER_PATH = "./lora_finetuned_prott5" # Assumes LoRA files are in this directory
199
-
200
- # --- Load Predictor Model (Head) ---
201
- print(f"Loading Predictor from: {PREDICTOR_CHECKPOINT_PATH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
203
- PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
204
  PREDICTOR_MODEL.to(DEVICE)
205
  PREDICTOR_MODEL.eval()
206
- print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
207
-
208
- # --- Load Scaler & LoRA Feature Extractor ---
209
- print(f"Loading Scaler from: {SCALER_PATH}")
210
- SCALER = joblib.load(SCALER_PATH)
211
- print("Loading LoRA-enhanced ProtT5 Feature Extractor...")
212
- # ✅ UPDATED: Instantiate the new LoRA extractor class
213
- PROTT5_EXTRACTOR = LoRAProtT5Extractor(
214
  base_model_id=PROTT5_BASE_MODEL_ID,
215
- lora_adapter_path=LORA_ADAPTER_PATH
 
216
  )
217
- print("✅ Scaler and Feature Extractor loaded.")
218
-
219
- # --- Load Generator Model ---
220
- print(f"Loading Generator from: {GENERATOR_CHECKPOINT_PATH}")
221
- GENERATOR_MODEL = ProtT5Generator(vocab_size=VOCAB_SIZE)
222
- GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
223
  GENERATOR_MODEL.to(DEVICE)
224
  GENERATOR_MODEL.eval()
225
- print("✅ Generator model loaded.")
226
 
227
  print("\n--- All models loaded! Gradio app is ready. ---\n")
228
 
@@ -231,7 +249,7 @@ except Exception as e:
231
  raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
232
 
233
  # --------------------------------------------------------------------------
234
- # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI (Unchanged logic)
235
  # --------------------------------------------------------------------------
236
 
237
  def predict_peptide_wrapper(sequence_str):
@@ -239,10 +257,9 @@ def predict_peptide_wrapper(sequence_str):
239
  return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
240
 
241
  try:
242
- # This function call remains the same because the PROTT5_EXTRACTOR object,
243
- # despite its new internal logic, provides the same interface.
244
- features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
245
- scaled_features = SCALER.transform(features.reshape(1, -1))
246
 
247
  with torch.no_grad():
248
  features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
@@ -306,7 +323,7 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
306
  time.sleep(1)
307
 
308
  if not validated_pool:
309
- return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with the current settings. Try different parameters.", "Predicted Probability": "N/A"}])
310
 
311
  high_quality_sequences = list(validated_pool.keys())
312
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
 
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
+ # Final version updated to use an AdvancedProtT5Generator with a LoRA backbone.
6
 
7
  import os
8
  import torch
 
34
  id2token = {i: t for t, i in token2id.items()}
35
  VOCAB_SIZE = len(token2id)
36
 
37
+ # --- Validator's Feature Extractor Class ---
 
 
38
  class LoRAProtT5Extractor:
39
  def __init__(self, base_model_id, lora_adapter_path):
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"Initializing Validator Feature Extractor on device: {self.device}")
42
 
43
+ print(f" - [Validator] Loading base model and tokenizer from '{base_model_id}'...")
44
  base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
45
  self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id)
46
 
47
  if not os.path.exists(lora_adapter_path):
48
+ raise FileNotFoundError(f"Error: Validator LoRA adapter directory not found at: {lora_adapter_path}")
49
 
50
+ print(f" - [Validator] Loading and applying LoRA adapter from: {lora_adapter_path}")
51
  lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
52
 
53
+ print(" - [Validator] Merging LoRA weights for faster inference...")
54
  self.model = lora_model.merge_and_unload().to(self.device)
55
  self.model.eval()
56
+ print(" - Validator feature extractor is ready.")
57
 
58
  def encode(self, sequence):
59
  if not sequence or not isinstance(sequence, str):
 
69
  emb_np = embedding.squeeze(0).cpu().numpy()
70
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
71
 
72
+
73
+ # --- Predictor Model Head Architecture (Unchanged) ---
74
  class AntioxidantPredictor(nn.Module):
75
  def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
76
  super(AntioxidantPredictor, self).__init__()
 
102
  def get_temperature(self):
103
  return self.temperature.item()
104
 
105
+
106
+ # --- ✅ NEW Generator Model Architecture ---
107
+ class AdvancedProtT5Generator(nn.Module):
108
+ def __init__(self, base_model_id, lora_adapter_path, vocab_size):
109
+ super(AdvancedProtT5Generator, self).__init__()
110
+
111
+ print(f" - [Generator] Loading base ProtT5 model from '{base_model_id}'...")
112
+ base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
113
+
114
+ print(f" - [Generator] Applying LoRA adapter from: {lora_adapter_path}")
115
+ self.backbone = PeftModel.from_pretrained(base_model, lora_adapter_path)
116
+
117
+ # Expose the embedding layer for the clustering function
118
+ self.embed_tokens = self.backbone.get_input_embeddings()
119
+
120
+ embed_dim = self.backbone.config.d_model # Should be 1024
121
  self.lm_head = nn.Linear(embed_dim, vocab_size)
122
+
123
  self.vocab_size = vocab_size
124
  self.eos_token_id = token2id["<EOS>"]
125
  self.pad_token_id = token2id["<PAD>"]
126
+ print(" - Advanced Generator framework initialized.")
127
 
128
  def forward(self, input_ids):
129
+ attention_mask = (input_ids != self.pad_token_id).int()
130
+ outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
131
+ sequence_output = outputs.last_hidden_state
132
+ logits = self.lm_head(sequence_output)
133
  return logits
134
 
135
  def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
 
143
  probs = torch.softmax(next_logits, dim=-1)
144
  next_token = torch.multinomial(probs, num_samples=1)
145
  generated = torch.cat((generated, next_token), dim=1)
146
+ if (generated == self.eos_token_id).any(dim=1).all():
147
+ break
148
  return generated
149
 
150
  def decode(self, token_ids_batch):
 
153
  seq = ""
154
  for token_id in ids_tensor.tolist()[1:]:
155
  if token_id == self.eos_token_id: break
156
+ if token_id == token2id["<PAD>"]: continue
157
+ seq += id2token.get(token_id, "?")
158
  sequences.append(seq)
159
  return sequences
160
 
 
164
  except ImportError:
165
  raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
166
 
167
+ # --- ✅ UPDATED Clustering Logic ---
168
  def cluster_sequences(generator, sequences, num_clusters, device):
169
  if not sequences or len(sequences) < num_clusters:
170
  return sequences[:num_clusters]
171
  with torch.no_grad():
172
  token_ids_list = []
173
+ max_len = max(len(seq) for seq in sequences) + 2
174
  for seq in sequences:
175
+ ids = [np.random.randint(2, VOCAB_SIZE)] + [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
 
176
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
177
  token_ids_list.append(ids)
178
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
179
+
180
+ # Use the generator's exposed embedding layer
181
  embeddings = generator.embed_tokens(input_ids)
182
+
183
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
184
  seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
185
  seq_embeds_np = seq_embeds.cpu().numpy()
186
+
187
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
188
+ reps = []
189
  for i in range(int(num_clusters)):
190
+ idxs = np.where(kmeans.labels_ == i)[0]
191
+ if len(idxs) == 0: continue
192
+ center = kmeans.cluster_centers_[i]
193
+ distances = np.linalg.norm(seq_embeds_np[idxs] - center, axis=1)
194
+ rep_idx = idxs[np.argmin(distances)]
195
+ reps.append(sequences[rep_idx])
196
+ return reps
 
197
 
198
  # --------------------------------------------------------------------------
199
  # SECTION 2: GLOBAL MODEL AND DEPENDENCY LOADING
 
204
 
205
  try:
206
  # --- Define file paths ---
 
 
 
 
 
207
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
208
+
209
+ # Paths for the Validator System
210
+ VALIDATOR_LORA_PATH = "./lora_finetuned_prott5"
211
+ PREDICTOR_HEAD_CHECKPOINT_PATH = "./predictor_with_lora_checkpoints/final_predictor_with_lora.pth"
212
+ SCALER_PATH = "./predictor_with_lora_checkpoints/scaler_lora.pkl"
213
+
214
+ # Paths for the Generator System
215
+ GENERATOR_LORA_DIR = "./generator_with_lora_output/final_lora_generator"
216
+ GENERATOR_LM_HEAD_PATH = os.path.join(GENERATOR_LORA_DIR, "lm_head.pth")
217
+
218
+ # --- Load Validator System ---
219
+ print("--- Loading Validator System ---")
220
+ VALIDATOR_SCALER = joblib.load(SCALER_PATH)
221
+ VALIDATOR_EXTRACTOR = LoRAProtT5Extractor(
222
+ base_model_id=PROTT5_BASE_MODEL_ID,
223
+ lora_adapter_path=VALIDATOR_LORA_PATH
224
+ )
225
  PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
226
+ PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_HEAD_CHECKPOINT_PATH, map_location=DEVICE))
227
  PREDICTOR_MODEL.to(DEVICE)
228
  PREDICTOR_MODEL.eval()
229
+ print("✅ Validator System loaded successfully.")
230
+
231
+ # --- Load Generator System ---
232
+ print("\n--- Loading Generator System ---")
233
+ GENERATOR_MODEL = AdvancedProtT5Generator(
 
 
 
234
  base_model_id=PROTT5_BASE_MODEL_ID,
235
+ lora_adapter_path=GENERATOR_LORA_DIR,
236
+ vocab_size=VOCAB_SIZE
237
  )
238
+ if not os.path.exists(GENERATOR_LM_HEAD_PATH):
239
+ raise FileNotFoundError(f"Generator's lm_head weights not found at: {GENERATOR_LM_HEAD_PATH}")
240
+ GENERATOR_MODEL.lm_head.load_state_dict(torch.load(GENERATOR_LM_HEAD_PATH, map_location=DEVICE))
 
 
 
241
  GENERATOR_MODEL.to(DEVICE)
242
  GENERATOR_MODEL.eval()
243
+ print("✅ Generator System loaded successfully.")
244
 
245
  print("\n--- All models loaded! Gradio app is ready. ---\n")
246
 
 
249
  raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
250
 
251
  # --------------------------------------------------------------------------
252
+ # SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI
253
  # --------------------------------------------------------------------------
254
 
255
  def predict_peptide_wrapper(sequence_str):
 
257
  return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
258
 
259
  try:
260
+ # Use the VALIDATOR's feature extractor
261
+ features = extract_features(sequence_str.upper(), VALIDATOR_EXTRACTOR, L_fixed=29, d_model_pe=16)
262
+ scaled_features = VALIDATOR_SCALER.transform(features.reshape(1, -1))
 
263
 
264
  with torch.no_grad():
265
  features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
 
323
  time.sleep(1)
324
 
325
  if not validated_pool:
326
+ return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with current settings.", "Predicted Probability": "N/A"}])
327
 
328
  high_quality_sequences = list(validated_pool.keys())
329
  final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)