Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import pickle | |
| import torch | |
| from rdkit import Chem | |
| from torch.utils.data import Dataset, DataLoader | |
| from tqdm import tqdm | |
| from src import const | |
| from pdb import set_trace | |
| def read_sdf(sdf_path): | |
| with Chem.SDMolSupplier(sdf_path, sanitize=False) as supplier: | |
| for molecule in supplier: | |
| yield molecule | |
| def get_one_hot(atom, atoms_dict): | |
| one_hot = np.zeros(len(atoms_dict)) | |
| one_hot[atoms_dict[atom]] = 1 | |
| return one_hot | |
| def parse_molecule(mol, is_geom): | |
| one_hot = [] | |
| charges = [] | |
| atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX | |
| charges_dict = const.GEOM_CHARGES if is_geom else const.CHARGES | |
| for atom in mol.GetAtoms(): | |
| one_hot.append(get_one_hot(atom.GetSymbol(), atom2idx)) | |
| charges.append(charges_dict[atom.GetSymbol()]) | |
| positions = mol.GetConformer().GetPositions() | |
| return positions, np.array(one_hot), np.array(charges) | |
| class ZincDataset(Dataset): | |
| def __init__(self, data_path, prefix, device): | |
| dataset_path = os.path.join(data_path, f'{prefix}.pt') | |
| if os.path.exists(dataset_path): | |
| self.data = torch.load(dataset_path, map_location=device) | |
| else: | |
| print(f'Preprocessing dataset with prefix {prefix}') | |
| self.data = ZincDataset.preprocess(data_path, prefix, device) | |
| torch.save(self.data, dataset_path) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, item): | |
| return self.data[item] | |
| def preprocess(data_path, prefix, device): | |
| data = [] | |
| table_path = os.path.join(data_path, f'{prefix}_table.csv') | |
| fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf') | |
| linkers_path = os.path.join(data_path, f'{prefix}_link.sdf') | |
| is_geom = ('geom' in prefix) or ('MOAD' in prefix) | |
| is_multifrag = 'multifrag' in prefix | |
| table = pd.read_csv(table_path) | |
| generator = tqdm(zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path)), total=len(table)) | |
| for (_, row), fragments, linker in generator: | |
| uuid = row['uuid'] | |
| name = row['molecule'] | |
| frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom) | |
| link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom) | |
| positions = np.concatenate([frag_pos, link_pos], axis=0) | |
| one_hot = np.concatenate([frag_one_hot, link_one_hot], axis=0) | |
| charges = np.concatenate([frag_charges, link_charges], axis=0) | |
| anchors = np.zeros_like(charges) | |
| if is_multifrag: | |
| for anchor_idx in map(int, row['anchors'].split('-')): | |
| anchors[anchor_idx] = 1 | |
| else: | |
| anchors[row['anchor_1']] = 1 | |
| anchors[row['anchor_2']] = 1 | |
| fragment_mask = np.concatenate([np.ones_like(frag_charges), np.zeros_like(link_charges)]) | |
| linker_mask = np.concatenate([np.zeros_like(frag_charges), np.ones_like(link_charges)]) | |
| data.append({ | |
| 'uuid': uuid, | |
| 'name': name, | |
| 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
| 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
| 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
| 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), | |
| 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'num_atoms': len(positions), | |
| }) | |
| return data | |
| class MOADDataset(Dataset): | |
| def __init__(self, data_path, prefix, device): | |
| prefix, pocket_mode = prefix.split('.') | |
| dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt') | |
| if os.path.exists(dataset_path): | |
| self.data = torch.load(dataset_path, map_location=device) | |
| else: | |
| print(f'Preprocessing dataset with prefix {prefix}') | |
| self.data = MOADDataset.preprocess(data_path, prefix, pocket_mode, device) | |
| torch.save(self.data, dataset_path) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, item): | |
| return self.data[item] | |
| def preprocess(data_path, prefix, pocket_mode, device): | |
| data = [] | |
| table_path = os.path.join(data_path, f'{prefix}_table.csv') | |
| fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf') | |
| linkers_path = os.path.join(data_path, f'{prefix}_link.sdf') | |
| pockets_path = os.path.join(data_path, f'{prefix}_pockets.pkl') | |
| is_geom = True | |
| is_multifrag = 'multifrag' in prefix | |
| with open(pockets_path, 'rb') as f: | |
| pockets = pickle.load(f) | |
| table = pd.read_csv(table_path) | |
| generator = tqdm( | |
| zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path), pockets), | |
| total=len(table) | |
| ) | |
| for (_, row), fragments, linker, pocket_data in generator: | |
| uuid = row['uuid'] | |
| name = row['molecule'] | |
| frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom) | |
| link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom) | |
| # Parsing pocket data | |
| pocket_pos = pocket_data[f'{pocket_mode}_coord'] | |
| pocket_one_hot = [] | |
| pocket_charges = [] | |
| for atom_type in pocket_data[f'{pocket_mode}_types']: | |
| pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) | |
| pocket_charges.append(const.GEOM_CHARGES[atom_type]) | |
| pocket_one_hot = np.array(pocket_one_hot) | |
| pocket_charges = np.array(pocket_charges) | |
| positions = np.concatenate([frag_pos, pocket_pos, link_pos], axis=0) | |
| one_hot = np.concatenate([frag_one_hot, pocket_one_hot, link_one_hot], axis=0) | |
| charges = np.concatenate([frag_charges, pocket_charges, link_charges], axis=0) | |
| anchors = np.zeros_like(charges) | |
| if is_multifrag: | |
| for anchor_idx in map(int, row['anchors'].split('-')): | |
| anchors[anchor_idx] = 1 | |
| else: | |
| anchors[row['anchor_1']] = 1 | |
| anchors[row['anchor_2']] = 1 | |
| fragment_only_mask = np.concatenate([ | |
| np.ones_like(frag_charges), | |
| np.zeros_like(pocket_charges), | |
| np.zeros_like(link_charges) | |
| ]) | |
| pocket_mask = np.concatenate([ | |
| np.zeros_like(frag_charges), | |
| np.ones_like(pocket_charges), | |
| np.zeros_like(link_charges) | |
| ]) | |
| linker_mask = np.concatenate([ | |
| np.zeros_like(frag_charges), | |
| np.zeros_like(pocket_charges), | |
| np.ones_like(link_charges) | |
| ]) | |
| fragment_mask = np.concatenate([ | |
| np.ones_like(frag_charges), | |
| np.ones_like(pocket_charges), | |
| np.zeros_like(link_charges) | |
| ]) | |
| data.append({ | |
| 'uuid': uuid, | |
| 'name': name, | |
| 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
| 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
| 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
| 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), | |
| 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'num_atoms': len(positions), | |
| }) | |
| return data | |
| def create_edges(positions, fragment_mask_only, linker_mask_only): | |
| ligand_mask = fragment_mask_only.astype(bool) | linker_mask_only.astype(bool) | |
| ligand_adj = ligand_mask[:, None] & ligand_mask[None, :] | |
| proximity_adj = np.linalg.norm(positions[:, None, :] - positions[None, :, :], axis=-1) <= 6 | |
| full_adj = ligand_adj | proximity_adj | |
| full_adj &= ~np.eye(len(positions)).astype(bool) | |
| curr_rows, curr_cols = np.where(full_adj) | |
| return [curr_rows, curr_cols] | |
| def collate(batch): | |
| out = {} | |
| # Filter out big molecules | |
| if 'pocket_mask' not in batch[0].keys(): | |
| batch = [data for data in batch if data['num_atoms'] <= 50] | |
| else: | |
| batch = [data for data in batch if data['num_atoms'] <= 1000] | |
| for i, data in enumerate(batch): | |
| for key, value in data.items(): | |
| out.setdefault(key, []).append(value) | |
| for key, value in out.items(): | |
| if key in const.DATA_LIST_ATTRS: | |
| continue | |
| if key in const.DATA_ATTRS_TO_PAD: | |
| out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0) | |
| continue | |
| raise Exception(f'Unknown batch key: {key}') | |
| atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT) | |
| out['atom_mask'] = atom_mask[:, :, None] | |
| batch_size, n_nodes = atom_mask.size() | |
| # In case of MOAD edge_mask is batch_idx | |
| if 'pocket_mask' in batch[0].keys(): | |
| batch_mask = torch.cat([ | |
| torch.ones(n_nodes, dtype=const.TORCH_INT) * i | |
| for i in range(batch_size) | |
| ]).to(atom_mask.device) | |
| out['edge_mask'] = batch_mask | |
| else: | |
| edge_mask = atom_mask[:, None, :] * atom_mask[:, :, None] | |
| diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=atom_mask.device).unsqueeze(0) | |
| edge_mask *= diag_mask | |
| out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1) | |
| for key in const.DATA_ATTRS_TO_ADD_LAST_DIM: | |
| if key in out.keys(): | |
| out[key] = out[key][:, :, None] | |
| return out | |
| def collate_with_fragment_edges(batch): | |
| out = {} | |
| # Filter out big molecules | |
| batch = [data for data in batch if data['num_atoms'] <= 50] | |
| for i, data in enumerate(batch): | |
| for key, value in data.items(): | |
| out.setdefault(key, []).append(value) | |
| for key, value in out.items(): | |
| if key in const.DATA_LIST_ATTRS: | |
| continue | |
| if key in const.DATA_ATTRS_TO_PAD: | |
| out[key] = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=0) | |
| continue | |
| raise Exception(f'Unknown batch key: {key}') | |
| frag_mask = out['fragment_mask'] | |
| edge_mask = frag_mask[:, None, :] * frag_mask[:, :, None] | |
| diag_mask = ~torch.eye(edge_mask.size(1), dtype=const.TORCH_INT, device=frag_mask.device).unsqueeze(0) | |
| edge_mask *= diag_mask | |
| batch_size, n_nodes = frag_mask.size() | |
| out['edge_mask'] = edge_mask.view(batch_size * n_nodes * n_nodes, 1) | |
| # Building edges and covalent bond values | |
| rows, cols, bonds = [], [], [] | |
| for batch_idx in range(batch_size): | |
| for i in range(n_nodes): | |
| for j in range(n_nodes): | |
| rows.append(i + batch_idx * n_nodes) | |
| cols.append(j + batch_idx * n_nodes) | |
| edges = [torch.LongTensor(rows).to(frag_mask.device), torch.LongTensor(cols).to(frag_mask.device)] | |
| out['edges'] = edges | |
| atom_mask = (out['fragment_mask'].bool() | out['linker_mask'].bool()).to(const.TORCH_INT) | |
| out['atom_mask'] = atom_mask[:, :, None] | |
| for key in const.DATA_ATTRS_TO_ADD_LAST_DIM: | |
| if key in out.keys(): | |
| out[key] = out[key][:, :, None] | |
| return out | |
| def get_dataloader(dataset, batch_size, collate_fn=collate, shuffle=False): | |
| return DataLoader(dataset, batch_size, collate_fn=collate_fn, shuffle=shuffle) | |
| def create_template(tensor, fragment_size, linker_size, fill=0): | |
| values_to_keep = tensor[:fragment_size] | |
| values_to_add = torch.ones(linker_size, tensor.shape[1], dtype=values_to_keep.dtype, device=values_to_keep.device) | |
| values_to_add = values_to_add * fill | |
| return torch.cat([values_to_keep, values_to_add], dim=0) | |
| def create_templates_for_linker_generation(data, linker_sizes): | |
| """ | |
| Takes data batch and new linker size and returns data batch where fragment-related data is the same | |
| but linker-related data is replaced with zero templates with new linker sizes | |
| """ | |
| decoupled_data = [] | |
| for i, linker_size in enumerate(linker_sizes): | |
| data_dict = {} | |
| fragment_mask = data['fragment_mask'][i].squeeze() | |
| fragment_size = fragment_mask.sum().int() | |
| for k, v in data.items(): | |
| if k == 'num_atoms': | |
| # Computing new number of atoms (fragment_size + linker_size) | |
| data_dict[k] = fragment_size + linker_size | |
| continue | |
| if k in const.DATA_LIST_ATTRS: | |
| # These attributes are written without modification | |
| data_dict[k] = v[i] | |
| continue | |
| if k in const.DATA_ATTRS_TO_PAD: | |
| # Should write fragment-related data + (zeros x linker_size) | |
| fill_value = 1 if k == 'linker_mask' else 0 | |
| template = create_template(v[i], fragment_size, linker_size, fill=fill_value) | |
| if k in const.DATA_ATTRS_TO_ADD_LAST_DIM: | |
| template = template.squeeze(-1) | |
| data_dict[k] = template | |
| decoupled_data.append(data_dict) | |
| return collate(decoupled_data) | |