nonzeroexit commited on
Commit
adaeb14
·
verified ·
1 Parent(s): 1b97b02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -19
app.py CHANGED
@@ -19,8 +19,33 @@ protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  protbert_model = protbert_model.to(device).eval()
21
 
22
- # Selected Features
23
- selected_features = [ ... ] # keep your full selected_features list here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # LIME Explainer Setup
26
  sample_data = np.random.rand(100, len(selected_features))
@@ -31,9 +56,7 @@ explainer = LimeTabularExplainer(
31
  mode="classification"
32
  )
33
 
34
- # Feature Extractor
35
  def extract_features(sequence):
36
- all_features_dict = {}
37
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
38
  if len(sequence) < 10:
39
  return "Error: Sequence too short."
@@ -42,6 +65,7 @@ def extract_features(sequence):
42
  ctd_features = CTD.CalculateCTD(sequence)
43
  auto_features = Autocorrelation.CalculateAutoTotal(sequence)
44
  pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
 
45
  all_features_dict.update(ctd_features)
46
  all_features_dict.update(filtered_dipeptide_features)
47
  all_features_dict.update(auto_features)
@@ -49,10 +73,11 @@ def extract_features(sequence):
49
  feature_df_all = pd.DataFrame([all_features_dict])
50
  normalized_array = scaler.transform(feature_df_all.values)
51
  normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
 
 
52
  selected_df = normalized_df[selected_features].fillna(0)
53
  return selected_df.values
54
 
55
- # MIC Predictor
56
  def predictmic(sequence):
57
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
58
  if len(sequence) < 10:
@@ -83,46 +108,38 @@ def predictmic(sequence):
83
  mic_results[bacterium] = f"Error: {str(e)}"
84
  return mic_results
85
 
86
- # Full Prediction with LIME Explanation
87
  def full_prediction(sequence):
88
  features = extract_features(sequence)
89
- if isinstance(features, str): # error
90
  return features
91
-
92
  prediction = model.predict(features)[0]
93
  probabilities = model.predict_proba(features)[0]
94
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
95
  confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
96
-
97
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
98
-
99
  if prediction == 0:
100
  mic_values = predictmic(sequence)
101
- result += "\nPredicted MIC Values (µM):\n"
102
  for org, mic in mic_values.items():
103
  result += f"- {org}: {mic}\n"
104
  else:
105
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
106
-
107
- # LIME explanation
108
  explanation = explainer.explain_instance(
109
  data_row=features[0],
110
  predict_fn=model.predict_proba,
111
  num_features=10
112
  )
113
- result += "\nTop Features Influencing AMP Prediction:\n"
114
  for feat, weight in explanation.as_list():
115
  result += f"- {feat}: {round(weight, 4)}\n"
116
-
117
  return result
118
 
119
- # Gradio UI
120
  iface = gr.Interface(
121
  fn=full_prediction,
122
  inputs=gr.Textbox(label="Enter Protein Sequence"),
123
- outputs=gr.Textbox(label="Prediction + MIC + LIME"),
124
  title="AMP & MIC Predictor + LIME Explanation",
125
- description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
126
  )
127
 
128
- iface.launch(share=True)
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  protbert_model = protbert_model.to(device).eval()
21
 
22
+ # Full list of selected features
23
+ selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
24
+ "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
25
+ "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
26
+ "_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
27
+ "_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
28
+ "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
29
+ "_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
30
+ "_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
31
+ "AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
32
+ "LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
33
+ "MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
34
+ "GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
35
+ "GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
36
+ "GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
37
+ "GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
38
+ "GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
39
+ "GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
40
+ "GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
41
+ "GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
42
+ "GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
43
+ "GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
44
+ "GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
45
+ "GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
46
+ "GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
47
+ "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
48
+ "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
49
 
50
  # LIME Explainer Setup
51
  sample_data = np.random.rand(100, len(selected_features))
 
56
  mode="classification"
57
  )
58
 
 
59
  def extract_features(sequence):
 
60
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
61
  if len(sequence) < 10:
62
  return "Error: Sequence too short."
 
65
  ctd_features = CTD.CalculateCTD(sequence)
66
  auto_features = Autocorrelation.CalculateAutoTotal(sequence)
67
  pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
68
+ all_features_dict = {}
69
  all_features_dict.update(ctd_features)
70
  all_features_dict.update(filtered_dipeptide_features)
71
  all_features_dict.update(auto_features)
 
73
  feature_df_all = pd.DataFrame([all_features_dict])
74
  normalized_array = scaler.transform(feature_df_all.values)
75
  normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
76
+ if not set(selected_features).issubset(set(normalized_df.columns)):
77
+ return "Error: Some selected features are missing from computed features."
78
  selected_df = normalized_df[selected_features].fillna(0)
79
  return selected_df.values
80
 
 
81
  def predictmic(sequence):
82
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
83
  if len(sequence) < 10:
 
108
  mic_results[bacterium] = f"Error: {str(e)}"
109
  return mic_results
110
 
 
111
  def full_prediction(sequence):
112
  features = extract_features(sequence)
113
+ if isinstance(features, str):
114
  return features
 
115
  prediction = model.predict(features)[0]
116
  probabilities = model.predict_proba(features)[0]
117
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
118
  confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
 
119
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
 
120
  if prediction == 0:
121
  mic_values = predictmic(sequence)
122
+ result += "\nPredicted MIC Values (\u00b5M):\n"
123
  for org, mic in mic_values.items():
124
  result += f"- {org}: {mic}\n"
125
  else:
126
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
 
 
127
  explanation = explainer.explain_instance(
128
  data_row=features[0],
129
  predict_fn=model.predict_proba,
130
  num_features=10
131
  )
132
+ result += "\nTop Features Influencing Prediction:\n"
133
  for feat, weight in explanation.as_list():
134
  result += f"- {feat}: {round(weight, 4)}\n"
 
135
  return result
136
 
 
137
  iface = gr.Interface(
138
  fn=full_prediction,
139
  inputs=gr.Textbox(label="Enter Protein Sequence"),
140
+ outputs=gr.Textbox(label="Results"),
141
  title="AMP & MIC Predictor + LIME Explanation",
142
+ description="Paste an amino acid sequence (\u226510 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
143
  )
144
 
145
+ iface.launch(share=True)