Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,556 Bytes
84bfd88 4f0db87 84bfd88 4f0db87 84bfd88 5af1eb8 84bfd88 5af1eb8 d54f04d 84bfd88 5af1eb8 077c500 84bfd88 d54f04d 84bfd88 6250a6b 84bfd88 d54f04d 84bfd88 4f0db87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from model.model import DTIModel
import spaces
dt_str = "14062024_0910"
def make_spider_plot(predictions, model_names, smiles_list):
fig = go.Figure()
for i, (prediction, smiles) in enumerate(zip(predictions, smiles_list)):
fig.add_trace(go.Scatterpolar(
r=prediction,
theta=model_names,
fill='toself',
name=smiles
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)),
showlegend=True
)
return fig
@spaces.GPU
def predict_and_plot(amino_acid_sequence, smiles_input, datasets):
model_ensemble = {}
gbm_model_paths = {
"BindingDB": f"model/xgb_models/xgb_model_BindingDB_{dt_str}_bt_optimized_0.json",
"BioSNAP": f"model/xgb_models/xgb_model_BIOSNAP_full_data_{dt_str}_bt_optimized_0.json",
"DAVIS": f"model/xgb_models/xgb_model_DAVIS_{dt_str}_bt_optimized_0.json",
"BarlowDTI XXL": f"model/xgb_models/{dt_str}_barlowdti_xxl_model.json",
}
for model in datasets:
print(f"Loading model {model}")
model_ensemble[model] = DTIModel(
bt_model_path=f"model/stash/{dt_str}",
gbm_model_path=gbm_model_paths[model],
)
smiles_list = smiles_input.strip().split('\n')
predictions = []
for model in model_ensemble.values():
model_predictions = model.predict(smiles_list, amino_acid_sequence)
predictions.append(model_predictions)
predictions = np.array(predictions).transpose().tolist()
df = pd.DataFrame(predictions, index=smiles_list, columns=datasets).reset_index()
df.columns = ["SMILES"] + datasets
fig = make_spider_plot(predictions, datasets, smiles_list)
return fig, df
dataset_names = [
"BarlowDTI XXL",
"BindingDB",
"BioSNAP",
"DAVIS",
]
title = "Predict Drug-Target Interactions with <span style='font-variant:small-caps;'>BarlowDTI</span>"
description = """
Enter the amino acid sequence and SMILES to get interaction predictions visualized as a spider graph and in a table.
The values can be interpreted as the probability of interaction between the drug and the target (0 = no interaction, 1 = interaction).
Thank you for using <span style='font-variant:small-caps;'>BarlowDTI</span>!
*Note: Thanks to ZeroGPU, you can run this model on a GPU for free.*
"""
article = """
This interface lets the scientific community use <span style='font-variant:small-caps;'>BarlowDTI</span><sub>XXL</sub> to predict drug-target interactions.
The model ensemble consists of four models trained on different datasets: our own curated and refined dataset based on
[Golts et. al](https://doi.org/10.48550/arXiv.2401.17174)
in combination with
[BindingDB](https://doi.org/10.1093/nar/gkl999),
[BioSNAP](https://snap.stanford.edu/index.html), and
[DAVIS](https://doi.org/10.1038/nbt.1990).
If you use our model in your research, please cite our paper:
```
@article{schuh2025barlow,
title = {Barlow {{Twins}} Deep Neural Network for Advanced {{1D}} Drug--Target Interaction Prediction},
author = {Schuh, Maximilian G. and Boldini, Davide and Bohne, Annkathrin I. and Sieber, Stephan A.},
year = {2025},
month = dec,
journal = {Journal of Cheminformatics},
volume = {17},
number = {1},
pages = {1--14},
publisher = {BioMed Central},
issn = {1758-2946},
doi = {10.1186/s13321-025-00952-2},
urldate = {2025-02-06},
copyright = {2025 The Author(s)},
langid = {english},
}
```
"""
theme = gr.themes.Base(
primary_hue="violet",
font=[gr.themes.GoogleFont('IBM Plex Sans'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
)
iface = gr.Interface(
fn=predict_and_plot,
inputs=[
gr.Textbox(label="Protein Sequence", info="Just one sequence is allowed. Remove FASTA syntax (e.g. > ABC).", placeholder="MRSWSTVMLAVLATAATVFGHDADPEMKMTTPQIIMRWGYPAMIYDVTTEDGYILELHRI"),
gr.Textbox(label="Molecule SMILES", info="One per line, multiple allowed.", placeholder="C1CSSC1CCCCC(=O)O\nCC1=CC(=C(C=C1)C(=O)O)O"),
gr.CheckboxGroup(choices=dataset_names, label="Select Models for Prediction", value="BarlowDTI XXL")
],
outputs=[
gr.Plot(label="Predictions Visualization"),
gr.DataFrame(label="Predictions DataFrame"),
],
title=title,
description=description,
article=article,
theme=theme
)
iface.launch()
|