Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,26 @@
|
|
1 |
-
|
2 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import os
|
5 |
import torch
|
@@ -12,14 +33,14 @@ from sklearn.cluster import KMeans
|
|
12 |
from tqdm import tqdm
|
13 |
import transformers
|
14 |
|
15 |
-
# Suppress verbose logging from transformers
|
16 |
transformers.logging.set_verbosity_error()
|
17 |
|
18 |
# --------------------------------------------------------------------------
|
19 |
# SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
|
20 |
# --------------------------------------------------------------------------
|
21 |
|
22 |
-
# --- Vocabulary Definition ---
|
23 |
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
24 |
token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
|
25 |
token2id["<PAD>"] = 0
|
@@ -27,71 +48,71 @@ token2id["<EOS>"] = 1
|
|
27 |
id2token = {i: t for t, i in token2id.items()}
|
28 |
VOCAB_SIZE = len(token2id)
|
29 |
|
30 |
-
|
|
|
|
|
31 |
class FeatureProtT5Model:
|
32 |
def __init__(self, model_dir_path, finetuned_weights_path=None):
|
33 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
-
print(f"Initializing ProtT5
|
35 |
|
36 |
-
#
|
37 |
-
# This step requires the original pytorch_model.bin to be in the model_dir_path.
|
38 |
self.tokenizer = transformers.T5Tokenizer.from_pretrained(model_dir_path, do_lower_case=False)
|
39 |
self.model = transformers.T5EncoderModel.from_pretrained(model_dir_path)
|
40 |
|
41 |
-
#
|
42 |
if finetuned_weights_path and os.path.exists(finetuned_weights_path):
|
43 |
-
print(f"
|
44 |
-
# Load the state_dict from your specific fine-tuned file
|
45 |
state_dict = torch.load(finetuned_weights_path, map_location=self.device)
|
46 |
-
# Use strict=False because the fine-tuned model may only contain encoder weights
|
47 |
self.model.load_state_dict(state_dict, strict=False)
|
48 |
-
print("Successfully applied fine-tuned weights
|
49 |
else:
|
50 |
-
print("Warning: Fine-tuned weights
|
51 |
|
52 |
self.model.to(self.device)
|
53 |
self.model.eval()
|
54 |
|
55 |
-
def encode(self, sequence):
|
56 |
-
if not sequence or not isinstance(sequence, str):
|
57 |
-
return np.zeros((1, 1024), dtype=np.float32)
|
58 |
-
seq_spaced = " ".join(list(sequence))
|
59 |
-
encoded_input = self.tokenizer(seq_spaced, return_tensors='pt', padding=True, truncation=True, max_length=1022)
|
60 |
-
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
61 |
-
with torch.no_grad():
|
62 |
-
embedding = self.model(**encoded_input).last_hidden_state
|
63 |
-
emb = embedding.squeeze(0).cpu().numpy()
|
64 |
-
return emb if emb.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
|
65 |
-
|
66 |
# --- Predictor Model Architecture ---
|
|
|
|
|
67 |
class AntioxidantPredictor(nn.Module):
|
68 |
-
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
69 |
super(AntioxidantPredictor, self).__init__()
|
70 |
self.prott5_dim = 1024
|
71 |
self.handcrafted_dim = input_dim - self.prott5_dim
|
72 |
self.seq_len = 16
|
73 |
-
self.prott5_feature_dim = 64
|
|
|
74 |
encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
|
75 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
|
|
76 |
fused_dim = self.prott5_feature_dim + self.handcrafted_dim
|
77 |
self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
|
78 |
self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
|
79 |
self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
|
80 |
-
|
|
|
81 |
batch_size = x.size(0)
|
|
|
82 |
prot_t5_features = x[:, :self.prott5_dim]
|
83 |
handcrafted_features = x[:, self.prott5_dim:]
|
|
|
|
|
84 |
prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
|
|
|
85 |
encoded_seq = self.transformer_encoder(prot_t5_seq)
|
86 |
refined_prott5 = encoded_seq.mean(dim=1)
|
|
|
87 |
fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
|
88 |
-
|
89 |
-
logits = self.classifier(
|
|
|
90 |
return logits / self.temperature
|
91 |
-
def set_temperature(self, temp_value, device): self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
|
92 |
-
def get_temperature(self): return self.temperature.item()
|
93 |
|
94 |
-
|
|
|
|
|
|
|
95 |
class ProtT5Generator(nn.Module):
|
96 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
97 |
super(ProtT5Generator, self).__init__()
|
@@ -117,49 +138,50 @@ class ProtT5Generator(nn.Module):
|
|
117 |
next_logits = logits[:, -1, :] / temperature
|
118 |
if generated.size(1) < min_decoded_length:
|
119 |
next_logits[:, self.eos_token_id] = -float("inf")
|
|
|
120 |
probs = torch.softmax(next_logits, dim=-1)
|
121 |
next_token = torch.multinomial(probs, num_samples=1)
|
122 |
generated = torch.cat((generated, next_token), dim=1)
|
123 |
-
if (generated == self.eos_token_id).any(dim=1).all():
|
124 |
-
break
|
125 |
return generated
|
126 |
|
127 |
def decode(self, token_ids_batch):
|
128 |
-
|
129 |
for ids_tensor in token_ids_batch:
|
130 |
seq = ""
|
131 |
-
for token_id in ids_tensor.tolist()[1:]: # Skip start token
|
132 |
if token_id == self.eos_token_id: break
|
133 |
if token_id == self.pad_token_id: continue
|
134 |
-
seq += id2token.get(token_id, "
|
135 |
-
|
136 |
-
return
|
137 |
-
|
138 |
-
# ---
|
|
|
|
|
|
|
139 |
try:
|
140 |
-
from feature_extract import
|
141 |
except ImportError:
|
142 |
-
raise gr.Error("
|
143 |
|
144 |
# --- Clustering Logic (from generator.py) ---
|
145 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
146 |
if not sequences or len(sequences) < num_clusters:
|
147 |
return sequences[:num_clusters]
|
|
|
148 |
with torch.no_grad():
|
149 |
token_ids_list = []
|
150 |
-
max_len = max(len(seq) for seq in sequences) + 2
|
151 |
for seq in sequences:
|
152 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
153 |
-
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
154 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
155 |
token_ids_list.append(ids)
|
156 |
|
157 |
input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
|
158 |
embeddings = generator.embed_tokens(input_ids)
|
159 |
mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
|
160 |
-
|
161 |
-
lengths = mask.sum(dim=1)
|
162 |
-
seq_embeds = embeddings.sum(dim=1) / (lengths + 1e-9)
|
163 |
seq_embeds_np = seq_embeds.cpu().numpy()
|
164 |
|
165 |
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
|
@@ -175,67 +197,67 @@ def cluster_sequences(generator, sequences, num_clusters, device):
|
|
175 |
return representatives
|
176 |
|
177 |
# --------------------------------------------------------------------------
|
178 |
-
# SECTION 2: GLOBAL MODEL LOADING
|
179 |
# --------------------------------------------------------------------------
|
180 |
-
|
181 |
-
|
|
|
182 |
|
183 |
try:
|
184 |
-
# --- Define file paths
|
185 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
186 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
187 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
188 |
PROTT5_BASE_MODEL_PATH = "prott5/model/"
|
189 |
-
# This path is now used by the FeatureProtT5Model to load the fine-tuned weights
|
190 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
191 |
|
192 |
-
# --- Load Predictor ---
|
193 |
-
print("Loading Predictor
|
194 |
-
|
195 |
-
PREDICTOR_MODEL = AntioxidantPredictor(
|
196 |
-
input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
|
197 |
-
)
|
198 |
-
# Load the state dict that matches this class
|
199 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
200 |
PREDICTOR_MODEL.to(DEVICE)
|
201 |
PREDICTOR_MODEL.eval()
|
202 |
print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
|
203 |
|
204 |
# --- Load Scaler & Feature Extractor ---
|
205 |
-
print("Loading Scaler
|
206 |
SCALER = joblib.load(SCALER_PATH)
|
|
|
207 |
PROTT5_EXTRACTOR = FeatureProtT5Model(
|
208 |
-
|
209 |
-
|
210 |
)
|
211 |
print("✅ Scaler and Feature Extractor loaded.")
|
212 |
|
213 |
-
# --- Load Generator ---
|
214 |
-
print("Loading Generator
|
215 |
-
GENERATOR_MODEL = ProtT5Generator(
|
216 |
-
vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
|
217 |
-
)
|
218 |
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
219 |
GENERATOR_MODEL.to(DEVICE)
|
220 |
GENERATOR_MODEL.eval()
|
221 |
print("✅ Generator model loaded.")
|
222 |
-
|
|
|
223 |
|
224 |
except Exception as e:
|
225 |
-
print(f"💥 FATAL ERROR
|
226 |
-
raise gr.Error(f"
|
227 |
|
228 |
# --------------------------------------------------------------------------
|
229 |
-
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
|
230 |
# --------------------------------------------------------------------------
|
231 |
|
232 |
def predict_peptide_wrapper(sequence_str):
|
|
|
233 |
if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
|
234 |
-
return "0.0000", "Error: Please enter a valid sequence
|
235 |
|
236 |
try:
|
237 |
-
#
|
238 |
-
|
|
|
|
|
|
|
239 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
240 |
|
241 |
with torch.no_grad():
|
@@ -247,21 +269,22 @@ def predict_peptide_wrapper(sequence_str):
|
|
247 |
return f"{probability:.4f}", classification
|
248 |
|
249 |
except Exception as e:
|
250 |
-
print(f"Prediction
|
251 |
-
return "N/A", f"An error occurred during
|
252 |
|
253 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
|
254 |
-
|
255 |
num_to_generate = int(num_to_generate)
|
256 |
min_len = int(min_len)
|
257 |
max_len = int(max_len)
|
258 |
|
259 |
try:
|
260 |
-
# Step 1: Generate a pool of
|
261 |
target_pool_size = int(num_to_generate * diversity_factor)
|
262 |
unique_seqs = set()
|
263 |
|
264 |
-
|
|
|
265 |
while len(unique_seqs) < target_pool_size:
|
266 |
batch_size = max(1, (target_pool_size - len(unique_seqs)))
|
267 |
with torch.no_grad():
|
@@ -269,19 +292,19 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
269 |
batch_size=batch_size, max_length=max_len, device=DEVICE,
|
270 |
temperature=temperature, min_decoded_length=min_len
|
271 |
)
|
272 |
-
|
273 |
|
274 |
initial_count = len(unique_seqs)
|
275 |
-
for seq in
|
276 |
if min_len <= len(seq) <= max_len:
|
277 |
unique_seqs.add(seq)
|
278 |
pbar.update(len(unique_seqs) - initial_count)
|
279 |
|
280 |
candidate_seqs = list(unique_seqs)
|
281 |
|
282 |
-
# Step 2: Validate the generated sequences
|
283 |
validated_pool = {}
|
284 |
-
for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
|
285 |
prob_str, _ = predict_peptide_wrapper(seq)
|
286 |
try:
|
287 |
prob = float(prob_str)
|
@@ -291,40 +314,41 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
291 |
continue
|
292 |
|
293 |
if not validated_pool:
|
294 |
-
return pd.DataFrame([{"Sequence": "No high-activity peptides (>0.9 prob) were generated.", "Predicted Probability": "N/A"}])
|
295 |
|
296 |
high_quality_sequences = list(validated_pool.keys())
|
297 |
|
298 |
-
# Step 3: Cluster to ensure diversity
|
299 |
-
progress(1.0, desc="Clustering for diversity...")
|
300 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
301 |
|
302 |
-
# Step 4: Format final results
|
303 |
final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
|
304 |
final_results.sort(key=lambda x: float(x[1]), reverse=True)
|
305 |
|
306 |
return pd.DataFrame(final_results, columns=["Sequence", "Predicted Probability"])
|
307 |
|
308 |
except Exception as e:
|
309 |
-
print(f"Generation
|
310 |
-
return pd.DataFrame([{"Sequence": f"An error occurred: {e}", "Predicted Probability": "N/A"}])
|
311 |
|
312 |
# --------------------------------------------------------------------------
|
313 |
# SECTION 4: GRADIO UI CONSTRUCTION
|
314 |
# --------------------------------------------------------------------------
|
315 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
316 |
-
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction
|
317 |
gr.Markdown("An integrated framework combining reinforcement learning and a Transformer model for the efficient prediction and innovative design of antioxidant peptides.")
|
318 |
|
319 |
with gr.Tabs():
|
|
|
320 |
with gr.TabItem("Peptide Activity Predictor"):
|
321 |
gr.Markdown("### Enter an amino acid sequence to predict its antioxidant activity.")
|
322 |
with gr.Row():
|
323 |
peptide_input = gr.Textbox(label="Peptide Sequence", placeholder="e.g., WHYHDYKY", scale=3)
|
324 |
predict_button = gr.Button("Predict", variant="primary", scale=1)
|
325 |
with gr.Row():
|
326 |
-
probability_output = gr.Textbox(label="Predicted Probability")
|
327 |
-
class_output = gr.Textbox(label="Predicted Class")
|
328 |
|
329 |
predict_button.click(
|
330 |
fn=predict_peptide_wrapper,
|
@@ -332,23 +356,27 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
|
332 |
outputs=[probability_output, class_output]
|
333 |
)
|
334 |
gr.Examples(
|
335 |
-
examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"]],
|
336 |
-
inputs=peptide_input
|
|
|
|
|
|
|
337 |
)
|
338 |
|
|
|
339 |
with gr.TabItem("Novel Sequence Generator"):
|
340 |
gr.Markdown("### Set parameters to generate novel, high-activity antioxidant peptides.")
|
341 |
with gr.Column():
|
342 |
with gr.Row():
|
343 |
-
num_input = gr.Slider(minimum=
|
344 |
-
min_len_input = gr.Slider(minimum=
|
345 |
max_len_input = gr.Slider(minimum=10, maximum=20, value=20, step=1, label="Maximum Length")
|
346 |
with gr.Row():
|
347 |
temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
|
348 |
-
diversity_input = gr.Slider(minimum=1.
|
349 |
|
350 |
generate_button = gr.Button("Generate Peptides", variant="primary")
|
351 |
-
results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides", wrap=True)
|
352 |
|
353 |
generate_button.click(
|
354 |
fn=generate_peptide_wrapper,
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# app.py - RLAnOxPeptide Gradio Web Application
|
5 |
+
# This script combines logic from predictor.py, generator.py, and the original app.py
|
6 |
+
# into a single, self-contained file for a Hugging Face Space.
|
7 |
+
#
|
8 |
+
# REQUIRED FILE STRUCTURE IN HUGGING FACE REPO:
|
9 |
+
# .
|
10 |
+
# ├── app.py (This file)
|
11 |
+
# ├── feature_extract.py (CRITICAL: This file with your `extract_features` function MUST be present)
|
12 |
+
# ├── checkpoints/
|
13 |
+
# │ ├── final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth
|
14 |
+
# │ └── scaler_FINETUNED_PROTT5.pkl
|
15 |
+
# ├── generator_checkpoints_v3.6/
|
16 |
+
# │ └── final_generator_model.pth
|
17 |
+
# ├── prott5/
|
18 |
+
# │ └── model/
|
19 |
+
# │ ├── config.json
|
20 |
+
# │ ├── pytorch_model.bin (The base ProtT5 model from Rostlab)
|
21 |
+
# │ ├── finetuned_prott5.bin (Your fine-tuned feature extractor weights)
|
22 |
+
# │ └── ... (other tokenizer files)
|
23 |
+
# └── requirements.txt
|
24 |
|
25 |
import os
|
26 |
import torch
|
|
|
33 |
from tqdm import tqdm
|
34 |
import transformers
|
35 |
|
36 |
+
# Suppress verbose logging from transformers, which can clutter the app logs
|
37 |
transformers.logging.set_verbosity_error()
|
38 |
|
39 |
# --------------------------------------------------------------------------
|
40 |
# SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
|
41 |
# --------------------------------------------------------------------------
|
42 |
|
43 |
+
# --- Vocabulary Definition (Consistent across all scripts) ---
|
44 |
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
45 |
token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
|
46 |
token2id["<PAD>"] = 0
|
|
|
48 |
id2token = {i: t for t, i in token2id.items()}
|
49 |
VOCAB_SIZE = len(token2id)
|
50 |
|
51 |
+
|
52 |
+
# --- Feature Extractor Model Class (For ProtT5) ---
|
53 |
+
# This class robustly loads the base ProtT5 model and applies your fine-tuned weights.
|
54 |
class FeatureProtT5Model:
|
55 |
def __init__(self, model_dir_path, finetuned_weights_path=None):
|
56 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
57 |
+
print(f"Initializing ProtT5 for feature extraction on device: {self.device}")
|
58 |
|
59 |
+
# Load the base model architecture and tokenizer from the specified directory.
|
|
|
60 |
self.tokenizer = transformers.T5Tokenizer.from_pretrained(model_dir_path, do_lower_case=False)
|
61 |
self.model = transformers.T5EncoderModel.from_pretrained(model_dir_path)
|
62 |
|
63 |
+
# If a path to a fine-tuned weights file is provided, load and apply those weights.
|
64 |
if finetuned_weights_path and os.path.exists(finetuned_weights_path):
|
65 |
+
print(f"Applying fine-tuned weights from: {finetuned_weights_path}")
|
|
|
66 |
state_dict = torch.load(finetuned_weights_path, map_location=self.device)
|
|
|
67 |
self.model.load_state_dict(state_dict, strict=False)
|
68 |
+
print("Successfully applied fine-tuned weights.")
|
69 |
else:
|
70 |
+
print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
|
71 |
|
72 |
self.model.to(self.device)
|
73 |
self.model.eval()
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
# --- Predictor Model Architecture ---
|
76 |
+
# This is the antioxidant activity predictor model. Its architecture must
|
77 |
+
# exactly match the architecture used to save the checkpoint file.
|
78 |
class AntioxidantPredictor(nn.Module):
|
79 |
+
def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
80 |
super(AntioxidantPredictor, self).__init__()
|
81 |
self.prott5_dim = 1024
|
82 |
self.handcrafted_dim = input_dim - self.prott5_dim
|
83 |
self.seq_len = 16
|
84 |
+
self.prott5_feature_dim = 64 # 16 * 64 = 1024
|
85 |
+
|
86 |
encoder_layer = nn.TransformerEncoderLayer(d_model=self.prott5_feature_dim, nhead=transformer_heads, dropout=transformer_dropout, batch_first=True)
|
87 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
88 |
+
|
89 |
fused_dim = self.prott5_feature_dim + self.handcrafted_dim
|
90 |
self.fusion_fc = nn.Sequential(nn.Linear(fused_dim, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.3))
|
91 |
self.classifier = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1))
|
92 |
self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
batch_size = x.size(0)
|
96 |
+
# The input 'x' is a flat 1914-dim vector from extract_features()
|
97 |
prot_t5_features = x[:, :self.prott5_dim]
|
98 |
handcrafted_features = x[:, self.prott5_dim:]
|
99 |
+
|
100 |
+
# Reshape the first 1024 features back into a sequence representation
|
101 |
prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
|
102 |
+
|
103 |
encoded_seq = self.transformer_encoder(prot_t5_seq)
|
104 |
refined_prott5 = encoded_seq.mean(dim=1)
|
105 |
+
|
106 |
fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
|
107 |
+
fused_output = self.fusion_fc(fused_features)
|
108 |
+
logits = self.classifier(fused_output)
|
109 |
+
|
110 |
return logits / self.temperature
|
|
|
|
|
111 |
|
112 |
+
def get_temperature(self):
|
113 |
+
return self.temperature.item()
|
114 |
+
|
115 |
+
# --- Generator Model Architecture (from generator.py) ---
|
116 |
class ProtT5Generator(nn.Module):
|
117 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
118 |
super(ProtT5Generator, self).__init__()
|
|
|
138 |
next_logits = logits[:, -1, :] / temperature
|
139 |
if generated.size(1) < min_decoded_length:
|
140 |
next_logits[:, self.eos_token_id] = -float("inf")
|
141 |
+
|
142 |
probs = torch.softmax(next_logits, dim=-1)
|
143 |
next_token = torch.multinomial(probs, num_samples=1)
|
144 |
generated = torch.cat((generated, next_token), dim=1)
|
|
|
|
|
145 |
return generated
|
146 |
|
147 |
def decode(self, token_ids_batch):
|
148 |
+
sequences = []
|
149 |
for ids_tensor in token_ids_batch:
|
150 |
seq = ""
|
151 |
+
for token_id in ids_tensor.tolist()[1:]: # Skip the random start token
|
152 |
if token_id == self.eos_token_id: break
|
153 |
if token_id == self.pad_token_id: continue
|
154 |
+
seq += id2token.get(token_id, "")
|
155 |
+
sequences.append(seq)
|
156 |
+
return sequences
|
157 |
+
|
158 |
+
# --- CRITICAL DEPENDENCY: feature_extract.py ---
|
159 |
+
# This application requires a function named `extract_features` to convert a peptide
|
160 |
+
# sequence into a 1914-dimensional feature vector for the prediction model.
|
161 |
+
# This function must be defined in a file named `feature_extract.py` in the repository root.
|
162 |
try:
|
163 |
+
from feature_extract import extract_features
|
164 |
except ImportError:
|
165 |
+
raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required for the application to run. Please upload it to your repository.")
|
166 |
|
167 |
# --- Clustering Logic (from generator.py) ---
|
168 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
169 |
if not sequences or len(sequences) < num_clusters:
|
170 |
return sequences[:num_clusters]
|
171 |
+
|
172 |
with torch.no_grad():
|
173 |
token_ids_list = []
|
174 |
+
max_len = max((len(seq) for seq in sequences), default=0) + 2
|
175 |
for seq in sequences:
|
176 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
177 |
+
ids = [np.random.randint(2, VOCAB_SIZE)] + ids # Prepend a start token
|
178 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
179 |
token_ids_list.append(ids)
|
180 |
|
181 |
input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
|
182 |
embeddings = generator.embed_tokens(input_ids)
|
183 |
mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
|
184 |
+
seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
|
|
|
|
|
185 |
seq_embeds_np = seq_embeds.cpu().numpy()
|
186 |
|
187 |
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
|
|
|
197 |
return representatives
|
198 |
|
199 |
# --------------------------------------------------------------------------
|
200 |
+
# SECTION 2: GLOBAL MODEL AND DEPENDENCY LOADING
|
201 |
# --------------------------------------------------------------------------
|
202 |
+
|
203 |
+
print("--- Starting Application: Loading all models and dependencies ---")
|
204 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
205 |
|
206 |
try:
|
207 |
+
# --- Define file paths relative to the repository root ---
|
208 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
209 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
210 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
211 |
PROTT5_BASE_MODEL_PATH = "prott5/model/"
|
|
|
212 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
213 |
|
214 |
+
# --- Load Predictor Model ---
|
215 |
+
print(f"Loading Predictor from: {PREDICTOR_CHECKPOINT_PATH}")
|
216 |
+
PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
|
|
|
|
|
|
|
|
|
217 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
218 |
PREDICTOR_MODEL.to(DEVICE)
|
219 |
PREDICTOR_MODEL.eval()
|
220 |
print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
|
221 |
|
222 |
# --- Load Scaler & Feature Extractor ---
|
223 |
+
print(f"Loading Scaler from: {SCALER_PATH}")
|
224 |
SCALER = joblib.load(SCALER_PATH)
|
225 |
+
print("Loading ProtT5 Feature Extractor...")
|
226 |
PROTT5_EXTRACTOR = FeatureProtT5Model(
|
227 |
+
model_dir_path=PROTT5_BASE_MODEL_PATH,
|
228 |
+
finetuned_weights_path=FINETUNED_PROTT5_FOR_FEATURES_PATH
|
229 |
)
|
230 |
print("✅ Scaler and Feature Extractor loaded.")
|
231 |
|
232 |
+
# --- Load Generator Model ---
|
233 |
+
print(f"Loading Generator from: {GENERATOR_CHECKPOINT_PATH}")
|
234 |
+
GENERATOR_MODEL = ProtT5Generator(vocab_size=VOCAB_SIZE)
|
|
|
|
|
235 |
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
236 |
GENERATOR_MODEL.to(DEVICE)
|
237 |
GENERATOR_MODEL.eval()
|
238 |
print("✅ Generator model loaded.")
|
239 |
+
|
240 |
+
print("\n--- All models loaded! Gradio app is ready. ---\n")
|
241 |
|
242 |
except Exception as e:
|
243 |
+
print(f"💥 FATAL ERROR during model loading: {e}")
|
244 |
+
raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
|
245 |
|
246 |
# --------------------------------------------------------------------------
|
247 |
+
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI
|
248 |
# --------------------------------------------------------------------------
|
249 |
|
250 |
def predict_peptide_wrapper(sequence_str):
|
251 |
+
"""Handles the prediction for a single peptide sequence from the UI."""
|
252 |
if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
|
253 |
+
return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
|
254 |
|
255 |
try:
|
256 |
+
# Use the imported extract_features function.
|
257 |
+
# The L_fixed and d_model_pe values are taken from your original predictor.py arguments.
|
258 |
+
features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
259 |
+
|
260 |
+
# Scale the features using the loaded scaler
|
261 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
262 |
|
263 |
with torch.no_grad():
|
|
|
269 |
return f"{probability:.4f}", classification
|
270 |
|
271 |
except Exception as e:
|
272 |
+
print(f"Prediction Error for sequence '{sequence_str}': {e}")
|
273 |
+
return "N/A", f"An error occurred during prediction: {e}"
|
274 |
|
275 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
|
276 |
+
"""Handles the full generation-validation-clustering pipeline."""
|
277 |
num_to_generate = int(num_to_generate)
|
278 |
min_len = int(min_len)
|
279 |
max_len = int(max_len)
|
280 |
|
281 |
try:
|
282 |
+
# Step 1: Generate a large, unique pool of candidate sequences
|
283 |
target_pool_size = int(num_to_generate * diversity_factor)
|
284 |
unique_seqs = set()
|
285 |
|
286 |
+
pbar_desc = "Step 1/3: Generating candidate sequences"
|
287 |
+
with tqdm(total=target_pool_size, desc=pbar_desc) as pbar:
|
288 |
while len(unique_seqs) < target_pool_size:
|
289 |
batch_size = max(1, (target_pool_size - len(unique_seqs)))
|
290 |
with torch.no_grad():
|
|
|
292 |
batch_size=batch_size, max_length=max_len, device=DEVICE,
|
293 |
temperature=temperature, min_decoded_length=min_len
|
294 |
)
|
295 |
+
decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
|
296 |
|
297 |
initial_count = len(unique_seqs)
|
298 |
+
for seq in decoded_sequences:
|
299 |
if min_len <= len(seq) <= max_len:
|
300 |
unique_seqs.add(seq)
|
301 |
pbar.update(len(unique_seqs) - initial_count)
|
302 |
|
303 |
candidate_seqs = list(unique_seqs)
|
304 |
|
305 |
+
# Step 2: Validate the generated sequences and filter for high probability
|
306 |
validated_pool = {}
|
307 |
+
for seq in tqdm(candidate_seqs, desc="Step 2/3: Validating generated sequences"):
|
308 |
prob_str, _ = predict_peptide_wrapper(seq)
|
309 |
try:
|
310 |
prob = float(prob_str)
|
|
|
314 |
continue
|
315 |
|
316 |
if not validated_pool:
|
317 |
+
return pd.DataFrame([{"Sequence": "No high-activity peptides (>0.9 prob) were generated. Try increasing the Diversity Factor or changing the Temperature.", "Predicted Probability": "N/A"}])
|
318 |
|
319 |
high_quality_sequences = list(validated_pool.keys())
|
320 |
|
321 |
+
# Step 3: Cluster to ensure diversity in the final set
|
322 |
+
progress(1.0, desc="Step 3/3: Clustering for diversity...")
|
323 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
324 |
|
325 |
+
# Step 4: Format final results into a DataFrame
|
326 |
final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
|
327 |
final_results.sort(key=lambda x: float(x[1]), reverse=True)
|
328 |
|
329 |
return pd.DataFrame(final_results, columns=["Sequence", "Predicted Probability"])
|
330 |
|
331 |
except Exception as e:
|
332 |
+
print(f"Generation Pipeline Error: {e}")
|
333 |
+
return pd.DataFrame([{"Sequence": f"An error occurred during generation: {e}", "Predicted Probability": "N/A"}])
|
334 |
|
335 |
# --------------------------------------------------------------------------
|
336 |
# SECTION 4: GRADIO UI CONSTRUCTION
|
337 |
# --------------------------------------------------------------------------
|
338 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
339 |
+
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction")
|
340 |
gr.Markdown("An integrated framework combining reinforcement learning and a Transformer model for the efficient prediction and innovative design of antioxidant peptides.")
|
341 |
|
342 |
with gr.Tabs():
|
343 |
+
# --- PREDICTION TAB ---
|
344 |
with gr.TabItem("Peptide Activity Predictor"):
|
345 |
gr.Markdown("### Enter an amino acid sequence to predict its antioxidant activity.")
|
346 |
with gr.Row():
|
347 |
peptide_input = gr.Textbox(label="Peptide Sequence", placeholder="e.g., WHYHDYKY", scale=3)
|
348 |
predict_button = gr.Button("Predict", variant="primary", scale=1)
|
349 |
with gr.Row():
|
350 |
+
probability_output = gr.Textbox(label="Predicted Probability", interactive=False)
|
351 |
+
class_output = gr.Textbox(label="Predicted Class", interactive=False)
|
352 |
|
353 |
predict_button.click(
|
354 |
fn=predict_peptide_wrapper,
|
|
|
356 |
outputs=[probability_output, class_output]
|
357 |
)
|
358 |
gr.Examples(
|
359 |
+
examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"], ["WKYG"]],
|
360 |
+
inputs=peptide_input,
|
361 |
+
fn=predict_peptide_wrapper,
|
362 |
+
outputs=[probability_output, class_output],
|
363 |
+
cache_examples=True
|
364 |
)
|
365 |
|
366 |
+
# --- GENERATION TAB ---
|
367 |
with gr.TabItem("Novel Sequence Generator"):
|
368 |
gr.Markdown("### Set parameters to generate novel, high-activity antioxidant peptides.")
|
369 |
with gr.Column():
|
370 |
with gr.Row():
|
371 |
+
num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
|
372 |
+
min_len_input = gr.Slider(minimum=3, maximum=10, value=3, step=1, label="Minimum Length")
|
373 |
max_len_input = gr.Slider(minimum=10, maximum=20, value=20, step=1, label="Maximum Length")
|
374 |
with gr.Row():
|
375 |
temp_input = gr.Slider(minimum=0.5, maximum=3.0, value=2.5, step=0.1, label="Temperature (Higher = More random)")
|
376 |
+
diversity_input = gr.Slider(minimum=1.1, maximum=5.0, value=1.5, step=0.1, label="Diversity Factor (Larger initial pool for clustering)")
|
377 |
|
378 |
generate_button = gr.Button("Generate Peptides", variant="primary")
|
379 |
+
results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
|
380 |
|
381 |
generate_button.click(
|
382 |
fn=generate_peptide_wrapper,
|