Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
# app.py - RLAnOxPeptide Gradio Web Application
|
2 |
-
# This script integrates both the predictor and generator into a user-friendly web UI.
|
3 |
|
4 |
import os
|
5 |
import torch
|
@@ -17,11 +16,10 @@ transformers.logging.set_verbosity_error()
|
|
17 |
|
18 |
# --------------------------------------------------------------------------
|
19 |
# SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
|
20 |
-
#
|
21 |
-
# These should match the versions used during training.
|
22 |
# --------------------------------------------------------------------------
|
23 |
|
24 |
-
# --- Vocabulary Definition
|
25 |
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
26 |
token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
|
27 |
token2id["<PAD>"] = 0
|
@@ -29,16 +27,13 @@ token2id["<EOS>"] = 1
|
|
29 |
id2token = {i: t for t, i in token2id.items()}
|
30 |
VOCAB_SIZE = len(token2id)
|
31 |
|
32 |
-
# --- Predictor Model Architecture (
|
33 |
class AntioxidantPredictor(nn.Module):
|
34 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
35 |
super(AntioxidantPredictor, self).__init__()
|
36 |
-
# 根据错误日志和您的训练脚本,我们知道输入维度是固定的
|
37 |
-
# 并且模型内部处理 ProtT5 和传统特征的分离
|
38 |
self.t5_dim = 1024
|
39 |
self.hand_crafted_dim = input_dim - self.t5_dim
|
40 |
|
41 |
-
# 定义 Transformer Encoder
|
42 |
encoder_layer = nn.TransformerEncoderLayer(
|
43 |
d_model=self.t5_dim,
|
44 |
nhead=transformer_heads,
|
@@ -47,9 +42,6 @@ class AntioxidantPredictor(nn.Module):
|
|
47 |
)
|
48 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
49 |
|
50 |
-
# 定义 MLP
|
51 |
-
# 错误日志表明权重文件没有 fusion_fc 和 classifier,只有一个 mlp
|
52 |
-
# 我们根据 predictor_train.py 的原始结构来重建
|
53 |
self.mlp = nn.Sequential(
|
54 |
nn.Linear(input_dim, 512),
|
55 |
nn.ReLU(),
|
@@ -62,19 +54,15 @@ class AntioxidantPredictor(nn.Module):
|
|
62 |
self.temperature = nn.Parameter(torch.ones(1))
|
63 |
|
64 |
def forward(self, fused_features):
|
65 |
-
# 这个前向传播逻辑与您的训练脚本 predictor_train.py 更为匹配
|
66 |
prot_t5_features = fused_features[:, :self.t5_dim]
|
67 |
hand_crafted_features = fused_features[:, self.t5_dim:]
|
68 |
|
69 |
-
# Transformer 只处理 ProtT5 特征
|
70 |
prot_t5_features_unsqueezed = prot_t5_features.unsqueeze(1)
|
71 |
transformer_output = self.transformer_encoder(prot_t5_features_unsqueezed)
|
72 |
transformer_output_pooled = transformer_output.mean(dim=1)
|
73 |
|
74 |
-
# 将处理后的 ProtT5 特征与传统特征拼接
|
75 |
combined_features = torch.cat((transformer_output_pooled, hand_crafted_features), dim=1)
|
76 |
|
77 |
-
# 将最终拼接的特征送入 MLP
|
78 |
logits = self.mlp(combined_features)
|
79 |
|
80 |
return logits / self.temperature
|
@@ -87,7 +75,6 @@ class AntioxidantPredictor(nn.Module):
|
|
87 |
|
88 |
# --- Generator Model Architecture (from generator.py) ---
|
89 |
class ProtT5Generator(nn.Module):
|
90 |
-
# This class definition should be an exact copy from your project
|
91 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
92 |
super(ProtT5Generator, self).__init__()
|
93 |
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=token2id["<PAD>"])
|
@@ -97,7 +84,7 @@ class ProtT5Generator(nn.Module):
|
|
97 |
self.vocab_size = vocab_size
|
98 |
self.eos_token_id = token2id["<EOS>"]
|
99 |
self.pad_token_id = token2id["<PAD>"]
|
100 |
-
|
101 |
def forward(self, input_ids):
|
102 |
embeddings = self.embed_tokens(input_ids)
|
103 |
encoder_output = self.encoder(embeddings)
|
@@ -112,12 +99,11 @@ class ProtT5Generator(nn.Module):
|
|
112 |
next_logits = logits[:, -1, :] / temperature
|
113 |
if generated.size(1) < min_decoded_length:
|
114 |
next_logits[:, self.eos_token_id] = -float("inf")
|
115 |
-
|
116 |
probs = torch.softmax(next_logits, dim=-1)
|
117 |
next_token = torch.multinomial(probs, num_samples=1)
|
118 |
generated = torch.cat((generated, next_token), dim=1)
|
119 |
-
|
120 |
-
if (
|
121 |
break
|
122 |
return generated
|
123 |
|
@@ -133,9 +119,7 @@ class ProtT5Generator(nn.Module):
|
|
133 |
seqs.append(seq)
|
134 |
return seqs
|
135 |
|
136 |
-
# --- Feature Extraction Logic (
|
137 |
-
# Note: You need the actual ProtT5Model and extract_features here.
|
138 |
-
# Assuming they are in a file named `feature_extract.py` in the same directory.
|
139 |
try:
|
140 |
from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
|
141 |
except ImportError:
|
@@ -147,11 +131,10 @@ def cluster_sequences(generator, sequences, num_clusters, device):
|
|
147 |
return sequences[:num_clusters]
|
148 |
with torch.no_grad():
|
149 |
token_ids_list = []
|
150 |
-
max_len = max(len(seq) for seq in sequences) + 2
|
151 |
for seq in sequences:
|
152 |
-
# Recreate encoding to match how generator sees it (with start token)
|
153 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
154 |
-
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
155 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
156 |
token_ids_list.append(ids)
|
157 |
|
@@ -175,55 +158,44 @@ def cluster_sequences(generator, sequences, num_clusters, device):
|
|
175 |
representatives.append(sequences[representative_index])
|
176 |
return representatives
|
177 |
|
178 |
-
|
179 |
# --------------------------------------------------------------------------
|
180 |
# SECTION 2: GLOBAL MODEL LOADING
|
181 |
-
# Load all models and dependencies once when the app starts.
|
182 |
# --------------------------------------------------------------------------
|
183 |
print("Loading all models and dependencies. Please wait...")
|
184 |
-
DEVICE = "cpu"
|
185 |
|
186 |
try:
|
187 |
-
# --- Define
|
188 |
-
# !! IMPORTANT: Ensure these are relative paths to the files in your Space !!
|
189 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
190 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
191 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
192 |
PROTT5_BASE_MODEL_PATH = "prott5/model/"
|
193 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
194 |
|
195 |
-
# --- Load Predictor
|
196 |
print("Loading Predictor Model...")
|
197 |
-
PREDICTOR_MODEL = AntioxidantPredictor(
|
198 |
-
input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
|
199 |
-
)
|
200 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
201 |
PREDICTOR_MODEL.to(DEVICE)
|
202 |
PREDICTOR_MODEL.eval()
|
203 |
-
print("✅ Predictor model loaded.")
|
204 |
|
205 |
-
|
|
|
206 |
SCALER = joblib.load(SCALER_PATH)
|
207 |
-
print("✅ Scaler loaded.")
|
208 |
-
|
209 |
-
print("Loading ProtT5 Feature Extractor...")
|
210 |
-
# This extractor must use the fine-tuned model for features, as per your training logic
|
211 |
PROTT5_EXTRACTOR = FeatureProtT5Model(
|
212 |
model_path=PROTT5_BASE_MODEL_PATH,
|
213 |
finetuned_model_file=FINETUNED_PROTT5_FOR_FEATURES_PATH
|
214 |
)
|
215 |
-
print("✅
|
216 |
|
217 |
-
# --- Load Generator
|
218 |
print("Loading Generator Model...")
|
219 |
-
GENERATOR_MODEL = ProtT5Generator(
|
220 |
-
vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
|
221 |
-
)
|
222 |
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
223 |
GENERATOR_MODEL.to(DEVICE)
|
224 |
GENERATOR_MODEL.eval()
|
225 |
print("✅ Generator model loaded.")
|
226 |
-
|
227 |
print("\n--- All models loaded successfully! Gradio app is ready. ---\n")
|
228 |
|
229 |
except Exception as e:
|
@@ -232,22 +204,17 @@ except Exception as e:
|
|
232 |
|
233 |
# --------------------------------------------------------------------------
|
234 |
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
|
235 |
-
# These functions connect the UI to our model's logic.
|
236 |
# --------------------------------------------------------------------------
|
237 |
|
238 |
def predict_peptide_wrapper(sequence_str):
|
239 |
-
"""Takes a peptide sequence string and returns its predicted probability and class."""
|
240 |
if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
|
241 |
return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
|
242 |
|
243 |
try:
|
244 |
-
#
|
245 |
-
features = extract_features(sequence_str, PROTT5_EXTRACTOR)
|
246 |
-
|
247 |
-
# 2. Scale features
|
248 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
249 |
|
250 |
-
# 3. Predict with the model
|
251 |
with torch.no_grad():
|
252 |
features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
|
253 |
logits = PREDICTOR_MODEL(features_tensor)
|
@@ -261,47 +228,46 @@ def predict_peptide_wrapper(sequence_str):
|
|
261 |
return "N/A", f"An error occurred during processing: {e}"
|
262 |
|
263 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
|
264 |
-
"""Generates, validates, and clusters sequences."""
|
265 |
num_to_generate = int(num_to_generate)
|
266 |
min_len = int(min_len)
|
267 |
max_len = int(max_len)
|
268 |
|
269 |
try:
|
270 |
-
#
|
271 |
target_pool_size = int(num_to_generate * diversity_factor)
|
272 |
unique_seqs = set()
|
273 |
-
progress(0, desc="Generating initial peptide pool...")
|
274 |
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
293 |
|
294 |
candidate_seqs = list(unique_seqs)
|
295 |
-
if not candidate_seqs:
|
296 |
-
return pd.DataFrame({"Sequence": ["Failed to generate valid sequences."], "Predicted Probability": ["N/A"]})
|
297 |
|
298 |
-
#
|
299 |
validated_pool = {}
|
300 |
for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
|
301 |
prob_str, _ = predict_peptide_wrapper(seq)
|
302 |
try:
|
303 |
prob = float(prob_str)
|
304 |
-
if prob > 0.90:
|
305 |
validated_pool[seq] = prob
|
306 |
except (ValueError, TypeError):
|
307 |
continue
|
@@ -311,11 +277,11 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
311 |
|
312 |
high_quality_sequences = list(validated_pool.keys())
|
313 |
|
314 |
-
#
|
315 |
progress(1.0, desc="Clustering for diversity...")
|
316 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
317 |
|
318 |
-
#
|
319 |
final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
|
320 |
final_results.sort(key=lambda x: float(x[1]), reverse=True)
|
321 |
|
@@ -325,10 +291,8 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
325 |
print(f"Generation error: {e}")
|
326 |
return pd.DataFrame({"Sequence": [f"An error occurred during generation: {e}"], "Predicted Probability": ["N/A"]})
|
327 |
|
328 |
-
|
329 |
# --------------------------------------------------------------------------
|
330 |
-
# SECTION 4: GRADIO UI CONSTRUCTION
|
331 |
-
# Building the web interface. All text is in English.
|
332 |
# --------------------------------------------------------------------------
|
333 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
334 |
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
|
@@ -350,11 +314,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
|
350 |
outputs=[probability_output, class_output]
|
351 |
)
|
352 |
gr.Examples(
|
353 |
-
examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"]
|
354 |
-
inputs=peptide_input
|
355 |
-
outputs=[probability_output, class_output],
|
356 |
-
fn=predict_peptide_wrapper,
|
357 |
-
cache_examples=False,
|
358 |
)
|
359 |
|
360 |
with gr.TabItem("Novel Sequence Generator"):
|
@@ -378,4 +339,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
|
378 |
)
|
379 |
|
380 |
if __name__ == "__main__":
|
381 |
-
demo.launch()
|
|
|
1 |
+
# app.py - RLAnOxPeptide Gradio Web Application (Corrected Version)
|
|
|
2 |
|
3 |
import os
|
4 |
import torch
|
|
|
16 |
|
17 |
# --------------------------------------------------------------------------
|
18 |
# SECTION 1: CORE CLASS AND FUNCTION DEFINITIONS
|
19 |
+
# These definitions are now synchronized with your provided, working scripts.
|
|
|
20 |
# --------------------------------------------------------------------------
|
21 |
|
22 |
+
# --- Vocabulary Definition ---
|
23 |
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
24 |
token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
|
25 |
token2id["<PAD>"] = 0
|
|
|
27 |
id2token = {i: t for t, i in token2id.items()}
|
28 |
VOCAB_SIZE = len(token2id)
|
29 |
|
30 |
+
# --- Predictor Model Architecture (VERSION THAT MATCHES YOUR .pth FILE) ---
|
31 |
class AntioxidantPredictor(nn.Module):
|
32 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
33 |
super(AntioxidantPredictor, self).__init__()
|
|
|
|
|
34 |
self.t5_dim = 1024
|
35 |
self.hand_crafted_dim = input_dim - self.t5_dim
|
36 |
|
|
|
37 |
encoder_layer = nn.TransformerEncoderLayer(
|
38 |
d_model=self.t5_dim,
|
39 |
nhead=transformer_heads,
|
|
|
42 |
)
|
43 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
44 |
|
|
|
|
|
|
|
45 |
self.mlp = nn.Sequential(
|
46 |
nn.Linear(input_dim, 512),
|
47 |
nn.ReLU(),
|
|
|
54 |
self.temperature = nn.Parameter(torch.ones(1))
|
55 |
|
56 |
def forward(self, fused_features):
|
|
|
57 |
prot_t5_features = fused_features[:, :self.t5_dim]
|
58 |
hand_crafted_features = fused_features[:, self.t5_dim:]
|
59 |
|
|
|
60 |
prot_t5_features_unsqueezed = prot_t5_features.unsqueeze(1)
|
61 |
transformer_output = self.transformer_encoder(prot_t5_features_unsqueezed)
|
62 |
transformer_output_pooled = transformer_output.mean(dim=1)
|
63 |
|
|
|
64 |
combined_features = torch.cat((transformer_output_pooled, hand_crafted_features), dim=1)
|
65 |
|
|
|
66 |
logits = self.mlp(combined_features)
|
67 |
|
68 |
return logits / self.temperature
|
|
|
75 |
|
76 |
# --- Generator Model Architecture (from generator.py) ---
|
77 |
class ProtT5Generator(nn.Module):
|
|
|
78 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
79 |
super(ProtT5Generator, self).__init__()
|
80 |
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=token2id["<PAD>"])
|
|
|
84 |
self.vocab_size = vocab_size
|
85 |
self.eos_token_id = token2id["<EOS>"]
|
86 |
self.pad_token_id = token2id["<PAD>"]
|
87 |
+
|
88 |
def forward(self, input_ids):
|
89 |
embeddings = self.embed_tokens(input_ids)
|
90 |
encoder_output = self.encoder(embeddings)
|
|
|
99 |
next_logits = logits[:, -1, :] / temperature
|
100 |
if generated.size(1) < min_decoded_length:
|
101 |
next_logits[:, self.eos_token_id] = -float("inf")
|
|
|
102 |
probs = torch.softmax(next_logits, dim=-1)
|
103 |
next_token = torch.multinomial(probs, num_samples=1)
|
104 |
generated = torch.cat((generated, next_token), dim=1)
|
105 |
+
# Early stop if all sequences in batch have generated an EOS token
|
106 |
+
if (generated == self.eos_token_id).any(dim=1).all():
|
107 |
break
|
108 |
return generated
|
109 |
|
|
|
119 |
seqs.append(seq)
|
120 |
return seqs
|
121 |
|
122 |
+
# --- Feature Extraction Logic (needs feature_extract.py) ---
|
|
|
|
|
123 |
try:
|
124 |
from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
|
125 |
except ImportError:
|
|
|
131 |
return sequences[:num_clusters]
|
132 |
with torch.no_grad():
|
133 |
token_ids_list = []
|
134 |
+
max_len = max(len(seq) for seq in sequences) + 2
|
135 |
for seq in sequences:
|
|
|
136 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
137 |
+
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
138 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
139 |
token_ids_list.append(ids)
|
140 |
|
|
|
158 |
representatives.append(sequences[representative_index])
|
159 |
return representatives
|
160 |
|
|
|
161 |
# --------------------------------------------------------------------------
|
162 |
# SECTION 2: GLOBAL MODEL LOADING
|
|
|
163 |
# --------------------------------------------------------------------------
|
164 |
print("Loading all models and dependencies. Please wait...")
|
165 |
+
DEVICE = "cpu"
|
166 |
|
167 |
try:
|
168 |
+
# --- Define file paths (!! CHECK THESE PATHS !!) ---
|
|
|
169 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
170 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
171 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
172 |
PROTT5_BASE_MODEL_PATH = "prott5/model/"
|
173 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
174 |
|
175 |
+
# --- Load Predictor ---
|
176 |
print("Loading Predictor Model...")
|
177 |
+
PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914, transformer_layers=3, transformer_heads=4)
|
|
|
|
|
178 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
179 |
PREDICTOR_MODEL.to(DEVICE)
|
180 |
PREDICTOR_MODEL.eval()
|
181 |
+
print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
|
182 |
|
183 |
+
# --- Load Scaler & Feature Extractor ---
|
184 |
+
print("Loading Scaler and Feature Extractor...")
|
185 |
SCALER = joblib.load(SCALER_PATH)
|
|
|
|
|
|
|
|
|
186 |
PROTT5_EXTRACTOR = FeatureProtT5Model(
|
187 |
model_path=PROTT5_BASE_MODEL_PATH,
|
188 |
finetuned_model_file=FINETUNED_PROTT5_FOR_FEATURES_PATH
|
189 |
)
|
190 |
+
print("✅ Scaler and Feature Extractor loaded.")
|
191 |
|
192 |
+
# --- Load Generator ---
|
193 |
print("Loading Generator Model...")
|
194 |
+
GENERATOR_MODEL = ProtT5Generator(vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8)
|
|
|
|
|
195 |
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
196 |
GENERATOR_MODEL.to(DEVICE)
|
197 |
GENERATOR_MODEL.eval()
|
198 |
print("✅ Generator model loaded.")
|
|
|
199 |
print("\n--- All models loaded successfully! Gradio app is ready. ---\n")
|
200 |
|
201 |
except Exception as e:
|
|
|
204 |
|
205 |
# --------------------------------------------------------------------------
|
206 |
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO
|
|
|
207 |
# --------------------------------------------------------------------------
|
208 |
|
209 |
def predict_peptide_wrapper(sequence_str):
|
|
|
210 |
if not sequence_str or not isinstance(sequence_str, str) or any(c not in AMINO_ACIDS for c in sequence_str.upper()):
|
211 |
return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
|
212 |
|
213 |
try:
|
214 |
+
# These L_fixed and d_model_pe values are from your predictor.py args
|
215 |
+
features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
|
|
|
|
216 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
217 |
|
|
|
218 |
with torch.no_grad():
|
219 |
features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(DEVICE)
|
220 |
logits = PREDICTOR_MODEL(features_tensor)
|
|
|
228 |
return "N/A", f"An error occurred during processing: {e}"
|
229 |
|
230 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
|
|
|
231 |
num_to_generate = int(num_to_generate)
|
232 |
min_len = int(min_len)
|
233 |
max_len = int(max_len)
|
234 |
|
235 |
try:
|
236 |
+
# Step 1: Generate a pool of unique sequences
|
237 |
target_pool_size = int(num_to_generate * diversity_factor)
|
238 |
unique_seqs = set()
|
|
|
239 |
|
240 |
+
# A simple generation loop based on generator.py logic
|
241 |
+
with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
|
242 |
+
while len(unique_seqs) < target_pool_size:
|
243 |
+
batch_size = (target_pool_size - len(unique_seqs))
|
244 |
+
with torch.no_grad():
|
245 |
+
generated_tokens = GENERATOR_MODEL.sample(
|
246 |
+
batch_size=max(1, batch_size),
|
247 |
+
max_length=max_len,
|
248 |
+
device=DEVICE,
|
249 |
+
temperature=temperature,
|
250 |
+
min_decoded_length=min_len
|
251 |
+
)
|
252 |
+
decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
|
253 |
+
|
254 |
+
newly_added = 0
|
255 |
+
for seq in decoded:
|
256 |
+
if min_len <= len(seq) <= max_len:
|
257 |
+
if seq not in unique_seqs:
|
258 |
+
unique_seqs.add(seq)
|
259 |
+
newly_added +=1
|
260 |
+
pbar.update(newly_added)
|
261 |
|
262 |
candidate_seqs = list(unique_seqs)
|
|
|
|
|
263 |
|
264 |
+
# Step 2: Validate the generated sequences
|
265 |
validated_pool = {}
|
266 |
for seq in tqdm(candidate_seqs, desc="Validating generated sequences"):
|
267 |
prob_str, _ = predict_peptide_wrapper(seq)
|
268 |
try:
|
269 |
prob = float(prob_str)
|
270 |
+
if prob > 0.90: # Filter for high-quality peptides
|
271 |
validated_pool[seq] = prob
|
272 |
except (ValueError, TypeError):
|
273 |
continue
|
|
|
277 |
|
278 |
high_quality_sequences = list(validated_pool.keys())
|
279 |
|
280 |
+
# Step 3: Cluster to ensure diversity
|
281 |
progress(1.0, desc="Clustering for diversity...")
|
282 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
283 |
|
284 |
+
# Step 4: Format final results
|
285 |
final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
|
286 |
final_results.sort(key=lambda x: float(x[1]), reverse=True)
|
287 |
|
|
|
291 |
print(f"Generation error: {e}")
|
292 |
return pd.DataFrame({"Sequence": [f"An error occurred during generation: {e}"], "Predicted Probability": ["N/A"]})
|
293 |
|
|
|
294 |
# --------------------------------------------------------------------------
|
295 |
+
# SECTION 4: GRADIO UI CONSTRUCTION (ALL ENGLISH)
|
|
|
296 |
# --------------------------------------------------------------------------
|
297 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
298 |
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
|
|
|
314 |
outputs=[probability_output, class_output]
|
315 |
)
|
316 |
gr.Examples(
|
317 |
+
examples=[["WHYHDYKY"], ["YPGG"], ["LVLHEHGGN"]],
|
318 |
+
inputs=peptide_input
|
|
|
|
|
|
|
319 |
)
|
320 |
|
321 |
with gr.TabItem("Novel Sequence Generator"):
|
|
|
339 |
)
|
340 |
|
341 |
if __name__ == "__main__":
|
342 |
+
demo.launch()
|