Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
# app.py - RLAnOxPeptide Gradio Web Application
|
5 |
-
# Final version updated to use
|
6 |
|
7 |
import os
|
8 |
import torch
|
@@ -34,28 +34,26 @@ token2id["<EOS>"] = 1
|
|
34 |
id2token = {i: t for t, i in token2id.items()}
|
35 |
VOCAB_SIZE = len(token2id)
|
36 |
|
37 |
-
|
38 |
-
# --- LoRA Feature Extractor Model Class ---
|
39 |
-
# ✅ REPLACED: This new class handles loading the base model and attaching the LoRA adapter.
|
40 |
class LoRAProtT5Extractor:
|
41 |
def __init__(self, base_model_id, lora_adapter_path):
|
42 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
-
print(f"Initializing
|
44 |
|
45 |
-
print(f" - Loading base model and tokenizer from '{base_model_id}'...")
|
46 |
base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
|
47 |
self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id)
|
48 |
|
49 |
if not os.path.exists(lora_adapter_path):
|
50 |
-
raise FileNotFoundError(f"Error: LoRA adapter directory not found at: {lora_adapter_path}")
|
51 |
|
52 |
-
print(f" - Loading and applying LoRA adapter from: {lora_adapter_path}")
|
53 |
lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
|
54 |
|
55 |
-
print(" - Merging LoRA weights for faster inference...")
|
56 |
self.model = lora_model.merge_and_unload().to(self.device)
|
57 |
self.model.eval()
|
58 |
-
print(" -
|
59 |
|
60 |
def encode(self, sequence):
|
61 |
if not sequence or not isinstance(sequence, str):
|
@@ -71,7 +69,8 @@ class LoRAProtT5Extractor:
|
|
71 |
emb_np = embedding.squeeze(0).cpu().numpy()
|
72 |
return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
|
73 |
|
74 |
-
|
|
|
75 |
class AntioxidantPredictor(nn.Module):
|
76 |
def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
77 |
super(AntioxidantPredictor, self).__init__()
|
@@ -103,22 +102,34 @@ class AntioxidantPredictor(nn.Module):
|
|
103 |
def get_temperature(self):
|
104 |
return self.temperature.item()
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
self.lm_head = nn.Linear(embed_dim, vocab_size)
|
|
|
114 |
self.vocab_size = vocab_size
|
115 |
self.eos_token_id = token2id["<EOS>"]
|
116 |
self.pad_token_id = token2id["<PAD>"]
|
|
|
117 |
|
118 |
def forward(self, input_ids):
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
122 |
return logits
|
123 |
|
124 |
def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
|
@@ -132,6 +143,8 @@ class ProtT5Generator(nn.Module):
|
|
132 |
probs = torch.softmax(next_logits, dim=-1)
|
133 |
next_token = torch.multinomial(probs, num_samples=1)
|
134 |
generated = torch.cat((generated, next_token), dim=1)
|
|
|
|
|
135 |
return generated
|
136 |
|
137 |
def decode(self, token_ids_batch):
|
@@ -140,8 +153,8 @@ class ProtT5Generator(nn.Module):
|
|
140 |
seq = ""
|
141 |
for token_id in ids_tensor.tolist()[1:]:
|
142 |
if token_id == self.eos_token_id: break
|
143 |
-
if token_id ==
|
144 |
-
seq += id2token.get(token_id, "")
|
145 |
sequences.append(seq)
|
146 |
return sequences
|
147 |
|
@@ -151,34 +164,36 @@ try:
|
|
151 |
except ImportError:
|
152 |
raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
|
153 |
|
154 |
-
# --- Clustering Logic
|
155 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
156 |
if not sequences or len(sequences) < num_clusters:
|
157 |
return sequences[:num_clusters]
|
158 |
with torch.no_grad():
|
159 |
token_ids_list = []
|
160 |
-
max_len = max(
|
161 |
for seq in sequences:
|
162 |
-
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
163 |
-
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
164 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
165 |
token_ids_list.append(ids)
|
166 |
input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
|
|
|
|
|
167 |
embeddings = generator.embed_tokens(input_ids)
|
|
|
168 |
mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
|
169 |
seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
|
170 |
seq_embeds_np = seq_embeds.cpu().numpy()
|
|
|
171 |
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
|
172 |
-
|
173 |
for i in range(int(num_clusters)):
|
174 |
-
|
175 |
-
if len(
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
return representatives
|
182 |
|
183 |
# --------------------------------------------------------------------------
|
184 |
# SECTION 2: GLOBAL MODEL AND DEPENDENCY LOADING
|
@@ -189,40 +204,43 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
189 |
|
190 |
try:
|
191 |
# --- Define file paths ---
|
192 |
-
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
193 |
-
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
194 |
-
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
195 |
-
|
196 |
-
# ✅ UPDATED: Define paths for LoRA-based loading
|
197 |
PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
|
203 |
-
PREDICTOR_MODEL.load_state_dict(torch.load(
|
204 |
PREDICTOR_MODEL.to(DEVICE)
|
205 |
PREDICTOR_MODEL.eval()
|
206 |
-
print(
|
207 |
-
|
208 |
-
# --- Load
|
209 |
-
print(
|
210 |
-
|
211 |
-
print("Loading LoRA-enhanced ProtT5 Feature Extractor...")
|
212 |
-
# ✅ UPDATED: Instantiate the new LoRA extractor class
|
213 |
-
PROTT5_EXTRACTOR = LoRAProtT5Extractor(
|
214 |
base_model_id=PROTT5_BASE_MODEL_ID,
|
215 |
-
lora_adapter_path=
|
|
|
216 |
)
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
print(f"Loading Generator from: {GENERATOR_CHECKPOINT_PATH}")
|
221 |
-
GENERATOR_MODEL = ProtT5Generator(vocab_size=VOCAB_SIZE)
|
222 |
-
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
223 |
GENERATOR_MODEL.to(DEVICE)
|
224 |
GENERATOR_MODEL.eval()
|
225 |
-
print("✅ Generator
|
226 |
|
227 |
print("\n--- All models loaded! Gradio app is ready. ---\n")
|
228 |
|
@@ -231,7 +249,7 @@ except Exception as e:
|
|
231 |
raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
|
232 |
|
233 |
# --------------------------------------------------------------------------
|
234 |
-
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI
|
235 |
# --------------------------------------------------------------------------
|
236 |
|
237 |
def predict_peptide_wrapper(sequence_str):
|
@@ -239,10 +257,9 @@ def predict_peptide_wrapper(sequence_str):
|
|
239 |
return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
|
240 |
|
241 |
try:
|
242 |
-
#
|
243 |
-
|
244 |
-
|
245 |
-
scaled_features = SCALER.transform(features.reshape(1, -1))
|
246 |
|
247 |
with torch.no_grad():
|
248 |
features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
|
@@ -306,7 +323,7 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
306 |
time.sleep(1)
|
307 |
|
308 |
if not validated_pool:
|
309 |
-
return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with
|
310 |
|
311 |
high_quality_sequences = list(validated_pool.keys())
|
312 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
# app.py - RLAnOxPeptide Gradio Web Application
|
5 |
+
# Final version updated to use an AdvancedProtT5Generator with a LoRA backbone.
|
6 |
|
7 |
import os
|
8 |
import torch
|
|
|
34 |
id2token = {i: t for t, i in token2id.items()}
|
35 |
VOCAB_SIZE = len(token2id)
|
36 |
|
37 |
+
# --- Validator's Feature Extractor Class ---
|
|
|
|
|
38 |
class LoRAProtT5Extractor:
|
39 |
def __init__(self, base_model_id, lora_adapter_path):
|
40 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
print(f"Initializing Validator Feature Extractor on device: {self.device}")
|
42 |
|
43 |
+
print(f" - [Validator] Loading base model and tokenizer from '{base_model_id}'...")
|
44 |
base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
|
45 |
self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id)
|
46 |
|
47 |
if not os.path.exists(lora_adapter_path):
|
48 |
+
raise FileNotFoundError(f"Error: Validator LoRA adapter directory not found at: {lora_adapter_path}")
|
49 |
|
50 |
+
print(f" - [Validator] Loading and applying LoRA adapter from: {lora_adapter_path}")
|
51 |
lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
|
52 |
|
53 |
+
print(" - [Validator] Merging LoRA weights for faster inference...")
|
54 |
self.model = lora_model.merge_and_unload().to(self.device)
|
55 |
self.model.eval()
|
56 |
+
print(" - Validator feature extractor is ready.")
|
57 |
|
58 |
def encode(self, sequence):
|
59 |
if not sequence or not isinstance(sequence, str):
|
|
|
69 |
emb_np = embedding.squeeze(0).cpu().numpy()
|
70 |
return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
|
71 |
|
72 |
+
|
73 |
+
# --- Predictor Model Head Architecture (Unchanged) ---
|
74 |
class AntioxidantPredictor(nn.Module):
|
75 |
def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
76 |
super(AntioxidantPredictor, self).__init__()
|
|
|
102 |
def get_temperature(self):
|
103 |
return self.temperature.item()
|
104 |
|
105 |
+
|
106 |
+
# --- ✅ NEW Generator Model Architecture ---
|
107 |
+
class AdvancedProtT5Generator(nn.Module):
|
108 |
+
def __init__(self, base_model_id, lora_adapter_path, vocab_size):
|
109 |
+
super(AdvancedProtT5Generator, self).__init__()
|
110 |
+
|
111 |
+
print(f" - [Generator] Loading base ProtT5 model from '{base_model_id}'...")
|
112 |
+
base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
|
113 |
+
|
114 |
+
print(f" - [Generator] Applying LoRA adapter from: {lora_adapter_path}")
|
115 |
+
self.backbone = PeftModel.from_pretrained(base_model, lora_adapter_path)
|
116 |
+
|
117 |
+
# Expose the embedding layer for the clustering function
|
118 |
+
self.embed_tokens = self.backbone.get_input_embeddings()
|
119 |
+
|
120 |
+
embed_dim = self.backbone.config.d_model # Should be 1024
|
121 |
self.lm_head = nn.Linear(embed_dim, vocab_size)
|
122 |
+
|
123 |
self.vocab_size = vocab_size
|
124 |
self.eos_token_id = token2id["<EOS>"]
|
125 |
self.pad_token_id = token2id["<PAD>"]
|
126 |
+
print(" - Advanced Generator framework initialized.")
|
127 |
|
128 |
def forward(self, input_ids):
|
129 |
+
attention_mask = (input_ids != self.pad_token_id).int()
|
130 |
+
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
|
131 |
+
sequence_output = outputs.last_hidden_state
|
132 |
+
logits = self.lm_head(sequence_output)
|
133 |
return logits
|
134 |
|
135 |
def sample(self, batch_size, max_length=20, device="cpu", temperature=2.5, min_decoded_length=3):
|
|
|
143 |
probs = torch.softmax(next_logits, dim=-1)
|
144 |
next_token = torch.multinomial(probs, num_samples=1)
|
145 |
generated = torch.cat((generated, next_token), dim=1)
|
146 |
+
if (generated == self.eos_token_id).any(dim=1).all():
|
147 |
+
break
|
148 |
return generated
|
149 |
|
150 |
def decode(self, token_ids_batch):
|
|
|
153 |
seq = ""
|
154 |
for token_id in ids_tensor.tolist()[1:]:
|
155 |
if token_id == self.eos_token_id: break
|
156 |
+
if token_id == token2id["<PAD>"]: continue
|
157 |
+
seq += id2token.get(token_id, "?")
|
158 |
sequences.append(seq)
|
159 |
return sequences
|
160 |
|
|
|
164 |
except ImportError:
|
165 |
raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
|
166 |
|
167 |
+
# --- ✅ UPDATED Clustering Logic ---
|
168 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
169 |
if not sequences or len(sequences) < num_clusters:
|
170 |
return sequences[:num_clusters]
|
171 |
with torch.no_grad():
|
172 |
token_ids_list = []
|
173 |
+
max_len = max(len(seq) for seq in sequences) + 2
|
174 |
for seq in sequences:
|
175 |
+
ids = [np.random.randint(2, VOCAB_SIZE)] + [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
|
|
176 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
177 |
token_ids_list.append(ids)
|
178 |
input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
|
179 |
+
|
180 |
+
# Use the generator's exposed embedding layer
|
181 |
embeddings = generator.embed_tokens(input_ids)
|
182 |
+
|
183 |
mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
|
184 |
seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
|
185 |
seq_embeds_np = seq_embeds.cpu().numpy()
|
186 |
+
|
187 |
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
|
188 |
+
reps = []
|
189 |
for i in range(int(num_clusters)):
|
190 |
+
idxs = np.where(kmeans.labels_ == i)[0]
|
191 |
+
if len(idxs) == 0: continue
|
192 |
+
center = kmeans.cluster_centers_[i]
|
193 |
+
distances = np.linalg.norm(seq_embeds_np[idxs] - center, axis=1)
|
194 |
+
rep_idx = idxs[np.argmin(distances)]
|
195 |
+
reps.append(sequences[rep_idx])
|
196 |
+
return reps
|
|
|
197 |
|
198 |
# --------------------------------------------------------------------------
|
199 |
# SECTION 2: GLOBAL MODEL AND DEPENDENCY LOADING
|
|
|
204 |
|
205 |
try:
|
206 |
# --- Define file paths ---
|
|
|
|
|
|
|
|
|
|
|
207 |
PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
|
208 |
+
|
209 |
+
# Paths for the Validator System
|
210 |
+
VALIDATOR_LORA_PATH = "./lora_finetuned_prott5"
|
211 |
+
PREDICTOR_HEAD_CHECKPOINT_PATH = "./predictor_with_lora_checkpoints/final_predictor_with_lora.pth"
|
212 |
+
SCALER_PATH = "./predictor_with_lora_checkpoints/scaler_lora.pkl"
|
213 |
+
|
214 |
+
# Paths for the Generator System
|
215 |
+
GENERATOR_LORA_DIR = "./generator_with_lora_output/final_lora_generator"
|
216 |
+
GENERATOR_LM_HEAD_PATH = os.path.join(GENERATOR_LORA_DIR, "lm_head.pth")
|
217 |
+
|
218 |
+
# --- Load Validator System ---
|
219 |
+
print("--- Loading Validator System ---")
|
220 |
+
VALIDATOR_SCALER = joblib.load(SCALER_PATH)
|
221 |
+
VALIDATOR_EXTRACTOR = LoRAProtT5Extractor(
|
222 |
+
base_model_id=PROTT5_BASE_MODEL_ID,
|
223 |
+
lora_adapter_path=VALIDATOR_LORA_PATH
|
224 |
+
)
|
225 |
PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
|
226 |
+
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_HEAD_CHECKPOINT_PATH, map_location=DEVICE))
|
227 |
PREDICTOR_MODEL.to(DEVICE)
|
228 |
PREDICTOR_MODEL.eval()
|
229 |
+
print("✅ Validator System loaded successfully.")
|
230 |
+
|
231 |
+
# --- Load Generator System ---
|
232 |
+
print("\n--- Loading Generator System ---")
|
233 |
+
GENERATOR_MODEL = AdvancedProtT5Generator(
|
|
|
|
|
|
|
234 |
base_model_id=PROTT5_BASE_MODEL_ID,
|
235 |
+
lora_adapter_path=GENERATOR_LORA_DIR,
|
236 |
+
vocab_size=VOCAB_SIZE
|
237 |
)
|
238 |
+
if not os.path.exists(GENERATOR_LM_HEAD_PATH):
|
239 |
+
raise FileNotFoundError(f"Generator's lm_head weights not found at: {GENERATOR_LM_HEAD_PATH}")
|
240 |
+
GENERATOR_MODEL.lm_head.load_state_dict(torch.load(GENERATOR_LM_HEAD_PATH, map_location=DEVICE))
|
|
|
|
|
|
|
241 |
GENERATOR_MODEL.to(DEVICE)
|
242 |
GENERATOR_MODEL.eval()
|
243 |
+
print("✅ Generator System loaded successfully.")
|
244 |
|
245 |
print("\n--- All models loaded! Gradio app is ready. ---\n")
|
246 |
|
|
|
249 |
raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
|
250 |
|
251 |
# --------------------------------------------------------------------------
|
252 |
+
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI
|
253 |
# --------------------------------------------------------------------------
|
254 |
|
255 |
def predict_peptide_wrapper(sequence_str):
|
|
|
257 |
return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
|
258 |
|
259 |
try:
|
260 |
+
# Use the VALIDATOR's feature extractor
|
261 |
+
features = extract_features(sequence_str.upper(), VALIDATOR_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
262 |
+
scaled_features = VALIDATOR_SCALER.transform(features.reshape(1, -1))
|
|
|
263 |
|
264 |
with torch.no_grad():
|
265 |
features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
|
|
|
323 |
time.sleep(1)
|
324 |
|
325 |
if not validated_pool:
|
326 |
+
return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with current settings.", "Predicted Probability": "N/A"}])
|
327 |
|
328 |
high_quality_sequences = list(validated_pool.keys())
|
329 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|