|
import torch |
|
|
|
from .model_arch import EGNNDiffusionModel, OlfactoryConditioner |
|
from .utils import load_goodscents_subset, validate_molecule, sample, sample_batch, smiles_to_graph |
|
from .train import train |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
SHOULD_BATCH: bool = False |
|
|
|
|
|
smiles_list, label_map, label_names = load_goodscents_subset( |
|
filepath="../data/leffingwell-goodscent-merge-dataset.csv", |
|
index=500, |
|
shuffle=True) |
|
num_labels = len(label_names) |
|
dataset = [] |
|
|
|
|
|
for smi in smiles_list: |
|
g = smiles_to_graph(smi) |
|
if g: |
|
g.y = torch.tensor(label_map[smi]) |
|
dataset.append(g) |
|
|
|
|
|
model = EGNNDiffusionModel(node_dim=1, embed_dim=8) |
|
conditioner = OlfactoryConditioner(num_labels=num_labels, embed_dim=8) |
|
|
|
|
|
train(model, conditioner, dataset, epochs=500) |
|
|
|
|
|
test_label_vec = torch.zeros(num_labels) |
|
if "floral" in label_names: |
|
test_label_vec[label_names.index("floral")] = 0 |
|
if "fruity" in label_names: |
|
test_label_vec[label_names.index("fruity")] = 1 |
|
if "musky" in label_names: |
|
test_label_vec[label_names.index("musky")] = 0 |
|
|
|
model.eval() |
|
conditioner.eval() |
|
|
|
if SHOULD_BATCH: |
|
new_smiles_list = sample_batch(model, conditioner, label_vec=test_label_vec) |
|
for new_smiles in new_smiles_list: |
|
print(new_smiles) |
|
valid, props = validate_molecule(new_smiles) |
|
print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}") |
|
else: |
|
new_smiles = sample(model, conditioner, label_vec=test_label_vec) |
|
print(new_smiles) |
|
valid, props = validate_molecule(new_smiles) |
|
print(f"Generated SMILES: {new_smiles}\nValid: {valid}, Properties: {props}") |