Spaces:
Sleeping
Sleeping
Delete deepscreen/data/featurizers/monn.py
Browse files
deepscreen/data/featurizers/monn.py
DELETED
@@ -1,106 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from rdkit.Chem import MolFromSmiles
|
3 |
-
|
4 |
-
from deepscreen.data.featurizers.categorical import FASTA_VOCAB, fasta_to_label
|
5 |
-
from deepscreen.data.featurizers.graph import atom_features, bond_features
|
6 |
-
|
7 |
-
|
8 |
-
def get_mask(arr):
|
9 |
-
a = np.zeros(1, len(arr))
|
10 |
-
a[1, :arr.shape[0]] = 1
|
11 |
-
return a
|
12 |
-
|
13 |
-
|
14 |
-
def add_index(input_array, ebd_size):
|
15 |
-
batch_size, n_vertex, n_nbs = np.shape(input_array)
|
16 |
-
add_idx = np.array(range(0, ebd_size * batch_size, ebd_size) * (n_nbs * n_vertex))
|
17 |
-
add_idx = np.transpose(add_idx.reshape(-1, batch_size))
|
18 |
-
add_idx = add_idx.reshape(-1)
|
19 |
-
new_array = input_array.reshape(-1) + add_idx
|
20 |
-
return new_array
|
21 |
-
|
22 |
-
|
23 |
-
# TODO fix padding and masking
|
24 |
-
def drug_featurizer(smiles, max_neighbors=6):
|
25 |
-
mol = MolFromSmiles(smiles)
|
26 |
-
|
27 |
-
# convert molecule to GNN input
|
28 |
-
n_atoms = mol.GetNumAtoms()
|
29 |
-
assert mol.GetNumBonds() >= 0
|
30 |
-
|
31 |
-
n_bonds = max(mol.GetNumBonds(), 1)
|
32 |
-
feat_atoms = np.zeros((n_atoms,)) # atom feature ID
|
33 |
-
feat_bonds = np.zeros((n_bonds,)) # bond feature ID
|
34 |
-
atom_adj = np.zeros((n_atoms, max_neighbors))
|
35 |
-
bond_adj = np.zeros((n_atoms, max_neighbors))
|
36 |
-
n_neighbors = np.zeros((n_atoms,))
|
37 |
-
neighbor_mask = np.zeros((n_atoms, max_neighbors))
|
38 |
-
|
39 |
-
for atom in mol.GetAtoms():
|
40 |
-
idx = atom.GetIdx()
|
41 |
-
feat_atoms[idx] = atom_features(atom)
|
42 |
-
|
43 |
-
for bond in mol.GetBonds():
|
44 |
-
a1 = bond.GetBeginAtom().GetIdx()
|
45 |
-
a2 = bond.GetEndAtom().GetIdx()
|
46 |
-
idx = bond.GetIdx()
|
47 |
-
feat_bonds[idx] = bond_features(bond)
|
48 |
-
try:
|
49 |
-
atom_adj[a1, n_neighbors[a1]] = a2
|
50 |
-
atom_adj[a2, n_neighbors[a2]] = a1
|
51 |
-
except:
|
52 |
-
return [], [], [], [], []
|
53 |
-
bond_adj[a1, n_neighbors[a1]] = idx
|
54 |
-
bond_adj[a2, n_neighbors[a2]] = idx
|
55 |
-
n_neighbors[a1] += 1
|
56 |
-
n_neighbors[a2] += 1
|
57 |
-
|
58 |
-
for i in range(len(n_neighbors)):
|
59 |
-
neighbor_mask[i, :n_neighbors[i]] = 1
|
60 |
-
|
61 |
-
vertex_mask = get_mask(feat_atoms)
|
62 |
-
# vertex = pack_1d(feat_atoms)
|
63 |
-
# edge = pack_1d(feat_bonds)
|
64 |
-
# atom_adj = pack_2d(atom_adj)
|
65 |
-
# bond_adj = pack_2d(bond_adj)
|
66 |
-
# nbs_mask = pack_2d(n_neighbors_mat)
|
67 |
-
|
68 |
-
atom_adj = add_index(atom_adj, np.shape(atom_adj)[1])
|
69 |
-
bond_adj = add_index(bond_adj, np.shape(feat_bonds)[1])
|
70 |
-
|
71 |
-
return vertex_mask, feat_atoms, feat_bonds, atom_adj, bond_adj, neighbor_mask
|
72 |
-
|
73 |
-
|
74 |
-
# TODO WIP the pairwise_label matrix probably should be generated beforehand and stored as an extra label in the dataset
|
75 |
-
def get_pairwise_label(pdbid, interaction_dict, mol):
|
76 |
-
if pdbid in interaction_dict:
|
77 |
-
sdf_element = np.array([atom.GetSymbol().upper() for atom in mol.GetAtoms()])
|
78 |
-
atom_element = np.array(interaction_dict[pdbid]['atom_element'], dtype=str)
|
79 |
-
atom_name_list = np.array(interaction_dict[pdbid]['atom_name'], dtype=str)
|
80 |
-
atom_interact = np.array(interaction_dict[pdbid]['atom_interact'], dtype=int)
|
81 |
-
nonH_position = np.where(atom_element != 'H')[0]
|
82 |
-
assert sum(atom_element[nonH_position] != sdf_element) == 0
|
83 |
-
|
84 |
-
atom_name_list = atom_name_list[nonH_position].tolist()
|
85 |
-
pairwise_mat = np.zeros((len(nonH_position), len(interaction_dict[pdbid]['uniprot_seq'])), dtype=np.int32)
|
86 |
-
for atom_name, bond_type in interaction_dict[pdbid]['atom_bond_type']:
|
87 |
-
atom_idx = atom_name_list.index(str(atom_name))
|
88 |
-
assert atom_idx < len(nonH_position)
|
89 |
-
|
90 |
-
seq_idx_list = []
|
91 |
-
for seq_idx, bond_type_seq in interaction_dict[pdbid]['residue_bond_type']:
|
92 |
-
if bond_type == bond_type_seq:
|
93 |
-
seq_idx_list.append(seq_idx)
|
94 |
-
pairwise_mat[atom_idx, seq_idx] = 1
|
95 |
-
if len(np.where(pairwise_mat != 0)[0]) != 0:
|
96 |
-
pairwise_mask = True
|
97 |
-
return True, pairwise_mat
|
98 |
-
return False, np.zeros((1, 1))
|
99 |
-
|
100 |
-
|
101 |
-
def protein_featurizer(fasta):
|
102 |
-
sequence = fasta_to_label(fasta)
|
103 |
-
# pad proteins and make masks
|
104 |
-
seq_mask = get_mask(sequence)
|
105 |
-
|
106 |
-
return seq_mask, sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|