Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py - RLAnOxPeptide Gradio Web Application (
|
2 |
|
3 |
import os
|
4 |
import torch
|
@@ -27,53 +27,66 @@ token2id["<EOS>"] = 1
|
|
27 |
id2token = {i: t for t, i in token2id.items()}
|
28 |
VOCAB_SIZE = len(token2id)
|
29 |
|
30 |
-
# --- Predictor Model Architecture (
|
31 |
class AntioxidantPredictor(nn.Module):
|
32 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
33 |
super(AntioxidantPredictor, self).__init__()
|
34 |
-
self.
|
35 |
-
self.
|
36 |
-
|
|
|
|
|
37 |
encoder_layer = nn.TransformerEncoderLayer(
|
38 |
-
d_model=self.
|
39 |
-
nhead=transformer_heads,
|
40 |
-
dropout=transformer_dropout,
|
41 |
batch_first=True
|
42 |
)
|
43 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
44 |
-
|
45 |
-
self.
|
46 |
-
|
|
|
47 |
nn.ReLU(),
|
48 |
-
nn.Dropout(0.
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
nn.Linear(512, 256),
|
50 |
nn.ReLU(),
|
51 |
-
nn.Dropout(0.
|
52 |
nn.Linear(256, 1)
|
53 |
)
|
54 |
-
self.temperature = nn.Parameter(torch.ones(1))
|
55 |
|
56 |
-
def forward(self,
|
57 |
-
|
58 |
-
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
|
|
|
65 |
|
66 |
-
logits = self.
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
return self.temperature.item()
|
72 |
|
73 |
def set_temperature(self, temp_value, device):
|
74 |
self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
|
75 |
|
76 |
-
|
|
|
|
|
|
|
77 |
class ProtT5Generator(nn.Module):
|
78 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
79 |
super(ProtT5Generator, self).__init__()
|
@@ -102,7 +115,6 @@ class ProtT5Generator(nn.Module):
|
|
102 |
probs = torch.softmax(next_logits, dim=-1)
|
103 |
next_token = torch.multinomial(probs, num_samples=1)
|
104 |
generated = torch.cat((generated, next_token), dim=1)
|
105 |
-
# Early stop if all sequences in batch have generated an EOS token
|
106 |
if (generated == self.eos_token_id).any(dim=1).all():
|
107 |
break
|
108 |
return generated
|
@@ -111,19 +123,18 @@ class ProtT5Generator(nn.Module):
|
|
111 |
seqs = []
|
112 |
for ids_tensor in token_ids_batch:
|
113 |
seq = ""
|
114 |
-
|
115 |
-
for token_id in ids_tensor.tolist()[1:]:
|
116 |
if token_id == self.eos_token_id: break
|
117 |
if token_id == self.pad_token_id: continue
|
118 |
seq += id2token.get(token_id, "?")
|
119 |
seqs.append(seq)
|
120 |
return seqs
|
121 |
|
122 |
-
# --- Feature Extraction
|
123 |
try:
|
124 |
from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
|
125 |
except ImportError:
|
126 |
-
raise gr.Error("Failed to import feature_extract.py.
|
127 |
|
128 |
# --- Clustering Logic (from generator.py) ---
|
129 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
@@ -131,10 +142,10 @@ def cluster_sequences(generator, sequences, num_clusters, device):
|
|
131 |
return sequences[:num_clusters]
|
132 |
with torch.no_grad():
|
133 |
token_ids_list = []
|
134 |
-
max_len = max(len(seq) for seq in sequences) + 2
|
135 |
for seq in sequences:
|
136 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
137 |
-
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
138 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
139 |
token_ids_list.append(ids)
|
140 |
|
@@ -161,7 +172,7 @@ def cluster_sequences(generator, sequences, num_clusters, device):
|
|
161 |
# --------------------------------------------------------------------------
|
162 |
# SECTION 2: GLOBAL MODEL LOADING
|
163 |
# --------------------------------------------------------------------------
|
164 |
-
print("Loading all models and dependencies
|
165 |
DEVICE = "cpu"
|
166 |
|
167 |
try:
|
@@ -174,7 +185,9 @@ try:
|
|
174 |
|
175 |
# --- Load Predictor ---
|
176 |
print("Loading Predictor Model...")
|
177 |
-
PREDICTOR_MODEL = AntioxidantPredictor(
|
|
|
|
|
178 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
179 |
PREDICTOR_MODEL.to(DEVICE)
|
180 |
PREDICTOR_MODEL.eval()
|
@@ -191,7 +204,9 @@ try:
|
|
191 |
|
192 |
# --- Load Generator ---
|
193 |
print("Loading Generator Model...")
|
194 |
-
GENERATOR_MODEL = ProtT5Generator(
|
|
|
|
|
195 |
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
196 |
GENERATOR_MODEL.to(DEVICE)
|
197 |
GENERATOR_MODEL.eval()
|
@@ -211,7 +226,7 @@ def predict_peptide_wrapper(sequence_str):
|
|
211 |
return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
|
212 |
|
213 |
try:
|
214 |
-
#
|
215 |
features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
216 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
217 |
|
@@ -237,27 +252,22 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
237 |
target_pool_size = int(num_to_generate * diversity_factor)
|
238 |
unique_seqs = set()
|
239 |
|
240 |
-
# A simple generation loop based on generator.py logic
|
241 |
with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
|
242 |
while len(unique_seqs) < target_pool_size:
|
243 |
-
|
|
|
244 |
with torch.no_grad():
|
245 |
generated_tokens = GENERATOR_MODEL.sample(
|
246 |
-
batch_size=
|
247 |
-
|
248 |
-
device=DEVICE,
|
249 |
-
temperature=temperature,
|
250 |
-
min_decoded_length=min_len
|
251 |
)
|
252 |
decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
|
253 |
|
254 |
-
|
255 |
for seq in decoded:
|
256 |
if min_len <= len(seq) <= max_len:
|
257 |
-
|
258 |
-
|
259 |
-
newly_added +=1
|
260 |
-
pbar.update(newly_added)
|
261 |
|
262 |
candidate_seqs = list(unique_seqs)
|
263 |
|
@@ -267,13 +277,13 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
267 |
prob_str, _ = predict_peptide_wrapper(seq)
|
268 |
try:
|
269 |
prob = float(prob_str)
|
270 |
-
if prob > 0.90:
|
271 |
validated_pool[seq] = prob
|
272 |
except (ValueError, TypeError):
|
273 |
continue
|
274 |
|
275 |
if not validated_pool:
|
276 |
-
return pd.DataFrame({"Sequence":
|
277 |
|
278 |
high_quality_sequences = list(validated_pool.keys())
|
279 |
|
@@ -289,10 +299,10 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
289 |
|
290 |
except Exception as e:
|
291 |
print(f"Generation error: {e}")
|
292 |
-
return pd.DataFrame({"Sequence":
|
293 |
|
294 |
# --------------------------------------------------------------------------
|
295 |
-
# SECTION 4: GRADIO UI CONSTRUCTION
|
296 |
# --------------------------------------------------------------------------
|
297 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
298 |
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
|
|
|
1 |
+
# app.py - RLAnOxPeptide Gradio Web Application (FINAL CORRECTED VERSION)
|
2 |
|
3 |
import os
|
4 |
import torch
|
|
|
27 |
id2token = {i: t for t, i in token2id.items()}
|
28 |
VOCAB_SIZE = len(token2id)
|
29 |
|
30 |
+
# --- Predictor Model Architecture (Copied from your LATEST antioxidant_predictor_5.py) ---
|
31 |
class AntioxidantPredictor(nn.Module):
|
32 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
33 |
super(AntioxidantPredictor, self).__init__()
|
34 |
+
self.prott5_dim = 1024
|
35 |
+
self.handcrafted_dim = input_dim - self.prott5_dim
|
36 |
+
self.seq_len = 16
|
37 |
+
self.prott5_feature_dim = 64 # 16 * 64 = 1024
|
38 |
+
|
39 |
encoder_layer = nn.TransformerEncoderLayer(
|
40 |
+
d_model=self.prott5_feature_dim,
|
41 |
+
nhead=transformer_heads,
|
42 |
+
dropout=transformer_dropout,
|
43 |
batch_first=True
|
44 |
)
|
45 |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers)
|
46 |
+
|
47 |
+
fused_dim = self.prott5_feature_dim + self.handcrafted_dim
|
48 |
+
self.fusion_fc = nn.Sequential(
|
49 |
+
nn.Linear(fused_dim, 1024),
|
50 |
nn.ReLU(),
|
51 |
+
nn.Dropout(0.3),
|
52 |
+
nn.Linear(1024, 512),
|
53 |
+
nn.ReLU(),
|
54 |
+
nn.Dropout(0.3)
|
55 |
+
)
|
56 |
+
|
57 |
+
self.classifier = nn.Sequential(
|
58 |
nn.Linear(512, 256),
|
59 |
nn.ReLU(),
|
60 |
+
nn.Dropout(0.3),
|
61 |
nn.Linear(256, 1)
|
62 |
)
|
63 |
+
self.temperature = nn.Parameter(torch.ones(1), requires_grad=False)
|
64 |
|
65 |
+
def forward(self, x, *args):
|
66 |
+
batch_size = x.size(0)
|
67 |
+
prot_t5_features = x[:, :self.prott5_dim]
|
68 |
+
handcrafted_features = x[:, self.prott5_dim:]
|
69 |
|
70 |
+
prot_t5_seq = prot_t5_features.view(batch_size, self.seq_len, self.prott5_feature_dim)
|
71 |
+
encoded_seq = self.transformer_encoder(prot_t5_seq)
|
72 |
+
refined_prott5 = encoded_seq.mean(dim=1)
|
73 |
|
74 |
+
fused_features = torch.cat([refined_prott5, handcrafted_features], dim=1)
|
75 |
+
fused_features = self.fusion_fc(fused_features)
|
76 |
|
77 |
+
logits = self.classifier(fused_features)
|
78 |
|
79 |
+
logits_scaled = logits / self.temperature
|
80 |
+
|
81 |
+
return logits_scaled
|
|
|
82 |
|
83 |
def set_temperature(self, temp_value, device):
|
84 |
self.temperature = nn.Parameter(torch.tensor([temp_value], device=device), requires_grad=False)
|
85 |
|
86 |
+
def get_temperature(self):
|
87 |
+
return self.temperature.item()
|
88 |
+
|
89 |
+
# --- Generator Model Architecture (Copied from your generator.py) ---
|
90 |
class ProtT5Generator(nn.Module):
|
91 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
92 |
super(ProtT5Generator, self).__init__()
|
|
|
115 |
probs = torch.softmax(next_logits, dim=-1)
|
116 |
next_token = torch.multinomial(probs, num_samples=1)
|
117 |
generated = torch.cat((generated, next_token), dim=1)
|
|
|
118 |
if (generated == self.eos_token_id).any(dim=1).all():
|
119 |
break
|
120 |
return generated
|
|
|
123 |
seqs = []
|
124 |
for ids_tensor in token_ids_batch:
|
125 |
seq = ""
|
126 |
+
for token_id in ids_tensor.tolist()[1:]: # Skip start token
|
|
|
127 |
if token_id == self.eos_token_id: break
|
128 |
if token_id == self.pad_token_id: continue
|
129 |
seq += id2token.get(token_id, "?")
|
130 |
seqs.append(seq)
|
131 |
return seqs
|
132 |
|
133 |
+
# --- Feature Extraction (needs feature_extract.py) ---
|
134 |
try:
|
135 |
from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
|
136 |
except ImportError:
|
137 |
+
raise gr.Error("Failed to import feature_extract.py. Ensure it is in the same directory.")
|
138 |
|
139 |
# --- Clustering Logic (from generator.py) ---
|
140 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
|
|
142 |
return sequences[:num_clusters]
|
143 |
with torch.no_grad():
|
144 |
token_ids_list = []
|
145 |
+
max_len = max(len(seq) for seq in sequences) + 2
|
146 |
for seq in sequences:
|
147 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
148 |
+
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
149 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
150 |
token_ids_list.append(ids)
|
151 |
|
|
|
172 |
# --------------------------------------------------------------------------
|
173 |
# SECTION 2: GLOBAL MODEL LOADING
|
174 |
# --------------------------------------------------------------------------
|
175 |
+
print("Loading all models and dependencies...")
|
176 |
DEVICE = "cpu"
|
177 |
|
178 |
try:
|
|
|
185 |
|
186 |
# --- Load Predictor ---
|
187 |
print("Loading Predictor Model...")
|
188 |
+
PREDICTOR_MODEL = AntioxidantPredictor(
|
189 |
+
input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
|
190 |
+
)
|
191 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
192 |
PREDICTOR_MODEL.to(DEVICE)
|
193 |
PREDICTOR_MODEL.eval()
|
|
|
204 |
|
205 |
# --- Load Generator ---
|
206 |
print("Loading Generator Model...")
|
207 |
+
GENERATOR_MODEL = ProtT5Generator(
|
208 |
+
vocab_size=VOCAB_SIZE, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1
|
209 |
+
)
|
210 |
GENERATOR_MODEL.load_state_dict(torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE))
|
211 |
GENERATOR_MODEL.to(DEVICE)
|
212 |
GENERATOR_MODEL.eval()
|
|
|
226 |
return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
|
227 |
|
228 |
try:
|
229 |
+
# Use feature extraction params from your working predictor.py
|
230 |
features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
231 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
232 |
|
|
|
252 |
target_pool_size = int(num_to_generate * diversity_factor)
|
253 |
unique_seqs = set()
|
254 |
|
|
|
255 |
with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
|
256 |
while len(unique_seqs) < target_pool_size:
|
257 |
+
# Generate a surplus to account for filtering
|
258 |
+
batch_size = max(1, (target_pool_size - len(unique_seqs)) * 2)
|
259 |
with torch.no_grad():
|
260 |
generated_tokens = GENERATOR_MODEL.sample(
|
261 |
+
batch_size=batch_size, max_length=max_len, device=DEVICE,
|
262 |
+
temperature=temperature, min_decoded_length=min_len
|
|
|
|
|
|
|
263 |
)
|
264 |
decoded = GENERATOR_MODEL.decode(generated_tokens.cpu())
|
265 |
|
266 |
+
initial_count = len(unique_seqs)
|
267 |
for seq in decoded:
|
268 |
if min_len <= len(seq) <= max_len:
|
269 |
+
unique_seqs.add(seq)
|
270 |
+
pbar.update(len(unique_seqs) - initial_count)
|
|
|
|
|
271 |
|
272 |
candidate_seqs = list(unique_seqs)
|
273 |
|
|
|
277 |
prob_str, _ = predict_peptide_wrapper(seq)
|
278 |
try:
|
279 |
prob = float(prob_str)
|
280 |
+
if prob > 0.90:
|
281 |
validated_pool[seq] = prob
|
282 |
except (ValueError, TypeError):
|
283 |
continue
|
284 |
|
285 |
if not validated_pool:
|
286 |
+
return pd.DataFrame([{"Sequence": "No high-activity peptides (>0.9 prob) were generated.", "Predicted Probability": "N/A"}])
|
287 |
|
288 |
high_quality_sequences = list(validated_pool.keys())
|
289 |
|
|
|
299 |
|
300 |
except Exception as e:
|
301 |
print(f"Generation error: {e}")
|
302 |
+
return pd.DataFrame([{"Sequence": f"An error occurred: {e}", "Predicted Probability": "N/A"}])
|
303 |
|
304 |
# --------------------------------------------------------------------------
|
305 |
+
# SECTION 4: GRADIO UI CONSTRUCTION
|
306 |
# --------------------------------------------------------------------------
|
307 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
308 |
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction Platform")
|