feat: moler updates
Browse files- app.py +15 -4
- model_cards/article.md +8 -1
- model_cards/examples.csv +7 -5
- requirements.txt +1 -1
- utils.py +4 -2
app.py
CHANGED
|
@@ -17,7 +17,9 @@ TITLE = "MoLeR"
|
|
| 17 |
def run_inference(
|
| 18 |
algorithm_version: str,
|
| 19 |
scaffolds: str,
|
|
|
|
| 20 |
beam_size: int,
|
|
|
|
| 21 |
number_of_samples: int,
|
| 22 |
seed: int,
|
| 23 |
):
|
|
@@ -25,15 +27,18 @@ def run_inference(
|
|
| 25 |
algorithm_version=algorithm_version,
|
| 26 |
scaffolds=scaffolds,
|
| 27 |
beam_size=beam_size,
|
| 28 |
-
num_samples=
|
| 29 |
seed=seed,
|
| 30 |
num_workers=1,
|
|
|
|
|
|
|
| 31 |
)
|
| 32 |
model = MoLeR(configuration=config)
|
| 33 |
samples = list(model.sample(number_of_samples))
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
if __name__ == "__main__":
|
|
@@ -67,7 +72,13 @@ if __name__ == "__main__":
|
|
| 67 |
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
|
| 68 |
lines=1,
|
| 69 |
),
|
| 70 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
gr.Slider(
|
| 72 |
minimum=1, maximum=50, value=10, label="Number of samples", step=1
|
| 73 |
),
|
|
|
|
| 17 |
def run_inference(
|
| 18 |
algorithm_version: str,
|
| 19 |
scaffolds: str,
|
| 20 |
+
seed_smiles: str,
|
| 21 |
beam_size: int,
|
| 22 |
+
sigma: float,
|
| 23 |
number_of_samples: int,
|
| 24 |
seed: int,
|
| 25 |
):
|
|
|
|
| 27 |
algorithm_version=algorithm_version,
|
| 28 |
scaffolds=scaffolds,
|
| 29 |
beam_size=beam_size,
|
| 30 |
+
num_samples=32,
|
| 31 |
seed=seed,
|
| 32 |
num_workers=1,
|
| 33 |
+
seed_smiles=seed_smiles,
|
| 34 |
+
sigma=sigma,
|
| 35 |
)
|
| 36 |
model = MoLeR(configuration=config)
|
| 37 |
samples = list(model.sample(number_of_samples))
|
| 38 |
|
| 39 |
+
scaffold_list = [] if scaffolds == "" else scaffolds.split(".")
|
| 40 |
+
seed_list = [] if seed_smiles == "" else seed_smiles.split(".")
|
| 41 |
+
return draw_grid_generate(seed_list, scaffold_list, samples)
|
| 42 |
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
|
|
|
| 72 |
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
|
| 73 |
lines=1,
|
| 74 |
),
|
| 75 |
+
gr.Textbox(
|
| 76 |
+
label="Seed SMILES",
|
| 77 |
+
placeholder="O=C1C2=CC=C(C3=CC=CC=C3)C=C=C2OC2=CC=CC=C12",
|
| 78 |
+
lines=1,
|
| 79 |
+
),
|
| 80 |
+
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Beams"),
|
| 81 |
+
gr.Slider(minimum=0.0, maximum=3.0, value=0.01, label="Sigma"),
|
| 82 |
gr.Slider(
|
| 83 |
minimum=1, maximum=50, value=10, label="Number of samples", step=1
|
| 84 |
),
|
model_cards/article.md
CHANGED
|
@@ -2,12 +2,19 @@
|
|
| 2 |
|
| 3 |
**Algorithm Version**: Which model checkpoint to use (trained on different datasets).
|
| 4 |
|
| 5 |
-
**Scaffolds**: One or multiple scaffolds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
**Number of samples**: How many samples should be generated (between 1 and 50).
|
| 8 |
|
| 9 |
**Beam size**: Beam size used in beam search decoding (the higher the slower but better).
|
| 10 |
|
|
|
|
|
|
|
| 11 |
**Seed**: The random seed used for initialization.
|
| 12 |
|
| 13 |
|
|
|
|
| 2 |
|
| 3 |
**Algorithm Version**: Which model checkpoint to use (trained on different datasets).
|
| 4 |
|
| 5 |
+
**Scaffolds**: One or multiple scaffolds, provided as '.'-separated SMILES. If empty, no scaffolds are used. Note that this is a hard-constraint,
|
| 6 |
+
i.e., the scaffold will certainly be present in the generated molecule. If multiple scaffolds are given, they are paired with the seed SMILES
|
| 7 |
+
(if applicable) and every molecule will be guaranteed to contain exactly one scaffold.
|
| 8 |
+
|
| 9 |
+
**Seed SMILES**: One or multiple seed molecules, provided as '.'-separated SMILES. If empty, no scaffolds are used.
|
| 10 |
+
There's no guarantee for a seed SMILES (or a substructure of it) to be present in the generated molecule as it's merely used for decoder initialization.
|
| 11 |
|
| 12 |
**Number of samples**: How many samples should be generated (between 1 and 50).
|
| 13 |
|
| 14 |
**Beam size**: Beam size used in beam search decoding (the higher the slower but better).
|
| 15 |
|
| 16 |
+
**Sigma**: Variance of the Gaussian noise that is added to the latent code (before passing to the decoder).
|
| 17 |
+
|
| 18 |
**Seed**: The random seed used for initialization.
|
| 19 |
|
| 20 |
|
model_cards/examples.csv
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
-
v0
|
| 2 |
-
v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1
|
| 3 |
-
v0,
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
v0,,,1,0.0,4,0
|
| 2 |
+
v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,,1,0.0,10,1
|
| 3 |
+
v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,,1,0.3,10,2
|
| 4 |
+
v0,,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,1,0.2,10,3
|
| 5 |
+
v0,,C12C=CC=NN1C(C#CC1=C(C)C=CC3C(NC4=CC(C(F)(F)F)=CC=C4)=NOC1=3)=CN=2.CCO,3,0.2,5,5
|
| 6 |
+
v0,,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,1,0.5,10,9
|
| 7 |
+
v0,CC(=O)NC1=NC2=CC(OCC3=CC=CN(CC4=CC=C(Cl)C=C4)C3=O)=CC=C2N1,c1ccccc1,1,0.2,10,10
|
requirements.txt
CHANGED
|
@@ -8,7 +8,7 @@ torch-sparse
|
|
| 8 |
torch-geometric
|
| 9 |
torchvision==0.13.1
|
| 10 |
torchaudio==0.12.1
|
| 11 |
-
gt4sd>=1.
|
| 12 |
molgx>=0.22.0a1
|
| 13 |
diffusers==0.6.0
|
| 14 |
molecule_generation
|
|
|
|
| 8 |
torch-geometric
|
| 9 |
torchvision==0.13.1
|
| 10 |
torchaudio==0.12.1
|
| 11 |
+
gt4sd>=1.1.12
|
| 12 |
molgx>=0.22.0a1
|
| 13 |
diffusers==0.6.0
|
| 14 |
molecule_generation
|
utils.py
CHANGED
|
@@ -15,8 +15,9 @@ logger.addHandler(logging.NullHandler())
|
|
| 15 |
|
| 16 |
def draw_grid_generate(
|
| 17 |
seeds: List[str],
|
|
|
|
| 18 |
samples: List[str],
|
| 19 |
-
n_cols: int =
|
| 20 |
size=(140, 200),
|
| 21 |
) -> str:
|
| 22 |
"""
|
|
@@ -34,8 +35,9 @@ def draw_grid_generate(
|
|
| 34 |
result = defaultdict(list)
|
| 35 |
result.update(
|
| 36 |
{
|
| 37 |
-
"SMILES": seeds + samples,
|
| 38 |
"Name": [f"Seed_{i}" for i in range(len(seeds))]
|
|
|
|
| 39 |
+ [f"Generated_{i}" for i in range(len(samples))],
|
| 40 |
},
|
| 41 |
)
|
|
|
|
| 15 |
|
| 16 |
def draw_grid_generate(
|
| 17 |
seeds: List[str],
|
| 18 |
+
scaffolds: List[str],
|
| 19 |
samples: List[str],
|
| 20 |
+
n_cols: int = 5,
|
| 21 |
size=(140, 200),
|
| 22 |
) -> str:
|
| 23 |
"""
|
|
|
|
| 35 |
result = defaultdict(list)
|
| 36 |
result.update(
|
| 37 |
{
|
| 38 |
+
"SMILES": seeds + scaffolds + samples,
|
| 39 |
"Name": [f"Seed_{i}" for i in range(len(seeds))]
|
| 40 |
+
+ [f"Scaffold_{i}" for i in range(len(scaffolds))]
|
| 41 |
+ [f"Generated_{i}" for i in range(len(samples))],
|
| 42 |
},
|
| 43 |
)
|