Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,8 +19,33 @@ protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
|
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
protbert_model = protbert_model.to(device).eval()
|
21 |
|
22 |
-
#
|
23 |
-
selected_features = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# LIME Explainer Setup
|
26 |
sample_data = np.random.rand(100, len(selected_features))
|
@@ -31,9 +56,7 @@ explainer = LimeTabularExplainer(
|
|
31 |
mode="classification"
|
32 |
)
|
33 |
|
34 |
-
# Feature Extractor
|
35 |
def extract_features(sequence):
|
36 |
-
all_features_dict = {}
|
37 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
38 |
if len(sequence) < 10:
|
39 |
return "Error: Sequence too short."
|
@@ -42,6 +65,7 @@ def extract_features(sequence):
|
|
42 |
ctd_features = CTD.CalculateCTD(sequence)
|
43 |
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
|
44 |
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
|
|
|
45 |
all_features_dict.update(ctd_features)
|
46 |
all_features_dict.update(filtered_dipeptide_features)
|
47 |
all_features_dict.update(auto_features)
|
@@ -49,10 +73,11 @@ def extract_features(sequence):
|
|
49 |
feature_df_all = pd.DataFrame([all_features_dict])
|
50 |
normalized_array = scaler.transform(feature_df_all.values)
|
51 |
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
|
|
|
|
|
52 |
selected_df = normalized_df[selected_features].fillna(0)
|
53 |
return selected_df.values
|
54 |
|
55 |
-
# MIC Predictor
|
56 |
def predictmic(sequence):
|
57 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
58 |
if len(sequence) < 10:
|
@@ -83,46 +108,38 @@ def predictmic(sequence):
|
|
83 |
mic_results[bacterium] = f"Error: {str(e)}"
|
84 |
return mic_results
|
85 |
|
86 |
-
# Full Prediction with LIME Explanation
|
87 |
def full_prediction(sequence):
|
88 |
features = extract_features(sequence)
|
89 |
-
if isinstance(features, str):
|
90 |
return features
|
91 |
-
|
92 |
prediction = model.predict(features)[0]
|
93 |
probabilities = model.predict_proba(features)[0]
|
94 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
95 |
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
|
96 |
-
|
97 |
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
|
98 |
-
|
99 |
if prediction == 0:
|
100 |
mic_values = predictmic(sequence)
|
101 |
-
result += "\nPredicted MIC Values (
|
102 |
for org, mic in mic_values.items():
|
103 |
result += f"- {org}: {mic}\n"
|
104 |
else:
|
105 |
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
|
106 |
-
|
107 |
-
# LIME explanation
|
108 |
explanation = explainer.explain_instance(
|
109 |
data_row=features[0],
|
110 |
predict_fn=model.predict_proba,
|
111 |
num_features=10
|
112 |
)
|
113 |
-
result += "\nTop Features Influencing
|
114 |
for feat, weight in explanation.as_list():
|
115 |
result += f"- {feat}: {round(weight, 4)}\n"
|
116 |
-
|
117 |
return result
|
118 |
|
119 |
-
# Gradio UI
|
120 |
iface = gr.Interface(
|
121 |
fn=full_prediction,
|
122 |
inputs=gr.Textbox(label="Enter Protein Sequence"),
|
123 |
-
outputs=gr.Textbox(label="
|
124 |
title="AMP & MIC Predictor + LIME Explanation",
|
125 |
-
description="Paste an amino acid sequence (
|
126 |
)
|
127 |
|
128 |
-
iface.launch(share=True)
|
|
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
protbert_model = protbert_model.to(device).eval()
|
21 |
|
22 |
+
# Full list of selected features
|
23 |
+
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
|
24 |
+
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
|
25 |
+
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
|
26 |
+
"_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
|
27 |
+
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
|
28 |
+
"_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
|
29 |
+
"_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
|
30 |
+
"_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
|
31 |
+
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
|
32 |
+
"LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
|
33 |
+
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
|
34 |
+
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
|
35 |
+
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
|
36 |
+
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
|
37 |
+
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
|
38 |
+
"GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
|
39 |
+
"GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
|
40 |
+
"GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
|
41 |
+
"GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
|
42 |
+
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
|
43 |
+
"GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
|
44 |
+
"GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
|
45 |
+
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
|
46 |
+
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
|
47 |
+
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
|
48 |
+
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]
|
49 |
|
50 |
# LIME Explainer Setup
|
51 |
sample_data = np.random.rand(100, len(selected_features))
|
|
|
56 |
mode="classification"
|
57 |
)
|
58 |
|
|
|
59 |
def extract_features(sequence):
|
|
|
60 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
61 |
if len(sequence) < 10:
|
62 |
return "Error: Sequence too short."
|
|
|
65 |
ctd_features = CTD.CalculateCTD(sequence)
|
66 |
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
|
67 |
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
|
68 |
+
all_features_dict = {}
|
69 |
all_features_dict.update(ctd_features)
|
70 |
all_features_dict.update(filtered_dipeptide_features)
|
71 |
all_features_dict.update(auto_features)
|
|
|
73 |
feature_df_all = pd.DataFrame([all_features_dict])
|
74 |
normalized_array = scaler.transform(feature_df_all.values)
|
75 |
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
|
76 |
+
if not set(selected_features).issubset(set(normalized_df.columns)):
|
77 |
+
return "Error: Some selected features are missing from computed features."
|
78 |
selected_df = normalized_df[selected_features].fillna(0)
|
79 |
return selected_df.values
|
80 |
|
|
|
81 |
def predictmic(sequence):
|
82 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
83 |
if len(sequence) < 10:
|
|
|
108 |
mic_results[bacterium] = f"Error: {str(e)}"
|
109 |
return mic_results
|
110 |
|
|
|
111 |
def full_prediction(sequence):
|
112 |
features = extract_features(sequence)
|
113 |
+
if isinstance(features, str):
|
114 |
return features
|
|
|
115 |
prediction = model.predict(features)[0]
|
116 |
probabilities = model.predict_proba(features)[0]
|
117 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
118 |
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
|
|
|
119 |
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
|
|
|
120 |
if prediction == 0:
|
121 |
mic_values = predictmic(sequence)
|
122 |
+
result += "\nPredicted MIC Values (\u00b5M):\n"
|
123 |
for org, mic in mic_values.items():
|
124 |
result += f"- {org}: {mic}\n"
|
125 |
else:
|
126 |
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
|
|
|
|
|
127 |
explanation = explainer.explain_instance(
|
128 |
data_row=features[0],
|
129 |
predict_fn=model.predict_proba,
|
130 |
num_features=10
|
131 |
)
|
132 |
+
result += "\nTop Features Influencing Prediction:\n"
|
133 |
for feat, weight in explanation.as_list():
|
134 |
result += f"- {feat}: {round(weight, 4)}\n"
|
|
|
135 |
return result
|
136 |
|
|
|
137 |
iface = gr.Interface(
|
138 |
fn=full_prediction,
|
139 |
inputs=gr.Textbox(label="Enter Protein Sequence"),
|
140 |
+
outputs=gr.Textbox(label="Results"),
|
141 |
title="AMP & MIC Predictor + LIME Explanation",
|
142 |
+
description="Paste an amino acid sequence (\u226510 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
|
143 |
)
|
144 |
|
145 |
+
iface.launch(share=True)
|