Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -93,8 +93,89 @@ def predict(sequence):
|
|
93 |
return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
|
94 |
else:
|
95 |
return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
|
|
|
|
|
96 |
def predictmic(sequence):
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
# Gradio interface
|
99 |
iface = gr.Interface(
|
100 |
fn=predict,
|
|
|
93 |
return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
|
94 |
else:
|
95 |
return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
|
96 |
+
|
97 |
+
|
98 |
def predictmic(sequence):
|
99 |
+
import torch
|
100 |
+
from transformers import BertTokenizer, BertModel
|
101 |
+
import numpy as np
|
102 |
+
import pickle
|
103 |
+
from math import expm1
|
104 |
+
|
105 |
+
# === Load ProtBert model ===
|
106 |
+
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
|
107 |
+
model = BertModel.from_pretrained("Rostlab/prot_bert")
|
108 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
109 |
+
model = model.to(device).eval()
|
110 |
+
|
111 |
+
# === Preprocess input sequence ===
|
112 |
+
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
113 |
+
if len(sequence) < 10:
|
114 |
+
return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
|
115 |
+
|
116 |
+
# === Tokenize & embed using mean pooling ===
|
117 |
+
seq_spaced = ' '.join(list(sequence))
|
118 |
+
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
|
119 |
+
tokens = {k: v.to(device) for k, v in tokens.items()}
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
outputs = model(**tokens)
|
123 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1) # Shape: (1, 1024)
|
124 |
+
|
125 |
+
# === MIC models and scalers for each bacterium ===
|
126 |
+
bacteria_config = {
|
127 |
+
"E.coli": {
|
128 |
+
"model": "coli_xgboost_model.pkl",
|
129 |
+
"scaler": "coli_scaler.pkl",
|
130 |
+
"pca": None
|
131 |
+
},
|
132 |
+
"S.aureus": {
|
133 |
+
"model": "aur_xgboost_model.pkl",
|
134 |
+
"scaler": "aur_scaler.pkl",
|
135 |
+
"pca": None
|
136 |
+
},
|
137 |
+
"P.aeruginosa": {
|
138 |
+
"model": "arg_xgboost_model.pkl",
|
139 |
+
"scaler": "arg_scaler.pkl",
|
140 |
+
"pca": None
|
141 |
+
},
|
142 |
+
"K.Pneumonia": {
|
143 |
+
"model": "pne_mlp_model.pkl",
|
144 |
+
"scaler": "pne_scaler.pkl",
|
145 |
+
"pca": "pne_pca"
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
mic_results = {}
|
150 |
+
|
151 |
+
for bacterium, cfg in bacteria_config.items():
|
152 |
+
try:
|
153 |
+
# === Load scaler and transform ===
|
154 |
+
with open(cfg["scaler"], "rb") as f:
|
155 |
+
scaler = pickle.load(f)
|
156 |
+
scaled = scaler.transform(embedding)
|
157 |
+
|
158 |
+
# === Apply PCA if exists ===
|
159 |
+
if cfg["pca"] is not None:
|
160 |
+
with open(cfg["pca"], "rb") as f:
|
161 |
+
pca = pickle.load(f)
|
162 |
+
transformed = pca.transform(scaled)
|
163 |
+
else:
|
164 |
+
transformed = scaled
|
165 |
+
|
166 |
+
# === Load model and predict ===
|
167 |
+
with open(cfg["model"], "rb") as f:
|
168 |
+
mic_model = pickle.load(f)
|
169 |
+
mic_log = mic_model.predict(transformed)[0]
|
170 |
+
mic = round(expm1(mic_log), 3) # Inverse of log1p used in training
|
171 |
+
|
172 |
+
mic_results[bacterium] = mic
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
mic_results[bacterium] = f"Error: {str(e)}"
|
176 |
+
|
177 |
+
return mic_results
|
178 |
+
|
179 |
# Gradio interface
|
180 |
iface = gr.Interface(
|
181 |
fn=predict,
|