nonzeroexit commited on
Commit
0c1f1e9
·
verified ·
1 Parent(s): 25d4105

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -1
app.py CHANGED
@@ -93,8 +93,89 @@ def predict(sequence):
93
  return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
94
  else:
95
  return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
 
 
96
  def predictmic(sequence):
97
- features = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # Gradio interface
99
  iface = gr.Interface(
100
  fn=predict,
 
93
  return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
94
  else:
95
  return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
96
+
97
+
98
  def predictmic(sequence):
99
+ import torch
100
+ from transformers import BertTokenizer, BertModel
101
+ import numpy as np
102
+ import pickle
103
+ from math import expm1
104
+
105
+ # === Load ProtBert model ===
106
+ tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
107
+ model = BertModel.from_pretrained("Rostlab/prot_bert")
108
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
109
+ model = model.to(device).eval()
110
+
111
+ # === Preprocess input sequence ===
112
+ sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
113
+ if len(sequence) < 10:
114
+ return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
115
+
116
+ # === Tokenize & embed using mean pooling ===
117
+ seq_spaced = ' '.join(list(sequence))
118
+ tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
119
+ tokens = {k: v.to(device) for k, v in tokens.items()}
120
+
121
+ with torch.no_grad():
122
+ outputs = model(**tokens)
123
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1) # Shape: (1, 1024)
124
+
125
+ # === MIC models and scalers for each bacterium ===
126
+ bacteria_config = {
127
+ "E.coli": {
128
+ "model": "coli_xgboost_model.pkl",
129
+ "scaler": "coli_scaler.pkl",
130
+ "pca": None
131
+ },
132
+ "S.aureus": {
133
+ "model": "aur_xgboost_model.pkl",
134
+ "scaler": "aur_scaler.pkl",
135
+ "pca": None
136
+ },
137
+ "P.aeruginosa": {
138
+ "model": "arg_xgboost_model.pkl",
139
+ "scaler": "arg_scaler.pkl",
140
+ "pca": None
141
+ },
142
+ "K.Pneumonia": {
143
+ "model": "pne_mlp_model.pkl",
144
+ "scaler": "pne_scaler.pkl",
145
+ "pca": "pne_pca"
146
+ }
147
+ }
148
+
149
+ mic_results = {}
150
+
151
+ for bacterium, cfg in bacteria_config.items():
152
+ try:
153
+ # === Load scaler and transform ===
154
+ with open(cfg["scaler"], "rb") as f:
155
+ scaler = pickle.load(f)
156
+ scaled = scaler.transform(embedding)
157
+
158
+ # === Apply PCA if exists ===
159
+ if cfg["pca"] is not None:
160
+ with open(cfg["pca"], "rb") as f:
161
+ pca = pickle.load(f)
162
+ transformed = pca.transform(scaled)
163
+ else:
164
+ transformed = scaled
165
+
166
+ # === Load model and predict ===
167
+ with open(cfg["model"], "rb") as f:
168
+ mic_model = pickle.load(f)
169
+ mic_log = mic_model.predict(transformed)[0]
170
+ mic = round(expm1(mic_log), 3) # Inverse of log1p used in training
171
+
172
+ mic_results[bacterium] = mic
173
+
174
+ except Exception as e:
175
+ mic_results[bacterium] = f"Error: {str(e)}"
176
+
177
+ return mic_results
178
+
179
  # Gradio interface
180
  iface = gr.Interface(
181
  fn=predict,