chshan commited on
Commit
836aa2b
Β·
verified Β·
1 Parent(s): 02b6e86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -35
app.py CHANGED
@@ -2,7 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
- # Final version updated to use an AdvancedProtT5Generator with a LoRA backbone.
6
 
7
  import os
8
  import torch
@@ -15,6 +15,7 @@ from sklearn.cluster import KMeans
15
  from tqdm import tqdm
16
  import transformers
17
  import time
 
18
 
19
  # NEW DEPENDENCY: peft library for LoRA
20
  from peft import PeftModel
@@ -34,20 +35,21 @@ token2id["<EOS>"] = 1
34
  id2token = {i: t for t, i in token2id.items()}
35
  VOCAB_SIZE = len(token2id)
36
 
 
37
  # --- Validator's Feature Extractor Class ---
 
38
  class LoRAProtT5Extractor:
39
- def __init__(self, base_model_id, lora_adapter_path):
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
  print(f"Initializing Validator Feature Extractor on device: {self.device}")
42
 
43
- print(f" - [Validator] Loading base model and tokenizer from '{base_model_id}'...")
44
- base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
45
- self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id)
46
 
47
  if not os.path.exists(lora_adapter_path):
48
  raise FileNotFoundError(f"Error: Validator LoRA adapter directory not found at: {lora_adapter_path}")
49
 
50
- print(f" - [Validator] Loading and applying LoRA adapter from: {lora_adapter_path}")
51
  lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
52
 
53
  print(" - [Validator] Merging LoRA weights for faster inference...")
@@ -58,14 +60,11 @@ class LoRAProtT5Extractor:
58
  def encode(self, sequence):
59
  if not sequence or not isinstance(sequence, str):
60
  return np.zeros((1, 1024), dtype=np.float32)
61
-
62
  seq_spaced = " ".join(list(sequence))
63
  encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True)
64
  encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
65
-
66
  with torch.no_grad():
67
  embedding = self.model(**encoded_input).last_hidden_state
68
-
69
  emb_np = embedding.squeeze(0).cpu().numpy()
70
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
71
 
@@ -78,15 +77,12 @@ class AntioxidantPredictor(nn.Module):
78
  self.handcrafted_dim = input_dim - self.prott5_dim
79
  self.seq_len = 16
80
  self.prott5_feature_dim = 64
81
-
82
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
83
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
84
-
85
  fused_dim = self.prott5_feature_dim + self.handcrafted_dim
86
  self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
87
  self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
88
  self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
89
-
90
  def forward(self, x):
91
  batch_size = x.size(0)
92
  prot_t5_features = x[:, :self.prott5_dim]
@@ -98,26 +94,23 @@ class AntioxidantPredictor(nn.Module):
98
  fused_output = self.fusion_fc(fused_features)
99
  logits = self.classifier(fused_output)
100
  return logits / self.temperature
101
-
102
  def get_temperature(self):
103
  return self.temperature.item()
104
 
105
 
106
- # --- βœ… NEW Generator Model Architecture ---
 
107
  class AdvancedProtT5Generator(nn.Module):
108
- def __init__(self, base_model_id, lora_adapter_path, vocab_size):
109
  super(AdvancedProtT5Generator, self).__init__()
110
 
111
- print(f" - [Generator] Loading base ProtT5 model from '{base_model_id}'...")
112
- base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
113
-
114
  print(f" - [Generator] Applying LoRA adapter from: {lora_adapter_path}")
115
  self.backbone = PeftModel.from_pretrained(base_model, lora_adapter_path)
116
 
117
- # Expose the embedding layer for the clustering function
118
  self.embed_tokens = self.backbone.get_input_embeddings()
119
 
120
- embed_dim = self.backbone.config.d_model # Should be 1024
121
  self.lm_head = nn.Linear(embed_dim, vocab_size)
122
 
123
  self.vocab_size = vocab_size
@@ -164,7 +157,7 @@ try:
164
  except ImportError:
165
  raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
166
 
167
- # --- βœ… UPDATED Clustering Logic ---
168
  def cluster_sequences(generator, sequences, num_clusters, device):
169
  if not sequences or len(sequences) < num_clusters:
170
  return sequences[:num_clusters]
@@ -176,14 +169,10 @@ def cluster_sequences(generator, sequences, num_clusters, device):
176
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
177
  token_ids_list.append(ids)
178
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
179
-
180
- # Use the generator's exposed embedding layer
181
  embeddings = generator.embed_tokens(input_ids)
182
-
183
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
184
  seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
185
  seq_embeds_np = seq_embeds.cpu().numpy()
