Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import shutil, os | |
| import os.path as osp | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| from torch_geometric.data import Data | |
| from torch.autograd import Variable | |
| from rdkit import Chem | |
| from data.features import ( | |
| allowable_features, | |
| atom_to_feature_vector, | |
| bond_to_feature_vector, | |
| atom_feature_vector_to_dict, | |
| bond_feature_vector_to_dict, | |
| ) | |
| from utils.data_util import one_hot_vector_sm, one_hot_vector_am, get_atom_feature_dims | |
| def load_dataset( | |
| cross_val, binary_task, target, args, use_prot=False, advs=False, test=False | |
| ): | |
| """ | |
| Load data and return data in dataframes format for each split and the loader of each split. | |
| Args: | |
| cross_val (int): Data partition being used [1-4]. | |
| binary_tast (boolean): Whether to perform binary classification or multiclass classification. | |
| target (string): Name of the protein target for binary classification. | |
| args (parser): Complete arguments (configuration) of the model. | |
| use_prot (boolean): Whether to use the PM module. | |
| advs (boolean): Whether to train the LM module with adversarial augmentations. | |
| test (boolean): Whether the model is being tested or trained. | |
| Return: | |
| train (loader): Training loader | |
| valid (loader): Validation loader | |
| test (loader): Test loader | |
| data_train (dataframe): Training data dataframe | |
| data_valid (dataframe): Validation data dataframe | |
| data_test (dataframe): Test data dataframe | |
| """ | |
| # TODO: NO QUEREMOS QUE ESTÉ LA PARTICIÓN DEL MULTICLASE? | |
| # Read all data files | |
| if not test: | |
| # Verify cross validation partition is defined | |
| assert cross_val in [1, 2, 3, 4], "{} data partition is not defined".format( | |
| cross_val | |
| ) | |
| print("Loading data...") | |
| if binary_task: | |
| path = "data/datasets/AD/" | |
| A = pd.read_csv( | |
| path + "Smiles_AD_1.csv", names=["Smiles", "Target", "Label"] | |
| ) | |
| B = pd.read_csv( | |
| path + "Smiles_AD_2.csv", names=["Smiles", "Target", "Label"] | |
| ) | |
| C = pd.read_csv( | |
| path + "Smiles_AD_3.csv", names=["Smiles", "Target", "Label"] | |
| ) | |
| D = pd.read_csv( | |
| path + "Smiles_AD_4.csv", names=["Smiles", "Target", "Label"] | |
| ) | |
| data_test = pd.read_csv( | |
| path + "AD_Test.csv", names=["Smiles", "Target", "Label"] | |
| ) | |
| if use_prot: | |
| data_target = pd.read_csv( | |
| path + "Targets_Fasta.csv", names=["Fasta", "Target", "Label"] | |
| ) | |
| else: | |
| data_target = [] | |
| # Generate train and validation splits according to cross validation number | |
| if cross_val == 1: | |
| data_train = pd.concat([A, B, C], ignore_index=True) | |
| data_val = D | |
| elif cross_val == 2: | |
| data_train = pd.concat([A, C, D], ignore_index=True) | |
| data_val = B | |
| elif cross_val == 3: | |
| data_train = pd.concat([A, B, D], ignore_index=True) | |
| data_val = C | |
| elif cross_val == 4: | |
| data_train = pd.concat([B, C, D], ignore_index=True) | |
| data_val = A | |
| # If in binary classification select data for the specific target being train | |
| if binary_task: | |
| data_train = data_train[data_train.Target == target] | |
| data_val = data_val[data_val.Target == target] | |
| data_test = data_test[data_test.Target == target] | |
| if use_prot: | |
| data_target = data_target[data_target.Target == target] | |
| # Get dataset for each split | |
| train = get_dataset(data_train, use_prot, data_target, args, advs) | |
| valid = get_dataset(data_val, use_prot, data_target, args) | |
| test = get_dataset(data_test, use_prot, data_target, args) | |
| else: | |
| # Read test data file | |
| if binary_task: | |
| path = "data/datasets/AD/" | |
| data_test = pd.read_csv( | |
| path + "Smiles_AD_Test.csv", names=["Smiles", "Target", "Label"] | |
| ) | |
| data_test = data_test[data_test.Target == target] | |
| if use_prot: | |
| data_target = pd.read_csv( | |
| path + "Targets_Fasta.csv", names=["Fasta", "Target", "Label"] | |
| ) | |
| data_target = data_target[data_target.Target == target] | |
| else: | |
| data_target = [] | |
| test = get_dataset(data_test,target=data_target, use_prot=use_prot, args=args, advs=advs, saliency=args.saliency) | |
| train = [] | |
| valid = [] | |
| data_train = [] | |
| data_val = [] | |
| print("Done.") | |
| return train, valid, test, data_train, data_val, data_test | |
| def reload_dataset(cross_val, binary_task, target, args, advs=False): | |
| print("Reloading data") | |
| args.edge_dict = {} | |
| if binary_task: | |
| path = "data/datasets/AD/" | |
| A = pd.read_csv(path + "Smiles_AD_1.csv", names=["Smiles", "Target", "Label"]) | |
| B = pd.read_csv(path + "Smiles_AD_2.csv", names=["Smiles", "Target", "Label"]) | |
| C = pd.read_csv(path + "Smiles_AD_3.csv", names=["Smiles", "Target", "Label"]) | |
| D = pd.read_csv(path + "Smiles_AD_4.csv", names=["Smiles", "Target", "Label"]) | |
| data_test = pd.read_csv( | |
| path + "AD_Test.csv", names=["Smiles", "Target", "Label"] | |
| ) | |
| if cross_val == 1: | |
| data_train = pd.concat([A, B, C], ignore_index=True) | |
| elif cross_val == 2: | |
| data_train = pd.concat([A, C, D], ignore_index=True) | |
| elif cross_val == 3: | |
| data_train = pd.concat([A, B, D], ignore_index=True) | |
| else: | |
| data_train = pd.concat([B, C, D], ignore_index=True) | |
| if binary_task: | |
| data_train = data_train[data_train.Target == target] | |
| train = get_dataset(data_train, args=args, advs=advs) | |
| print("Done.") | |
| return train, data_train | |
| def smiles_to_graph(smiles_string, is_prot=False, received_mol=False, saliency=False): | |
| """ | |
| Converts SMILES string to graph Data object | |
| :input: SMILES string (str) | |
| :return: graph object | |
| """ | |
| if not is_prot: | |
| mol = Chem.MolFromSmiles(smiles_string) | |
| else: | |
| mol = Chem.MolFromFASTA(smiles_string) | |
| # atoms | |
| atom_features_list = [] | |
| atom_feat_dims = get_atom_feature_dims() | |
| for atom in mol.GetAtoms(): | |
| ftrs = atom_to_feature_vector(atom) | |
| if saliency: | |
| ftrs_oh = one_hot_vector_am(ftrs, atom_feat_dims) | |
| atom_features_list.append(torch.unsqueeze(ftrs_oh, 0)) | |
| else: | |
| atom_features_list.append(ftrs) | |
| if saliency: | |
| x = torch.cat(atom_features_list) | |
| else: | |
| x = np.array(atom_features_list, dtype=np.int64) | |
| # bonds | |
| num_bond_features = 3 # bond type, bond stereo, is_conjugated | |
| if len(mol.GetBonds()) > 0: # mol has bonds | |
| edges_list = [] | |
| edge_features_list = [] | |
| for bond in mol.GetBonds(): | |
| i = bond.GetBeginAtomIdx() | |
| j = bond.GetEndAtomIdx() | |
| edge_feature = bond_to_feature_vector(bond) | |
| # add edges in both directions | |
| edges_list.append((i, j)) | |
| edge_features_list.append(edge_feature) | |
| edges_list.append((j, i)) | |
| edge_features_list.append(edge_feature) | |
| # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] | |
| edge_index = np.array(edges_list, dtype=np.int64).T | |
| # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] | |
| edge_attr = np.array(edge_features_list, dtype=np.int64) | |
| else: # mol has no bonds | |
| edge_index = np.empty((2, 0), dtype=np.int64) | |
| edge_attr = np.empty((0, num_bond_features), dtype=np.int64) | |
| return edge_attr, edge_index, x | |
| def smiles_to_graph_advs( | |
| smiles_string, args, advs=False, received_mol=False, saliency=False | |
| ): | |
| """ | |
| Converts SMILES string to graph Data object | |
| :input: SMILES string (str) | |
| :return: graph object | |
| """ | |
| if not received_mol: | |
| mol = Chem.MolFromSmiles(smiles_string) | |
| else: | |
| mol = smiles_string | |
| # atoms | |
| atom_features_list = [] | |
| atom_feat_dims = get_atom_feature_dims() | |
| for atom in mol.GetAtoms(): | |
| ftrs = atom_to_feature_vector(atom) | |
| if saliency: | |
| ftrs_oh = one_hot_vector_am(ftrs, atom_feat_dims) | |
| atom_features_list.append(torch.unsqueeze(ftrs_oh, 0)) | |
| else: | |
| atom_features_list.append(ftrs) | |
| if saliency: | |
| x = torch.cat(atom_features_list) | |
| else: | |
| x = np.array(atom_features_list, dtype=np.int64) | |
| if advs: | |
| # bonds | |
| mol_edge_dict = {} | |
| num_bond_features = 3 # bond type, bond stereo, is_conjugated | |
| features_dim1 = torch.eye(5) | |
| features_dim2 = torch.eye(6) | |
| features_dim3 = torch.eye(2) | |
| if len(mol.GetBonds()) > 0: # mol has bonds | |
| edges_list = [] | |
| edge_features_list = [] | |
| for bond in mol.GetBonds(): | |
| i = bond.GetBeginAtomIdx() | |
| j = bond.GetEndAtomIdx() | |
| edge_feature = bond_to_feature_vector(bond) | |
| # add edges in both directions | |
| edges_list.append((i, j)) | |
| edges_list.append((j, i)) | |
| edge_feature_oh = one_hot_vector_sm( | |
| edge_feature, features_dim1, features_dim2, features_dim3 | |
| ) | |
| if advs: | |
| mol_edge_dict[(i, j)] = Variable( | |
| torch.tensor([1.0]), requires_grad=True | |
| ) | |
| # add edges in both directions | |
| edge_features_list.append( | |
| torch.unsqueeze(mol_edge_dict[(i, j)] * edge_feature_oh, 0) | |
| ) | |
| edge_features_list.append( | |
| torch.unsqueeze(mol_edge_dict[(i, j)] * edge_feature_oh, 0) | |
| ) | |
| else: | |
| # add edges in both directions | |
| edge_features_list.append(torch.unsqueeze(edge_feature_oh, 0)) | |
| edge_features_list.append(torch.unsqueeze(edge_feature_oh, 0)) | |
| if advs: | |
| # Update edge dict | |
| args.edge_dict[smiles_string] = mol_edge_dict | |
| # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] | |
| edge_index = np.array(edges_list, dtype=np.int64).T | |
| # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] | |
| edge_attr = torch.cat(edge_features_list) | |
| else: # mol has no bonds | |
| edge_index = np.empty((2, 0), dtype=np.int64) | |
| edge_attr = np.empty((0, num_bond_features), dtype=np.int64) | |
| args.edge_dict[smiles_string] = {} | |
| return edge_attr, edge_index, x | |
| def get_dataset( | |
| dataset, use_prot=False, target=None, args=None, advs=False, saliency=False | |
| ): | |
| total_dataset = [] | |
| if use_prot: | |
| prot_graph = transform_molecule_pg( | |
| target["Fasta"].item(), label=None, is_prot=use_prot | |
| ) | |
| for mol, label in tqdm( | |
| zip(dataset["Smiles"], dataset["Label"]), total=len(dataset["Smiles"]) | |
| ): | |
| if use_prot: | |
| total_dataset.append([transform_molecule_pg(mol,label,args, advs, saliency=saliency),prot_graph]) | |
| else: | |
| total_dataset.append( | |
| transform_molecule_pg(mol, label, args, advs, saliency=saliency) | |
| ) | |
| return total_dataset | |
| def get_perturbed_dataset(mols, labels, args): | |
| total_dataset = [] | |
| for mol, label in zip(mols, labels): | |
| total_dataset.append(transform_molecule_pg(mol, label, args, received_mol=True)) | |
| return total_dataset | |
| def transform_molecule_pg( | |
| smiles, | |
| label, | |
| args=None, | |
| advs=False, | |
| received_mol=False, | |
| saliency=False, | |
| is_prot=False, | |
| ): | |
| if is_prot: | |
| edge_attr_p, edge_index_p, x_p = smiles_to_graph(smiles, is_prot) | |
| x_p = torch.tensor(x_p) | |
| edge_index_p = torch.tensor(edge_index_p) | |
| edge_attr_p = torch.tensor(edge_attr_p) | |
| return Data(edge_attr=edge_attr_p, edge_index=edge_index_p, x=x_p) | |
| else: | |
| if args.advs or received_mol: | |
| if advs or received_mol: | |
| edge_attr, edge_index, x = smiles_to_graph_advs( | |
| smiles, | |
| args, | |
| advs=True, | |
| received_mol=received_mol, | |
| saliency=saliency, | |
| ) | |
| else: | |
| edge_attr, edge_index, x = smiles_to_graph_advs( | |
| smiles, args, received_mol=received_mol, saliency=saliency | |
| ) | |
| else: | |
| edge_attr, edge_index, x = smiles_to_graph(smiles, saliency=saliency) | |
| if not saliency: | |
| x = torch.tensor(x) | |
| y = torch.tensor([label]) | |
| edge_index = torch.tensor(edge_index) | |
| if not args.advs and not received_mol: | |
| edge_attr = torch.tensor(edge_attr) | |
| if received_mol: | |
| mol = smiles | |
| else: | |
| mol = Chem.MolFromSmiles(smiles) | |
| return Data( | |
| edge_attr=edge_attr, edge_index=edge_index, x=x, y=y, mol=mol, smiles=smiles | |
| ) | |