libokj commited on
Commit
a368343
·
1 Parent(s): 786ef34

Delete deepscreen/data/featurizers/monn.py

Browse files
Files changed (1) hide show
  1. deepscreen/data/featurizers/monn.py +0 -106
deepscreen/data/featurizers/monn.py DELETED
@@ -1,106 +0,0 @@
1
- import numpy as np
2
- from rdkit.Chem import MolFromSmiles
3
-
4
- from deepscreen.data.featurizers.categorical import FASTA_VOCAB, fasta_to_label
5
- from deepscreen.data.featurizers.graph import atom_features, bond_features
6
-
7
-
8
- def get_mask(arr):
9
- a = np.zeros(1, len(arr))
10
- a[1, :arr.shape[0]] = 1
11
- return a
12
-
13
-
14
- def add_index(input_array, ebd_size):
15
- batch_size, n_vertex, n_nbs = np.shape(input_array)
16
- add_idx = np.array(range(0, ebd_size * batch_size, ebd_size) * (n_nbs * n_vertex))
17
- add_idx = np.transpose(add_idx.reshape(-1, batch_size))
18
- add_idx = add_idx.reshape(-1)
19
- new_array = input_array.reshape(-1) + add_idx
20
- return new_array
21
-
22
-
23
- # TODO fix padding and masking
24
- def drug_featurizer(smiles, max_neighbors=6):
25
- mol = MolFromSmiles(smiles)
26
-
27
- # convert molecule to GNN input
28
- n_atoms = mol.GetNumAtoms()
29
- assert mol.GetNumBonds() >= 0
30
-
31
- n_bonds = max(mol.GetNumBonds(), 1)
32
- feat_atoms = np.zeros((n_atoms,)) # atom feature ID
33
- feat_bonds = np.zeros((n_bonds,)) # bond feature ID
34
- atom_adj = np.zeros((n_atoms, max_neighbors))
35
- bond_adj = np.zeros((n_atoms, max_neighbors))
36
- n_neighbors = np.zeros((n_atoms,))
37
- neighbor_mask = np.zeros((n_atoms, max_neighbors))
38
-
39
- for atom in mol.GetAtoms():
40
- idx = atom.GetIdx()
41
- feat_atoms[idx] = atom_features(atom)
42
-
43
- for bond in mol.GetBonds():
44
- a1 = bond.GetBeginAtom().GetIdx()
45
- a2 = bond.GetEndAtom().GetIdx()
46
- idx = bond.GetIdx()
47
- feat_bonds[idx] = bond_features(bond)
48
- try:
49
- atom_adj[a1, n_neighbors[a1]] = a2
50
- atom_adj[a2, n_neighbors[a2]] = a1
51
- except:
52
- return [], [], [], [], []
53
- bond_adj[a1, n_neighbors[a1]] = idx
54
- bond_adj[a2, n_neighbors[a2]] = idx
55
- n_neighbors[a1] += 1
56
- n_neighbors[a2] += 1
57
-
58
- for i in range(len(n_neighbors)):
59
- neighbor_mask[i, :n_neighbors[i]] = 1
60
-
61
- vertex_mask = get_mask(feat_atoms)
62
- # vertex = pack_1d(feat_atoms)
63
- # edge = pack_1d(feat_bonds)
64
- # atom_adj = pack_2d(atom_adj)
65
- # bond_adj = pack_2d(bond_adj)
66
- # nbs_mask = pack_2d(n_neighbors_mat)
67
-
68
- atom_adj = add_index(atom_adj, np.shape(atom_adj)[1])
69
- bond_adj = add_index(bond_adj, np.shape(feat_bonds)[1])
70
-
71
- return vertex_mask, feat_atoms, feat_bonds, atom_adj, bond_adj, neighbor_mask
72
-
73
-
74
- # TODO WIP the pairwise_label matrix probably should be generated beforehand and stored as an extra label in the dataset
75
- def get_pairwise_label(pdbid, interaction_dict, mol):
76
- if pdbid in interaction_dict:
77
- sdf_element = np.array([atom.GetSymbol().upper() for atom in mol.GetAtoms()])
78
- atom_element = np.array(interaction_dict[pdbid]['atom_element'], dtype=str)
79
- atom_name_list = np.array(interaction_dict[pdbid]['atom_name'], dtype=str)
80
- atom_interact = np.array(interaction_dict[pdbid]['atom_interact'], dtype=int)
81
- nonH_position = np.where(atom_element != 'H')[0]
82
- assert sum(atom_element[nonH_position] != sdf_element) == 0
83
-
84
- atom_name_list = atom_name_list[nonH_position].tolist()
85
- pairwise_mat = np.zeros((len(nonH_position), len(interaction_dict[pdbid]['uniprot_seq'])), dtype=np.int32)
86
- for atom_name, bond_type in interaction_dict[pdbid]['atom_bond_type']:
87
- atom_idx = atom_name_list.index(str(atom_name))
88
- assert atom_idx < len(nonH_position)
89
-
90
- seq_idx_list = []
91
- for seq_idx, bond_type_seq in interaction_dict[pdbid]['residue_bond_type']:
92
- if bond_type == bond_type_seq:
93
- seq_idx_list.append(seq_idx)
94
- pairwise_mat[atom_idx, seq_idx] = 1
95
- if len(np.where(pairwise_mat != 0)[0]) != 0:
96
- pairwise_mask = True
97
- return True, pairwise_mat
98
- return False, np.zeros((1, 1))
99
-
100
-
101
- def protein_featurizer(fasta):
102
- sequence = fasta_to_label(fasta)
103
- # pad proteins and make masks
104
- seq_mask = get_mask(sequence)
105
-
106
- return seq_mask, sequence