186
-
187
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
188
  reps = []
189
  for i in range(int(num_clusters)):
@@ -205,21 +194,25 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
205
  try:
206
  # --- Define file paths ---
207
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
208
-
209
- # Paths for the Validator System
210
  VALIDATOR_LORA_PATH = "./lora_finetuned_prott5"
211
  PREDICTOR_HEAD_CHECKPOINT_PATH = "./predictor_with_lora_checkpoints/final_predictor_with_lora.pth"
212
  SCALER_PATH = "./predictor_with_lora_checkpoints/scaler_lora.pkl"
213
-
214
- # Paths for the Generator System
215
  GENERATOR_LORA_DIR = "./generator_with_lora_output/final_lora_generator"
216
  GENERATOR_LM_HEAD_PATH = os.path.join(GENERATOR_LORA_DIR, "lm_head.pth")
217
 
 
 
 
 
 
 
218
  # --- Load Validator System ---
219
- print("--- Loading Validator System ---")
220
  VALIDATOR_SCALER = joblib.load(SCALER_PATH)
 
221
  VALIDATOR_EXTRACTOR = LoRAProtT5Extractor(
222
- base_model_id=PROTT5_BASE_MODEL_ID,
 
223
  lora_adapter_path=VALIDATOR_LORA_PATH
224
  )
225
  PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
@@ -229,9 +222,10 @@ try:
229
  print("βœ… Validator System loaded successfully.")
230
 
231
  # --- Load Generator System ---
232
- print("\n--- Loading Generator System ---")
 
233
  GENERATOR_MODEL = AdvancedProtT5Generator(
234
- base_model_id=PROTT5_BASE_MODEL_ID,
235
  lora_adapter_path=GENERATOR_LORA_DIR,
236
  vocab_size=VOCAB_SIZE
237
  )
@@ -254,10 +248,9 @@ except Exception as e:
254
 
255
  def predict_peptide_wrapper(sequence_str):
256
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
257
- return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
258
 
259
  try:
260
- # Use the VALIDATOR's feature extractor
261
  features = extract_features(sequence_str.upper(), VALIDATOR_EXTRACTOR, L_fixed=29, d_model_pe=16)
262
  scaled_features = VALIDATOR_SCALER.transform(features.reshape(1, -1))
263
 
 
2
  # -*- coding: utf-8 -*-
3
 
4
  # app.py - RLAnOxPeptide Gradio Web Application
5
+
6
 
7
  import os
8
  import torch
 
15
  from tqdm import tqdm
16
  import transformers
17
  import time
18
+ import copy # βœ… ADDED: For deep copying the base model
19
 
20
  # NEW DEPENDENCY: peft library for LoRA
21
  from peft import PeftModel
 
35
  id2token = {i: t for t, i in token2id.items()}
36
  VOCAB_SIZE = len(token2id)
37
 
38
+
39
  # --- Validator's Feature Extractor Class ---
40
+ # βœ… MODIFIED: Accepts a pre-loaded model instead of loading its own.
41
  class LoRAProtT5Extractor:
42
+ def __init__(self, preloaded_base_model, preloaded_tokenizer, lora_adapter_path):
43
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
44
  print(f"Initializing Validator Feature Extractor on device: {self.device}")
45
 
46
+ base_model = preloaded_base_model
47
+ self.tokenizer = preloaded_tokenizer
 
48
 
49
  if not os.path.exists(lora_adapter_path):
50
  raise FileNotFoundError(f"Error: Validator LoRA adapter directory not found at: {lora_adapter_path}")
51
 
52
+ print(f" - [Validator] Applying LoRA adapter from: {lora_adapter_path}")
53
  lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
54
 
55
  print(" - [Validator] Merging LoRA weights for faster inference...")
 
60
  def encode(self, sequence):
61
  if not sequence or not isinstance(sequence, str):
62
  return np.zeros((1, 1024), dtype=np.float32)
 
63
  seq_spaced = " ".join(list(sequence))
64
  encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True)
65
  encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
 
66
  with torch.no_grad():
67
  embedding = self.model(**encoded_input).last_hidden_state
 
68
  emb_np = embedding.squeeze(0).cpu().numpy()
69
  return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
70
 
 
77
  self.handcrafted_dim = input_dim - self.prott5_dim
78
  self.seq_len = 16
79
  self.prott5_feature_dim = 64
 
80
  encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
81
  self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
 
82
  fused_dim = self.prott5_feature_dim + self.handcrafted_dim
83
  self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
84
  self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
85
  self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
 
86
  def forward(self, x):
