Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py - RLAnOxPeptide Gradio Web Application (FINAL CORRECTED VERSION)
|
2 |
|
3 |
import os
|
4 |
import torch
|
@@ -10,6 +10,7 @@ import gradio as gr
|
|
10 |
from sklearn.cluster import KMeans
|
11 |
from tqdm import tqdm
|
12 |
import transformers
|
|
|
13 |
|
14 |
# Suppress verbose logging from transformers
|
15 |
transformers.logging.set_verbosity_error()
|
@@ -19,7 +20,7 @@ transformers.logging.set_verbosity_error()
|
|
19 |
# These definitions are now synchronized with your provided, working scripts.
|
20 |
# --------------------------------------------------------------------------
|
21 |
|
22 |
-
# --- Vocabulary Definition ---
|
23 |
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
24 |
token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
|
25 |
token2id["<PAD>"] = 0
|
@@ -27,7 +28,7 @@ token2id["<EOS>"] = 1
|
|
27 |
id2token = {i: t for t, i in token2id.items()}
|
28 |
VOCAB_SIZE = len(token2id)
|
29 |
|
30 |
-
# --- Predictor Model Architecture (Copied from your
|
31 |
class AntioxidantPredictor(nn.Module):
|
32 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
33 |
super(AntioxidantPredictor, self).__init__()
|
@@ -75,9 +76,7 @@ class AntioxidantPredictor(nn.Module):
|
|
75 |
fused_features = self.fusion_fc(fused_features)
|
76 |
|
77 |
logits = self.classifier(fused_features)
|
78 |
-
|
79 |
logits_scaled = logits / self.temperature
|
80 |
-
|
81 |
return logits_scaled
|
82 |
|
83 |
def set_temperature(self, temp_value, device):
|
@@ -86,7 +85,7 @@ class AntioxidantPredictor(nn.Module):
|
|
86 |
def get_temperature(self):
|
87 |
return self.temperature.item()
|
88 |
|
89 |
-
# --- Generator Model Architecture (Copied from your generator.py) ---
|
90 |
class ProtT5Generator(nn.Module):
|
91 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
92 |
super(ProtT5Generator, self).__init__()
|
@@ -97,7 +96,7 @@ class ProtT5Generator(nn.Module):
|
|
97 |
self.vocab_size = vocab_size
|
98 |
self.eos_token_id = token2id["<EOS>"]
|
99 |
self.pad_token_id = token2id["<PAD>"]
|
100 |
-
|
101 |
def forward(self, input_ids):
|
102 |
embeddings = self.embed_tokens(input_ids)
|
103 |
encoder_output = self.encoder(embeddings)
|
@@ -130,7 +129,7 @@ class ProtT5Generator(nn.Module):
|
|
130 |
seqs.append(seq)
|
131 |
return seqs
|
132 |
|
133 |
-
# --- Feature Extraction (needs feature_extract.py) ---
|
134 |
try:
|
135 |
from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
|
136 |
except ImportError:
|
@@ -142,10 +141,10 @@ def cluster_sequences(generator, sequences, num_clusters, device):
|
|
142 |
return sequences[:num_clusters]
|
143 |
with torch.no_grad():
|
144 |
token_ids_list = []
|
145 |
-
max_len = max(len(seq) for seq in sequences) + 2
|
146 |
for seq in sequences:
|
147 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
148 |
-
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
149 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
150 |
token_ids_list.append(ids)
|
151 |
|
@@ -181,13 +180,16 @@ try:
|
|
181 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
182 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
183 |
PROTT5_BASE_MODEL_PATH = "prott5/model/"
|
|
|
184 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
185 |
|
186 |
# --- Load Predictor ---
|
187 |
print("Loading Predictor Model...")
|
|
|
188 |
PREDICTOR_MODEL = AntioxidantPredictor(
|
189 |
input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
|
190 |
)
|
|
|
191 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
192 |
PREDICTOR_MODEL.to(DEVICE)
|
193 |
PREDICTOR_MODEL.eval()
|
@@ -226,7 +228,7 @@ def predict_peptide_wrapper(sequence_str):
|
|
226 |
return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
|
227 |
|
228 |
try:
|
229 |
-
#
|
230 |
features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
231 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
232 |
|
@@ -243,6 +245,7 @@ def predict_peptide_wrapper(sequence_str):
|
|
243 |
return "N/A", f"An error occurred during processing: {e}"
|
244 |
|
245 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
|
|
|
246 |
num_to_generate = int(num_to_generate)
|
247 |
min_len = int(min_len)
|
248 |
max_len = int(max_len)
|
@@ -254,8 +257,7 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
254 |
|
255 |
with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
|
256 |
while len(unique_seqs) < target_pool_size:
|
257 |
-
|
258 |
-
batch_size = max(1, (target_pool_size - len(unique_seqs)) * 2)
|
259 |
with torch.no_grad():
|
260 |
generated_tokens = GENERATOR_MODEL.sample(
|
261 |
batch_size=batch_size, max_length=max_len, device=DEVICE,
|
|
|
1 |
+
# app.py - RLAnOxPeptide Gradio Web Application (FINAL CORRECTED VERSION - Synced with local scripts)
|
2 |
|
3 |
import os
|
4 |
import torch
|
|
|
10 |
from sklearn.cluster import KMeans
|
11 |
from tqdm import tqdm
|
12 |
import transformers
|
13 |
+
import argparse # We won't use argparse but might need it for compatibility if any function expects it
|
14 |
|
15 |
# Suppress verbose logging from transformers
|
16 |
transformers.logging.set_verbosity_error()
|
|
|
20 |
# These definitions are now synchronized with your provided, working scripts.
|
21 |
# --------------------------------------------------------------------------
|
22 |
|
23 |
+
# --- Vocabulary Definition (from generator.py) ---
|
24 |
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
25 |
token2id = {aa: i + 2 for i, aa in enumerate(AMINO_ACIDS)}
|
26 |
token2id["<PAD>"] = 0
|
|
|
28 |
id2token = {i: t for t, i in token2id.items()}
|
29 |
VOCAB_SIZE = len(token2id)
|
30 |
|
31 |
+
# --- Predictor Model Architecture (Copied VERBATIM from your antioxidant_predictor_5.py) ---
|
32 |
class AntioxidantPredictor(nn.Module):
|
33 |
def __init__(self, input_dim, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
34 |
super(AntioxidantPredictor, self).__init__()
|
|
|
76 |
fused_features = self.fusion_fc(fused_features)
|
77 |
|
78 |
logits = self.classifier(fused_features)
|
|
|
79 |
logits_scaled = logits / self.temperature
|
|
|
80 |
return logits_scaled
|
81 |
|
82 |
def set_temperature(self, temp_value, device):
|
|
|
85 |
def get_temperature(self):
|
86 |
return self.temperature.item()
|
87 |
|
88 |
+
# --- Generator Model Architecture (Copied VERBATIM from your generator.py) ---
|
89 |
class ProtT5Generator(nn.Module):
|
90 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
91 |
super(ProtT5Generator, self).__init__()
|
|
|
96 |
self.vocab_size = vocab_size
|
97 |
self.eos_token_id = token2id["<EOS>"]
|
98 |
self.pad_token_id = token2id["<PAD>"]
|
99 |
+
|
100 |
def forward(self, input_ids):
|
101 |
embeddings = self.embed_tokens(input_ids)
|
102 |
encoder_output = self.encoder(embeddings)
|
|
|
129 |
seqs.append(seq)
|
130 |
return seqs
|
131 |
|
132 |
+
# --- Feature Extraction (needs feature_extract.py in the same directory) ---
|
133 |
try:
|
134 |
from feature_extract import ProtT5Model as FeatureProtT5Model, extract_features
|
135 |
except ImportError:
|
|
|
141 |
return sequences[:num_clusters]
|
142 |
with torch.no_grad():
|
143 |
token_ids_list = []
|
144 |
+
max_len = max(len(seq) for seq in sequences) + 2
|
145 |
for seq in sequences:
|
146 |
ids = [token2id.get(aa, 0) for aa in seq] + [generator.eos_token_id]
|
147 |
+
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
148 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
149 |
token_ids_list.append(ids)
|
150 |
|
|
|
180 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
181 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
182 |
PROTT5_BASE_MODEL_PATH = "prott5/model/"
|
183 |
+
# This path is now used by the FeatureProtT5Model to load the fine-tuned weights
|
184 |
FINETUNED_PROTT5_FOR_FEATURES_PATH = "prott5/model/finetuned_prott5.bin"
|
185 |
|
186 |
# --- Load Predictor ---
|
187 |
print("Loading Predictor Model...")
|
188 |
+
# Initialize the correct class
|
189 |
PREDICTOR_MODEL = AntioxidantPredictor(
|
190 |
input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1
|
191 |
)
|
192 |
+
# Load the state dict that matches this class
|
193 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
194 |
PREDICTOR_MODEL.to(DEVICE)
|
195 |
PREDICTOR_MODEL.eval()
|
|
|
228 |
return "0.0000", "Error: Please enter a valid sequence with standard amino acids."
|
229 |
|
230 |
try:
|
231 |
+
# These L_fixed and d_model_pe values are from your predictor.py args
|
232 |
features = extract_features(sequence_str, PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
233 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
234 |
|
|
|
245 |
return "N/A", f"An error occurred during processing: {e}"
|
246 |
|
247 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress(track_tqdm=True)):
|
248 |
+
# This logic is a direct adaptation of your generator.py main function
|
249 |
num_to_generate = int(num_to_generate)
|
250 |
min_len = int(min_len)
|
251 |
max_len = int(max_len)
|
|
|
257 |
|
258 |
with tqdm(total=target_pool_size, desc="Generating candidate sequences") as pbar:
|
259 |
while len(unique_seqs) < target_pool_size:
|
260 |
+
batch_size = max(1, (target_pool_size - len(unique_seqs)))
|
|
|
261 |
with torch.no_grad():
|
262 |
generated_tokens = GENERATOR_MODEL.sample(
|
263 |
batch_size=batch_size, max_length=max_len, device=DEVICE,
|