File size: 1,554 Bytes
e3eae4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import Descriptors
from torch_geometric.data import Data
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from ogb.lsc import PCQM4Mv2Evaluator
from ogb.utils import smiles2graph
from torch_geometric.loader import DataLoader

def compute_rdkit_features(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError("Invalid SMILES")
    return [
        Descriptors.MolWt(mol),
        Descriptors.NumRotatableBonds(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.NumHDonors(mol),
        Descriptors.RingCount(mol)
    ]

def smiles_to_data(smiles_list, device="cpu"):
    graph_list = []
    rdkit_list = []

    for smi in smiles_list:
        try:
            graph = smiles2graph(smi)
            rdkit_feats = compute_rdkit_features(smi)

            data = Data(
                x=torch.tensor(graph['node_feat'], dtype=torch.long),
                edge_index=torch.tensor(graph['edge_index'], dtype=torch.long),
                edge_attr=torch.tensor(graph['edge_feat'], dtype=torch.long),
                rdkit_feats=torch.tensor(rdkit_feats, dtype=torch.float32).unsqueeze(0),
                num_nodes=graph['num_nodes']
            )
            graph_list.append(data)
        except Exception as e:
            print(f"Error with SMILES '{smi}': {e}")
            continue

    return graph_list