87
  batch_size = x.size(0)
88
  prot_t5_features = x[:, :self.prott5_dim]
 
94
  fused_output = self.fusion_fc(fused_features)
95
  logits = self.classifier(fused_output)
96
  return logits / self.temperature
 
97
  def get_temperature(self):
98
  return self.temperature.item()
99
 
100
 
101
+ # --- Generator Model Architecture ---
102
+ # βœ… MODIFIED: Accepts a pre-loaded model instead of loading its own.
103
  class AdvancedProtT5Generator(nn.Module):
104
+ def __init__(self, preloaded_base_model, lora_adapter_path, vocab_size):
105
  super(AdvancedProtT5Generator, self).__init__()
106
 
107
+ base_model = preloaded_base_model
 
 
108
  print(f" - [Generator] Applying LoRA adapter from: {lora_adapter_path}")
109
  self.backbone = PeftModel.from_pretrained(base_model, lora_adapter_path)
110
 
 
111
  self.embed_tokens = self.backbone.get_input_embeddings()
112
 
113
+ embed_dim = self.backbone.config.d_model
114
  self.lm_head = nn.Linear(embed_dim, vocab_size)
115
 
116
  self.vocab_size = vocab_size
 
157
  except ImportError:
158
  raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
159
 
160
+ # --- Clustering Logic (Unchanged) ---
161
  def cluster_sequences(generator, sequences, num_clusters, device):
162
  if not sequences or len(sequences) < num_clusters:
163
  return sequences[:num_clusters]
 
169
  ids += [token2id["<PAD>"]] * (max_len - len(ids))
170
  token_ids_list.append(ids)
171
  input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
 
 
172
  embeddings = generator.embed_tokens(input_ids)
 
173
  mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
174
  seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
175
  seq_embeds_np = seq_embeds.cpu().numpy()
 
176
  kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
177
  reps = []
178
  for i in range(int(num_clusters)):
 
194
  try:
195
  # --- Define file paths ---
196
  PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
 
 
197
  VALIDATOR_LORA_PATH = "./lora_finetuned_prott5"
198
  PREDICTOR_HEAD_CHECKPOINT_PATH = "./predictor_with_lora_checkpoints/final_predictor_with_lora.pth"
199
  SCALER_PATH = "./predictor_with_lora_checkpoints/scaler_lora.pkl"
 
 
200
  GENERATOR_LORA_DIR = "./generator_with_lora_output/final_lora_generator"
201
  GENERATOR_LM_HEAD_PATH = os.path.join(GENERATOR_LORA_DIR, "lm_head.pth")
202
 
203
+ # βœ… OPTIMIZED: Load the base model and tokenizer only ONCE
204
+ print(f"--- Loading Base ProtT5 Model ({PROTT5_BASE_MODEL_ID}) just once... ---")
205
+ base_prot_t5_model = transformers.T5EncoderModel.from_pretrained(PROTT5_BASE_MODEL_ID)
206
+ base_tokenizer = transformers.T5Tokenizer.from_pretrained(PROTT5_BASE_MODEL_ID)
207
+ print("βœ… Base ProtT5 Model loaded.")
208
+
209
  # --- Load Validator System ---
210
+ print("\n--- Initializing Validator System ---")
211
  VALIDATOR_SCALER = joblib.load(SCALER_PATH)
212
+ # Pass a deep copy of the base model to prevent modification conflicts
213
  VALIDATOR_EXTRACTOR = LoRAProtT5Extractor(
214
+ preloaded_base_model=copy.deepcopy(base_prot_t5_model),
215
+ preloaded_tokenizer=base_tokenizer,
216
  lora_adapter_path=VALIDATOR_LORA_PATH
217
  )
218
  PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
 
222
  print("βœ… Validator System loaded successfully.")
223
 
224
  # --- Load Generator System ---
225
+ print("\n--- Initializing Generator System ---")
226
+ # Pass a deep copy of the base model here as well
227
  GENERATOR_MODEL = AdvancedProtT5Generator(
228
+ preloaded_base_model=copy.deepcopy(base_prot_t5_model),
229
  lora_adapter_path=GENERATOR_LORA_DIR,
230
  vocab_size=VOCAB_SIZE
231
  )
 
248
 
249
  def predict_peptide_wrapper(sequence_str):
250
  if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
251
+ return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids."
252
 
253
  try:
 
254
  features = extract_features(sequence_str.upper(), VALIDATOR_EXTRACTOR, L_fixed=29, d_model_pe=16)
255
  scaled_features = VALIDATOR_SCALER.transform(features.reshape(1, -1))
256