Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| from src import const | |
| from src.molecule_builder import get_bond_order | |
| from scipy.stats import wasserstein_distance | |
| from pdb import set_trace | |
| def is_valid(mol): | |
| try: | |
| Chem.SanitizeMol(mol) | |
| except ValueError: | |
| return False | |
| return True | |
| def is_connected(mol): | |
| try: | |
| mol_frags = Chem.GetMolFrags(mol, asMols=True) | |
| except Chem.rdchem.AtomValenceException: | |
| return False | |
| if len(mol_frags) != 1: | |
| return False | |
| return True | |
| def get_valid_molecules(molecules): | |
| valid = [] | |
| for mol in molecules: | |
| if is_valid(mol): | |
| valid.append(mol) | |
| return valid | |
| def get_connected_molecules(molecules): | |
| connected = [] | |
| for mol in molecules: | |
| if is_connected(mol): | |
| connected.append(mol) | |
| return connected | |
| def get_unique_smiles(valid_molecules): | |
| unique = set() | |
| for mol in valid_molecules: | |
| unique.add(Chem.MolToSmiles(mol)) | |
| return list(unique) | |
| def get_novel_smiles(unique_true_smiles, unique_pred_smiles): | |
| return list(set(unique_pred_smiles).difference(set(unique_true_smiles))) | |
| def compute_energy(mol): | |
| mp = AllChem.MMFFGetMoleculeProperties(mol) | |
| energy = AllChem.MMFFGetMoleculeForceField(mol, mp, confId=0).CalcEnergy() | |
| return energy | |
| def wasserstein_distance_between_energies(true_molecules, pred_molecules): | |
| true_energy_dist = [] | |
| for mol in true_molecules: | |
| try: | |
| energy = compute_energy(mol) | |
| true_energy_dist.append(energy) | |
| except: | |
| continue | |
| pred_energy_dist = [] | |
| for mol in pred_molecules: | |
| try: | |
| energy = compute_energy(mol) | |
| pred_energy_dist.append(energy) | |
| except: | |
| continue | |
| if len(true_energy_dist) > 0 and len(pred_energy_dist) > 0: | |
| return wasserstein_distance(true_energy_dist, pred_energy_dist) | |
| else: | |
| return 0 | |
| def compute_metrics(pred_molecules, true_molecules): | |
| if len(pred_molecules) == 0: | |
| return { | |
| 'validity': 0, | |
| 'validity_and_connectivity': 0, | |
| 'validity_as_in_delinker': 0, | |
| 'uniqueness': 0, | |
| 'novelty': 0, | |
| 'energies': 0, | |
| } | |
| # Passing rdkit.Chem.Sanitize filter | |
| true_valid = get_valid_molecules(true_molecules) | |
| pred_valid = get_valid_molecules(pred_molecules) | |
| validity = len(pred_valid) / len(pred_molecules) | |
| # Checking if molecule consists of a single connected part | |
| true_valid_and_connected = get_connected_molecules(true_valid) | |
| pred_valid_and_connected = get_connected_molecules(pred_valid) | |
| validity_and_connectivity = len(pred_valid_and_connected) / len(pred_molecules) | |
| # Unique molecules | |
| true_unique = get_unique_smiles(true_valid_and_connected) | |
| pred_unique = get_unique_smiles(pred_valid_and_connected) | |
| uniqueness = len(pred_unique) / len(pred_valid_and_connected) if len(pred_valid_and_connected) > 0 else 0 | |
| # Novel molecules | |
| pred_novel = get_novel_smiles(true_unique, pred_unique) | |
| novelty = len(pred_novel) / len(pred_unique) if len(pred_unique) > 0 else 0 | |
| # Difference between Energy distributions | |
| energies = wasserstein_distance_between_energies(true_valid_and_connected, pred_valid_and_connected) | |
| return { | |
| 'validity': validity, | |
| 'validity_and_connectivity': validity_and_connectivity, | |
| 'uniqueness': uniqueness, | |
| 'novelty': novelty, | |
| 'energies': energies, | |
| } | |
| # def check_stability(positions, atom_types): | |
| # assert len(positions.shape) == 2 | |
| # assert positions.shape[1] == 3 | |
| # x = positions[:, 0] | |
| # y = positions[:, 1] | |
| # z = positions[:, 2] | |
| # | |
| # nr_bonds = np.zeros(len(x), dtype='int') | |
| # for i in range(len(x)): | |
| # for j in range(i + 1, len(x)): | |
| # p1 = np.array([x[i], y[i], z[i]]) | |
| # p2 = np.array([x[j], y[j], z[j]]) | |
| # dist = np.sqrt(np.sum((p1 - p2) ** 2)) | |
| # atom1, atom2 = const.IDX2ATOM[atom_types[i].item()], const.IDX2ATOM[atom_types[j].item()] | |
| # order = get_bond_order(atom1, atom2, dist) | |
| # nr_bonds[i] += order | |
| # nr_bonds[j] += order | |
| # nr_stable_bonds = 0 | |
| # for atom_type_i, nr_bonds_i in zip(atom_types, nr_bonds): | |
| # possible_bonds = const.ALLOWED_BONDS[const.IDX2ATOM[atom_type_i.item()]] | |
| # if type(possible_bonds) == int: | |
| # is_stable = possible_bonds == nr_bonds_i | |
| # else: | |
| # is_stable = nr_bonds_i in possible_bonds | |
| # nr_stable_bonds += int(is_stable) | |
| # | |
| # molecule_stable = nr_stable_bonds == len(x) | |
| # return molecule_stable, nr_stable_bonds, len(x) | |
| # | |
| # | |
| # def count_stable_molecules(one_hot, x, node_mask): | |
| # stable_molecules = 0 | |
| # for i in range(len(one_hot)): | |
| # mol_size = node_mask[i].sum() | |
| # atom_types = one_hot[i][:mol_size, :].argmax(dim=1).detach().cpu() | |
| # positions = x[i][:mol_size, :].detach().cpu() | |
| # stable, _, _ = check_stability(positions, atom_types) | |
| # stable_molecules += int(stable) | |
| # | |
| # return stable_molecules | |