Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
# app.py - RLAnOxPeptide Gradio Web Application
|
5 |
-
# Final version
|
6 |
|
7 |
import os
|
8 |
import torch
|
@@ -16,6 +16,9 @@ from tqdm import tqdm
|
|
16 |
import transformers
|
17 |
import time
|
18 |
|
|
|
|
|
|
|
19 |
# Suppress verbose logging from transformers
|
20 |
transformers.logging.set_verbosity_error()
|
21 |
|
@@ -32,26 +35,27 @@ id2token = {i: t for t, i in token2id.items()}
|
|
32 |
VOCAB_SIZE = len(token2id)
|
33 |
|
34 |
|
35 |
-
# --- Feature Extractor Model Class
|
36 |
-
class
|
37 |
-
|
|
|
38 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
-
print(f"Initializing ProtT5
|
40 |
-
|
41 |
-
print(f"Loading base model and tokenizer from '{base_model_id}'...")
|
42 |
-
self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id, do_lower_case=False)
|
43 |
-
self.model = transformers.T5EncoderModel.from_pretrained(base_model_id)
|
44 |
-
|
45 |
-
if finetuned_weights_path and os.path.exists(finetuned_weights_path):
|
46 |
-
print(f"Applying local fine-tuned weights from: {finetuned_weights_path}")
|
47 |
-
state_dict = torch.load(finetuned_weights_path, map_location=self.device)
|
48 |
-
self.model.load_state_dict(state_dict, strict=False)
|
49 |
-
print("Successfully applied fine-tuned weights.")
|
50 |
-
else:
|
51 |
-
print("Warning: Fine-tuned weights not found or not provided. Using base ProtT5 weights.")
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
self.model.eval()
|
|
|
55 |
|
56 |
def encode(self, sequence):
|
57 |
if not sequence or not isinstance(sequence, str):
|
@@ -67,7 +71,7 @@ class FeatureProtT5Model:
|
|
67 |
emb_np = embedding.squeeze(0).cpu().numpy()
|
68 |
return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
|
69 |
|
70 |
-
# --- Predictor Model Architecture ---
|
71 |
class AntioxidantPredictor(nn.Module):
|
72 |
def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
73 |
super(AntioxidantPredictor, self).__init__()
|
@@ -99,7 +103,7 @@ class AntioxidantPredictor(nn.Module):
|
|
99 |
def get_temperature(self):
|
100 |
return self.temperature.item()
|
101 |
|
102 |
-
# --- Generator Model Architecture ---
|
103 |
class ProtT5Generator(nn.Module):
|
104 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
105 |
super(ProtT5Generator, self).__init__()
|
@@ -141,17 +145,16 @@ class ProtT5Generator(nn.Module):
|
|
141 |
sequences.append(seq)
|
142 |
return sequences
|
143 |
|
144 |
-
# --- CRITICAL DEPENDENCY: feature_extract.py ---
|
145 |
try:
|
146 |
from feature_extract import extract_features
|
147 |
except ImportError:
|
148 |
raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
|
149 |
|
150 |
-
# --- Clustering Logic ---
|
151 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
152 |
if not sequences or len(sequences) < num_clusters:
|
153 |
return sequences[:num_clusters]
|
154 |
-
|
155 |
with torch.no_grad():
|
156 |
token_ids_list = []
|
157 |
max_len = max((len(seq) for seq in sequences), default=0) + 2
|
@@ -160,13 +163,11 @@ def cluster_sequences(generator, sequences, num_clusters, device):
|
|
160 |
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
161 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
162 |
token_ids_list.append(ids)
|
163 |
-
|
164 |
input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
|
165 |
embeddings = generator.embed_tokens(input_ids)
|
166 |
mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
|
167 |
seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
|
168 |
seq_embeds_np = seq_embeds.cpu().numpy()
|
169 |
-
|
170 |
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
|
171 |
representatives = []
|
172 |
for i in range(int(num_clusters)):
|
@@ -191,10 +192,12 @@ try:
|
|
191 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
192 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
193 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
|
|
|
|
194 |
PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
|
195 |
-
|
196 |
|
197 |
-
# --- Load Predictor Model ---
|
198 |
print(f"Loading Predictor from: {PREDICTOR_CHECKPOINT_PATH}")
|
199 |
PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
|
200 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
@@ -202,13 +205,14 @@ try:
|
|
202 |
PREDICTOR_MODEL.eval()
|
203 |
print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
|
204 |
|
205 |
-
# --- Load Scaler & Feature Extractor ---
|
206 |
print(f"Loading Scaler from: {SCALER_PATH}")
|
207 |
SCALER = joblib.load(SCALER_PATH)
|
208 |
-
print("Loading ProtT5 Feature Extractor...")
|
209 |
-
|
|
|
210 |
base_model_id=PROTT5_BASE_MODEL_ID,
|
211 |
-
|
212 |
)
|
213 |
print("✅ Scaler and Feature Extractor loaded.")
|
214 |
|
@@ -227,7 +231,7 @@ except Exception as e:
|
|
227 |
raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
|
228 |
|
229 |
# --------------------------------------------------------------------------
|
230 |
-
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI
|
231 |
# --------------------------------------------------------------------------
|
232 |
|
233 |
def predict_peptide_wrapper(sequence_str):
|
@@ -235,6 +239,8 @@ def predict_peptide_wrapper(sequence_str):
|
|
235 |
return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
|
236 |
|
237 |
try:
|
|
|
|
|
238 |
features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
239 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
240 |
|
@@ -251,29 +257,23 @@ def predict_peptide_wrapper(sequence_str):
|
|
251 |
return "N/A", f"An error occurred during prediction: {e}"
|
252 |
|
253 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress()):
|
254 |
-
"""
|
255 |
-
Handles the full generation-validation-clustering pipeline with a loop to ensure
|
256 |
-
the target number of peptides is generated.
|
257 |
-
"""
|
258 |
num_to_generate = int(num_to_generate)
|
259 |
min_len = int(min_len)
|
260 |
max_len = int(max_len)
|
261 |
|
262 |
-
# Safety check for length
|
263 |
if min_len > max_len:
|
264 |
gr.Warning("Minimum Length cannot be greater than Maximum Length. Adjusting min_len = max_len.")
|
265 |
min_len = max_len
|
266 |
|
267 |
try:
|
268 |
-
validated_pool = {}
|
269 |
attempts = 0
|
270 |
-
max_attempts = 20
|
271 |
-
generation_batch_size = 200
|
272 |
|
273 |
while len(validated_pool) < num_to_generate and attempts < max_attempts:
|
274 |
progress(len(validated_pool) / num_to_generate, desc=f"Found {len(validated_pool)} / {num_to_generate} peptides. (Attempt {attempts+1}/{max_attempts})")
|
275 |
|
276 |
-
# Generate a batch of candidate sequences
|
277 |
with torch.no_grad():
|
278 |
generated_tokens = GENERATOR_MODEL.sample(
|
279 |
batch_size=generation_batch_size, max_length=max_len, device=DEVICE,
|
@@ -281,21 +281,18 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
281 |
)
|
282 |
decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
|
283 |
|
284 |
-
# Filter for length and uniqueness
|
285 |
new_candidates = []
|
286 |
for seq in decoded_sequences:
|
287 |
if min_len <= len(seq) <= max_len:
|
288 |
if seq not in validated_pool:
|
289 |
new_candidates.append(seq)
|
290 |
|
291 |
-
# Validate the new, unique candidates
|
292 |
for seq in new_candidates:
|
293 |
prob_str, _ = predict_peptide_wrapper(seq)
|
294 |
try:
|
295 |
prob = float(prob_str)
|
296 |
if prob > 0.90:
|
297 |
validated_pool[seq] = prob
|
298 |
-
# Check if we have reached the target
|
299 |
if len(validated_pool) >= num_to_generate:
|
300 |
break
|
301 |
except (ValueError, TypeError):
|
@@ -311,13 +308,9 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
311 |
if not validated_pool:
|
312 |
return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with the current settings. Try different parameters.", "Predicted Probability": "N/A"}])
|
313 |
|
314 |
-
# --- Final Processing ---
|
315 |
high_quality_sequences = list(validated_pool.keys())
|
316 |
-
|
317 |
-
# Cluster to ensure diversity, selecting up to the target number
|
318 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
319 |
|
320 |
-
# Format final results into a DataFrame
|
321 |
final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
|
322 |
final_results.sort(key=lambda x: float(x[1]), reverse=True)
|
323 |
|
@@ -328,7 +321,7 @@ def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, div
|
|
328 |
return pd.DataFrame([{"Sequence": f"An error occurred during generation: {e}", "Predicted Probability": "N/A"}])
|
329 |
|
330 |
# --------------------------------------------------------------------------
|
331 |
-
# SECTION 4: GRADIO UI CONSTRUCTION
|
332 |
# --------------------------------------------------------------------------
|
333 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
334 |
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction")
|
@@ -364,7 +357,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
|
364 |
with gr.Column():
|
365 |
with gr.Row():
|
366 |
num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
|
367 |
-
# ✅ MODIFIED: Length sliders both have a range of 2-20
|
368 |
min_len_input = gr.Slider(minimum=2, maximum=20, value=3, step=1, label="Minimum Length")
|
369 |
max_len_input = gr.Slider(minimum=2, maximum=20, value=20, step=1, label="Maximum Length")
|
370 |
with gr.Row():
|
@@ -374,10 +366,9 @@ with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
|
374 |
generate_button = gr.Button("Generate Peptides", variant="primary")
|
375 |
results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
|
376 |
|
377 |
-
# ✅ ADDED: Dynamic linking of min and max length sliders for better UX
|
378 |
def update_min_len_range(max_len):
|
379 |
return gr.Slider(maximum=max_len)
|
380 |
-
max_len_input.change(fn=update_min_len_range, inputs=max_len_input, outputs=
|
381 |
|
382 |
def update_max_len_range(min_len):
|
383 |
return gr.Slider(minimum=min_len)
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
# app.py - RLAnOxPeptide Gradio Web Application
|
5 |
+
# Final version updated to use a LoRA-finetuned model for feature extraction.
|
6 |
|
7 |
import os
|
8 |
import torch
|
|
|
16 |
import transformers
|
17 |
import time
|
18 |
|
19 |
+
# NEW DEPENDENCY: peft library for LoRA
|
20 |
+
from peft import PeftModel
|
21 |
+
|
22 |
# Suppress verbose logging from transformers
|
23 |
transformers.logging.set_verbosity_error()
|
24 |
|
|
|
35 |
VOCAB_SIZE = len(token2id)
|
36 |
|
37 |
|
38 |
+
# --- LoRA Feature Extractor Model Class ---
|
39 |
+
# ✅ REPLACED: This new class handles loading the base model and attaching the LoRA adapter.
|
40 |
+
class LoRAProtT5Extractor:
|
41 |
+
def __init__(self, base_model_id, lora_adapter_path):
|
42 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
+
print(f"Initializing LoRA-enhanced ProtT5 on device: {self.device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
print(f" - Loading base model and tokenizer from '{base_model_id}'...")
|
46 |
+
base_model = transformers.T5EncoderModel.from_pretrained(base_model_id)
|
47 |
+
self.tokenizer = transformers.T5Tokenizer.from_pretrained(base_model_id)
|
48 |
+
|
49 |
+
if not os.path.exists(lora_adapter_path):
|
50 |
+
raise FileNotFoundError(f"Error: LoRA adapter directory not found at: {lora_adapter_path}")
|
51 |
+
|
52 |
+
print(f" - Loading and applying LoRA adapter from: {lora_adapter_path}")
|
53 |
+
lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
|
54 |
+
|
55 |
+
print(" - Merging LoRA weights for faster inference...")
|
56 |
+
self.model = lora_model.merge_and_unload().to(self.device)
|
57 |
self.model.eval()
|
58 |
+
print(" - LoRA-enhanced feature extractor is ready.")
|
59 |
|
60 |
def encode(self, sequence):
|
61 |
if not sequence or not isinstance(sequence, str):
|
|
|
71 |
emb_np = embedding.squeeze(0).cpu().numpy()
|
72 |
return emb_np if emb_np.shape[0] > 0 else np.zeros((1, 1024), dtype=np.float32)
|
73 |
|
74 |
+
# --- Predictor Model Architecture (Unchanged) ---
|
75 |
class AntioxidantPredictor(nn.Module):
|
76 |
def __init__(self, input_dim=1914, transformer_layers=3, transformer_heads=4, transformer_dropout=0.1):
|
77 |
super(AntioxidantPredictor, self).__init__()
|
|
|
103 |
def get_temperature(self):
|
104 |
return self.temperature.item()
|
105 |
|
106 |
+
# --- Generator Model Architecture (Unchanged) ---
|
107 |
class ProtT5Generator(nn.Module):
|
108 |
def __init__(self, vocab_size, embed_dim=512, num_layers=6, num_heads=8, dropout=0.1):
|
109 |
super(ProtT5Generator, self).__init__()
|
|
|
145 |
sequences.append(seq)
|
146 |
return sequences
|
147 |
|
148 |
+
# --- CRITICAL DEPENDENCY: feature_extract.py (Unchanged) ---
|
149 |
try:
|
150 |
from feature_extract import extract_features
|
151 |
except ImportError:
|
152 |
raise gr.Error("Fatal Error: `feature_extract.py` not found. This file is required. Please upload it to your repository.")
|
153 |
|
154 |
+
# --- Clustering Logic (Unchanged) ---
|
155 |
def cluster_sequences(generator, sequences, num_clusters, device):
|
156 |
if not sequences or len(sequences) < num_clusters:
|
157 |
return sequences[:num_clusters]
|
|
|
158 |
with torch.no_grad():
|
159 |
token_ids_list = []
|
160 |
max_len = max((len(seq) for seq in sequences), default=0) + 2
|
|
|
163 |
ids = [np.random.randint(2, VOCAB_SIZE)] + ids
|
164 |
ids += [token2id["<PAD>"]] * (max_len - len(ids))
|
165 |
token_ids_list.append(ids)
|
|
|
166 |
input_ids = torch.tensor(token_ids_list, dtype=torch.long, device=device)
|
167 |
embeddings = generator.embed_tokens(input_ids)
|
168 |
mask = (input_ids != token2id["<PAD>"]).unsqueeze(-1).float()
|
169 |
seq_embeds = (embeddings * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
|
170 |
seq_embeds_np = seq_embeds.cpu().numpy()
|
|
|
171 |
kmeans = KMeans(n_clusters=int(num_clusters), random_state=42, n_init='auto').fit(seq_embeds_np)
|
172 |
representatives = []
|
173 |
for i in range(int(num_clusters)):
|
|
|
192 |
PREDICTOR_CHECKPOINT_PATH = "checkpoints/final_rl_model_logitp0.1_calibrated_FINETUNED_PROTT5.pth"
|
193 |
SCALER_PATH = "checkpoints/scaler_FINETUNED_PROTT5.pkl"
|
194 |
GENERATOR_CHECKPOINT_PATH = "generator_checkpoints_v3.6/final_generator_model.pth"
|
195 |
+
|
196 |
+
# ✅ UPDATED: Define paths for LoRA-based loading
|
197 |
PROTT5_BASE_MODEL_ID = "Rostlab/prot_t5_xl_uniref50"
|
198 |
+
LORA_ADAPTER_PATH = "./lora_finetuned_prott5" # Assumes LoRA files are in this directory
|
199 |
|
200 |
+
# --- Load Predictor Model (Head) ---
|
201 |
print(f"Loading Predictor from: {PREDICTOR_CHECKPOINT_PATH}")
|
202 |
PREDICTOR_MODEL = AntioxidantPredictor(input_dim=1914)
|
203 |
PREDICTOR_MODEL.load_state_dict(torch.load(PREDICTOR_CHECKPOINT_PATH, map_location=DEVICE))
|
|
|
205 |
PREDICTOR_MODEL.eval()
|
206 |
print(f"✅ Predictor model loaded (Temp: {PREDICTOR_MODEL.get_temperature():.4f}).")
|
207 |
|
208 |
+
# --- Load Scaler & LoRA Feature Extractor ---
|
209 |
print(f"Loading Scaler from: {SCALER_PATH}")
|
210 |
SCALER = joblib.load(SCALER_PATH)
|
211 |
+
print("Loading LoRA-enhanced ProtT5 Feature Extractor...")
|
212 |
+
# ✅ UPDATED: Instantiate the new LoRA extractor class
|
213 |
+
PROTT5_EXTRACTOR = LoRAProtT5Extractor(
|
214 |
base_model_id=PROTT5_BASE_MODEL_ID,
|
215 |
+
lora_adapter_path=LORA_ADAPTER_PATH
|
216 |
)
|
217 |
print("✅ Scaler and Feature Extractor loaded.")
|
218 |
|
|
|
231 |
raise gr.Error(f"A required model or file could not be loaded. Please check your repository file structure and paths. Error details: {e}")
|
232 |
|
233 |
# --------------------------------------------------------------------------
|
234 |
+
# SECTION 3: WRAPPER FUNCTIONS FOR GRADIO UI (Unchanged logic)
|
235 |
# --------------------------------------------------------------------------
|
236 |
|
237 |
def predict_peptide_wrapper(sequence_str):
|
|
|
239 |
return "0.0000", "Error: Please enter a valid peptide sequence using standard amino acids (ACDEFGHIKLMNPQRSTVWY)."
|
240 |
|
241 |
try:
|
242 |
+
# This function call remains the same because the PROTT5_EXTRACTOR object,
|
243 |
+
# despite its new internal logic, provides the same interface.
|
244 |
features = extract_features(sequence_str.upper(), PROTT5_EXTRACTOR, L_fixed=29, d_model_pe=16)
|
245 |
scaled_features = SCALER.transform(features.reshape(1, -1))
|
246 |
|
|
|
257 |
return "N/A", f"An error occurred during prediction: {e}"
|
258 |
|
259 |
def generate_peptide_wrapper(num_to_generate, min_len, max_len, temperature, diversity_factor, progress=gr.Progress()):
|
|
|
|
|
|
|
|
|
260 |
num_to_generate = int(num_to_generate)
|
261 |
min_len = int(min_len)
|
262 |
max_len = int(max_len)
|
263 |
|
|
|
264 |
if min_len > max_len:
|
265 |
gr.Warning("Minimum Length cannot be greater than Maximum Length. Adjusting min_len = max_len.")
|
266 |
min_len = max_len
|
267 |
|
268 |
try:
|
269 |
+
validated_pool = {}
|
270 |
attempts = 0
|
271 |
+
max_attempts = 20
|
272 |
+
generation_batch_size = 200
|
273 |
|
274 |
while len(validated_pool) < num_to_generate and attempts < max_attempts:
|
275 |
progress(len(validated_pool) / num_to_generate, desc=f"Found {len(validated_pool)} / {num_to_generate} peptides. (Attempt {attempts+1}/{max_attempts})")
|
276 |
|
|
|
277 |
with torch.no_grad():
|
278 |
generated_tokens = GENERATOR_MODEL.sample(
|
279 |
batch_size=generation_batch_size, max_length=max_len, device=DEVICE,
|
|
|
281 |
)
|
282 |
decoded_sequences = GENERATOR_MODEL.decode(generated_tokens)
|
283 |
|
|
|
284 |
new_candidates = []
|
285 |
for seq in decoded_sequences:
|
286 |
if min_len <= len(seq) <= max_len:
|
287 |
if seq not in validated_pool:
|
288 |
new_candidates.append(seq)
|
289 |
|
|
|
290 |
for seq in new_candidates:
|
291 |
prob_str, _ = predict_peptide_wrapper(seq)
|
292 |
try:
|
293 |
prob = float(prob_str)
|
294 |
if prob > 0.90:
|
295 |
validated_pool[seq] = prob
|
|
|
296 |
if len(validated_pool) >= num_to_generate:
|
297 |
break
|
298 |
except (ValueError, TypeError):
|
|
|
308 |
if not validated_pool:
|
309 |
return pd.DataFrame([{"Sequence": "Could not generate any high-activity peptides (>0.9 prob) with the current settings. Try different parameters.", "Predicted Probability": "N/A"}])
|
310 |
|
|
|
311 |
high_quality_sequences = list(validated_pool.keys())
|
|
|
|
|
312 |
final_diverse_seqs = cluster_sequences(GENERATOR_MODEL, high_quality_sequences, num_to_generate, DEVICE)
|
313 |
|
|
|
314 |
final_results = [(seq, f"{validated_pool[seq]:.4f}") for seq in final_diverse_seqs]
|
315 |
final_results.sort(key=lambda x: float(x[1]), reverse=True)
|
316 |
|
|
|
321 |
return pd.DataFrame([{"Sequence": f"An error occurred during generation: {e}", "Predicted Probability": "N/A"}])
|
322 |
|
323 |
# --------------------------------------------------------------------------
|
324 |
+
# SECTION 4: GRADIO UI CONSTRUCTION (Unchanged)
|
325 |
# --------------------------------------------------------------------------
|
326 |
with gr.Blocks(theme=gr.themes.Soft(), title="RLAnOxPeptide") as demo:
|
327 |
gr.Markdown("# RLAnOxPeptide: Intelligent Peptide Design and Prediction")
|
|
|
357 |
with gr.Column():
|
358 |
with gr.Row():
|
359 |
num_input = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="Number of Final Peptides to Generate")
|
|
|
360 |
min_len_input = gr.Slider(minimum=2, maximum=20, value=3, step=1, label="Minimum Length")
|
361 |
max_len_input = gr.Slider(minimum=2, maximum=20, value=20, step=1, label="Maximum Length")
|
362 |
with gr.Row():
|
|
|
366 |
generate_button = gr.Button("Generate Peptides", variant="primary")
|
367 |
results_output = gr.DataFrame(headers=["Sequence", "Predicted Probability"], label="Generated & Validated Peptides (>90% Probability)", wrap=True)
|
368 |
|
|
|
369 |
def update_min_len_range(max_len):
|
370 |
return gr.Slider(maximum=max_len)
|
371 |
+
max_len_input.change(fn=update_min_len_range, inputs=max_len_input, outputs=max_len_input)
|
372 |
|
373 |
def update_max_len_range(min_len):
|
374 |
return gr.Slider(minimum=min_len)
|