Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	refactor
Browse files- app.py +288 -0
- inference.py +303 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- src/data/__pycache__/utils.cpython-310.pyc +0 -0
- src/data/dataset.py +317 -0
- src/data/utils.py +143 -0
- src/model/__init__.py +0 -0
- src/model/__pycache__/__init__.cpython-310.pyc +0 -0
- src/model/__pycache__/layers.cpython-310.pyc +0 -0
- src/model/__pycache__/loss.cpython-310.pyc +0 -0
- src/model/__pycache__/models.cpython-310.pyc +0 -0
- src/model/layers.py +234 -0
- src/model/loss.py +85 -0
- src/model/models.py +269 -0
- src/util/__init__.py +0 -0
- src/util/__pycache__/__init__.cpython-310.pyc +0 -0
- src/util/__pycache__/smiles_cor.cpython-310.pyc +0 -0
- src/util/__pycache__/utils.cpython-310.pyc +0 -0
- src/util/smiles_cor.py +1284 -0
- src/util/utils.py +930 -0
- train.py +462 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,288 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from inference import Inference
         | 
| 3 | 
            +
            import PIL
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import pandas as pd
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            from rdkit import Chem
         | 
| 8 | 
            +
            from rdkit.Chem import Draw
         | 
| 9 | 
            +
            from rdkit.Chem.Draw import IPythonConsole
         | 
| 10 | 
            +
            import shutil
         | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import time
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            class DrugGENConfig:
         | 
| 15 | 
            +
                # Inference configuration
         | 
| 16 | 
            +
                submodel='DrugGEN'
         | 
| 17 | 
            +
                inference_model="experiments/models/DrugGEN/"
         | 
| 18 | 
            +
                sample_num=100
         | 
| 19 | 
            +
                disable_correction=False  # corresponds to correct=True in old config
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                # Data configuration
         | 
| 22 | 
            +
                inf_smiles='data/chembl_test.smi'  # corresponds to inf_raw_file in old config
         | 
| 23 | 
            +
                train_smiles='data/chembl_train.smi'
         | 
| 24 | 
            +
                train_drug_smiles='data/akt1_train.smi'
         | 
| 25 | 
            +
                inf_batch_size=1
         | 
| 26 | 
            +
                mol_data_dir='data'
         | 
| 27 | 
            +
                features=False
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                # Model configuration
         | 
| 30 | 
            +
                act='relu'
         | 
| 31 | 
            +
                max_atom=45
         | 
| 32 | 
            +
                dim=128
         | 
| 33 | 
            +
                depth=1
         | 
| 34 | 
            +
                heads=8
         | 
| 35 | 
            +
                mlp_ratio=3
         | 
| 36 | 
            +
                dropout=0.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # Seed configuration
         | 
| 39 | 
            +
                set_seed=True
         | 
| 40 | 
            +
                seed=10
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            class DrugGENAKT1Config(DrugGENConfig):
         | 
| 44 | 
            +
                submodel='DrugGEN'
         | 
| 45 | 
            +
                inference_model="experiments/models/DrugGEN-AKT1/"
         | 
| 46 | 
            +
                train_drug_smiles='data/akt1_train.smi'
         | 
| 47 | 
            +
                max_atom=45
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class DrugGENCDK2Config(DrugGENConfig):
         | 
| 51 | 
            +
                submodel='DrugGEN'
         | 
| 52 | 
            +
                inference_model="experiments/models/DrugGEN-CDK2/"
         | 
| 53 | 
            +
                train_drug_smiles='data/cdk2_train.smi'
         | 
| 54 | 
            +
                max_atom=38
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class NoTargetConfig(DrugGENConfig):
         | 
| 58 | 
            +
                submodel="NoTarget"
         | 
| 59 | 
            +
                inference_model="experiments/models/NoTarget/"
         | 
| 60 | 
            +
                train_drug_smiles='data/chembl_train.smi'  # No specific target, use general ChEMBL data
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            model_configs = {
         | 
| 64 | 
            +
                "DrugGEN-AKT1": DrugGENAKT1Config(),
         | 
| 65 | 
            +
                "DrugGEN-CDK2": DrugGENCDK2Config(),
         | 
| 66 | 
            +
                "DrugGEN-NoTarget": NoTargetConfig(),
         | 
| 67 | 
            +
            }
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def function(model_name: str, num_molecules: int, seed_num: int) -> tuple[PIL.Image, pd.DataFrame, str]:
         | 
| 72 | 
            +
                '''
         | 
| 73 | 
            +
                Returns:
         | 
| 74 | 
            +
                image, score_df, file path
         | 
| 75 | 
            +
                '''
         | 
| 76 | 
            +
                if model_name == "DrugGEN-NoTarget":
         | 
| 77 | 
            +
                    model_name = "NoTarget"
         | 
| 78 | 
            +
                
         | 
| 79 | 
            +
                config = model_configs[model_name]
         | 
| 80 | 
            +
                config.sample_num = num_molecules
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                if config.sample_num > 250:
         | 
| 83 | 
            +
                    raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                if seed_num is None or seed_num.strip() == "":
         | 
| 86 | 
            +
                    config.seed = random.randint(0, 10000)
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                    try:
         | 
| 89 | 
            +
                        config.seed = int(seed_num)
         | 
| 90 | 
            +
                    except ValueError:
         | 
| 91 | 
            +
                        raise gr.Error("The seed must be an integer value!")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
                inferer = Inference(config)
         | 
| 95 | 
            +
                start_time = time.time()
         | 
| 96 | 
            +
                scores = inferer.inference() # create scores_df out of this
         | 
| 97 | 
            +
                et = time.time() - start_time
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                score_df = pd.DataFrame({
         | 
| 100 | 
            +
                    "Runtime (seconds)": [et],
         | 
| 101 | 
            +
                    "Validity": [scores["validity"].iloc[0]],
         | 
| 102 | 
            +
                    "Uniqueness": [scores["uniqueness"].iloc[0]],
         | 
| 103 | 
            +
                    "Novelty (Train)": [scores["novelty"].iloc[0]],
         | 
| 104 | 
            +
                    "Novelty (Test)": [scores["novelty_test"].iloc[0]],
         | 
| 105 | 
            +
                    "Drug Novelty": [scores["drug_novelty"].iloc[0]],
         | 
| 106 | 
            +
                    "Max Length": [scores["max_len"].iloc[0]],
         | 
| 107 | 
            +
                    "Mean Atom Type": [scores["mean_atom_type"].iloc[0]],
         | 
| 108 | 
            +
                    "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
         | 
| 109 | 
            +
                    "SNN Drug": [scores["snn_drug"].iloc[0]],
         | 
| 110 | 
            +
                    "Internal Diversity": [scores["IntDiv"].iloc[0]],
         | 
| 111 | 
            +
                    "QED": [scores["qed"].iloc[0]],
         | 
| 112 | 
            +
                    "SA Score": [scores["sa"].iloc[0]]
         | 
| 113 | 
            +
                })
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt'
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                new_path = f'{model_name}_denovo_mols.smi'
         | 
| 118 | 
            +
                os.rename(output_file_path, new_path)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                with open(new_path) as f:
         | 
| 121 | 
            +
                    inference_drugs = f.read()
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                generated_molecule_list = inference_drugs.split("\n")[:-1]
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                rng = random.Random(config.seed)
         | 
| 126 | 
            +
                if num_molecules > 12:
         | 
| 127 | 
            +
                    selected_molecules = rng.choices(generated_molecule_list, k=12)
         | 
| 128 | 
            +
                else:
         | 
| 129 | 
            +
                    selected_molecules = generated_molecule_list
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
                selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None]
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
         | 
| 134 | 
            +
                drawOptions.prepareMolsBeforeDrawing = False
         | 
| 135 | 
            +
                drawOptions.bondLineWidth = 0.5
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                molecule_image = Draw.MolsToGridImage(
         | 
| 138 | 
            +
                    selected_molecules,
         | 
| 139 | 
            +
                    molsPerRow=3,
         | 
| 140 | 
            +
                    subImgSize=(400, 400),
         | 
| 141 | 
            +
                    maxMols=len(selected_molecules),
         | 
| 142 | 
            +
                    # legends=None,
         | 
| 143 | 
            +
                    returnPNG=False,
         | 
| 144 | 
            +
                    drawOptions=drawOptions,
         | 
| 145 | 
            +
                    highlightAtomLists=None,
         | 
| 146 | 
            +
                    highlightBondLists=None,
         | 
| 147 | 
            +
                )
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                return molecule_image, score_df, new_path
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
         | 
| 154 | 
            +
                with gr.Row():
         | 
| 155 | 
            +
                    with gr.Column(scale=1):
         | 
| 156 | 
            +
                        gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
         | 
| 157 | 
            +
                        with gr.Row():
         | 
| 158 | 
            +
                            gr.Markdown("[](https://arxiv.org/abs/2302.07868)")
         | 
| 159 | 
            +
                            gr.Markdown("[](https://github.com/HUBioDataLab/DrugGEN)")
         | 
| 160 | 
            +
                        
         | 
| 161 | 
            +
                        with gr.Accordion("About DrugGEN Models", open=False):
         | 
| 162 | 
            +
                            gr.Markdown("""
         | 
| 163 | 
            +
            ## Model Variations
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            ### DrugGEN-AKT1
         | 
| 166 | 
            +
            This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749), a serine/threonine-protein kinase that plays a key role in regulating cell survival, metabolism, and growth. AKT1 is a significant target in cancer therapy, particularly for breast, colorectal, and ovarian cancers.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            The model learns from:
         | 
| 169 | 
            +
            - General drug-like molecules from ChEMBL database
         | 
| 170 | 
            +
            - Known AKT1 inhibitors
         | 
| 171 | 
            +
            - Maximum atom count: 45
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            ### DrugGEN-CDK2
         | 
| 174 | 
            +
            This model targets the human CDK2 protein (UniProt ID: P24941), a cyclin-dependent kinase involved in cell cycle regulation. CDK2 inhibitors are being investigated for treating various cancers, particularly those with dysregulated cell cycle control.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            The model learns from:
         | 
| 177 | 
            +
            - General drug-like molecules from ChEMBL database
         | 
| 178 | 
            +
            - Known CDK2 inhibitors
         | 
| 179 | 
            +
            - Maximum atom count: 38
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            ### DrugGEN-NoTarget
         | 
| 182 | 
            +
            This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for:
         | 
| 183 | 
            +
            - Exploring chemical space
         | 
| 184 | 
            +
            - Generating diverse scaffolds
         | 
| 185 | 
            +
            - Creating molecules with drug-like properties
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            ## How It Works
         | 
| 188 | 
            +
            DrugGEN uses a graph-based generative adversarial network (GAN) architecture where:
         | 
| 189 | 
            +
            1. The generator creates molecular graphs
         | 
| 190 | 
            +
            2. The discriminator evaluates them against real molecules
         | 
| 191 | 
            +
            3. The model learns to generate increasingly realistic and target-specific molecules
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
         | 
| 194 | 
            +
                            """)
         | 
| 195 | 
            +
                        
         | 
| 196 | 
            +
                        with gr.Accordion("Understanding the Metrics", open=False):
         | 
| 197 | 
            +
                            gr.Markdown("""
         | 
| 198 | 
            +
            ## Evaluation Metrics
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            ### Basic Metrics
         | 
| 201 | 
            +
            - **Validity**: Percentage of generated molecules that are chemically valid
         | 
| 202 | 
            +
            - **Uniqueness**: Percentage of unique molecules among valid ones
         | 
| 203 | 
            +
            - **Runtime**: Time taken to generate the requested molecules
         | 
| 204 | 
            +
             | 
| 205 | 
            +
            ### Novelty Metrics
         | 
| 206 | 
            +
            - **Novelty (Train)**: Percentage of molecules not found in the training set
         | 
| 207 | 
            +
            - **Novelty (Test)**: Percentage of molecules not found in the test set
         | 
| 208 | 
            +
            - **Drug Novelty**: Percentage of molecules not found in known drugs
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            ### Structural Metrics
         | 
| 211 | 
            +
            - **Max Length**: Maximum component length in the generated molecules
         | 
| 212 | 
            +
            - **Mean Atom Type**: Average distribution of atom types
         | 
| 213 | 
            +
            - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            ### Drug-likeness Metrics
         | 
| 216 | 
            +
            - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
         | 
| 217 | 
            +
            - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
            ### Similarity Metrics
         | 
| 220 | 
            +
            - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
         | 
| 221 | 
            +
            - **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs)
         | 
| 222 | 
            +
                            """)
         | 
| 223 | 
            +
                        
         | 
| 224 | 
            +
                        model_name = gr.Radio(
         | 
| 225 | 
            +
                            choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
         | 
| 226 | 
            +
                            value="DrugGEN-AKT1",
         | 
| 227 | 
            +
                            label="Select Target Model",
         | 
| 228 | 
            +
                            info="Choose which protein target or general model to use for molecule generation"
         | 
| 229 | 
            +
                        )
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                        num_molecules = gr.Slider(
         | 
| 232 | 
            +
                            minimum=10,
         | 
| 233 | 
            +
                            maximum=250,
         | 
| 234 | 
            +
                            value=100,
         | 
| 235 | 
            +
                            step=10,
         | 
| 236 | 
            +
                            label="Number of Molecules to Generate",
         | 
| 237 | 
            +
                            info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU.""
         | 
| 238 | 
            +
                        )
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        seed_num = gr.Textbox(
         | 
| 241 | 
            +
                            label="Random Seed (Optional)",
         | 
| 242 | 
            +
                            value="",
         | 
| 243 | 
            +
                            info="Set a specific seed for reproducible results, or leave empty for random generation"
         | 
| 244 | 
            +
                        )
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        submit_button = gr.Button(
         | 
| 247 | 
            +
                            value="Generate Molecules",
         | 
| 248 | 
            +
                            variant="primary",
         | 
| 249 | 
            +
                            size="lg"
         | 
| 250 | 
            +
                        )
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    with gr.Column(scale=2):
         | 
| 253 | 
            +
                        with gr.Tabs():
         | 
| 254 | 
            +
                            with gr.TabItem("Generated Molecules"):
         | 
| 255 | 
            +
                                image_output = gr.Image(
         | 
| 256 | 
            +
                                    label="Sample of Generated Molecules",
         | 
| 257 | 
            +
                                    elem_id="molecule_display"
         | 
| 258 | 
            +
                                )
         | 
| 259 | 
            +
                                file_download = gr.File(
         | 
| 260 | 
            +
                                    label="Download All Generated Molecules (SMILES format)",
         | 
| 261 | 
            +
                                )
         | 
| 262 | 
            +
                            
         | 
| 263 | 
            +
                            with gr.TabItem("Performance Metrics"):
         | 
| 264 | 
            +
                                scores_df = gr.Dataframe(
         | 
| 265 | 
            +
                                    label="Model Performance Metrics",
         | 
| 266 | 
            +
                                    headers=["Runtime (seconds)", "Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", 
         | 
| 267 | 
            +
                                            "Drug Novelty", "Max Length", "Mean Atom Type", "SNN ChEMBL", "SNN Drug", 
         | 
| 268 | 
            +
                                            "Internal Diversity", "QED", "SA Score"]
         | 
| 269 | 
            +
                                )
         | 
| 270 | 
            +
                        
         | 
| 271 | 
            +
                        with gr.Accordion("Generation Settings", open=False):
         | 
| 272 | 
            +
                            gr.Markdown("""
         | 
| 273 | 
            +
                            ## Technical Details
         | 
| 274 | 
            +
                            
         | 
| 275 | 
            +
                            - This demo runs on CPU which limits generation speed
         | 
| 276 | 
            +
                            - Generating 200 molecules takes approximately 6 minutes
         | 
| 277 | 
            +
                            - For faster generation or larger batches, run the model on GPU using our GitHub repository
         | 
| 278 | 
            +
                            - The model uses a graph-based representation of molecules
         | 
| 279 | 
            +
                            - Maximum atom count varies by model (AKT1: 45, CDK2: 38)
         | 
| 280 | 
            +
                            """)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                gr.Markdown("### Created by the HU BioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                submit_button.click(function, inputs=[model_name, num_molecules, seed_num], outputs=[image_output, scores_df, file_download], api_name="inference")
         | 
| 285 | 
            +
            #demo.queue(concurrency_count=1)
         | 
| 286 | 
            +
            demo.queue()
         | 
| 287 | 
            +
            demo.launch()
         | 
| 288 | 
            +
             | 
    	
        inference.py
    ADDED
    
    | @@ -0,0 +1,303 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import pickle
         | 
| 6 | 
            +
            import argparse
         | 
| 7 | 
            +
            import os.path as osp
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.utils.data
         | 
| 11 | 
            +
            from torch_geometric.loader import DataLoader
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import pandas as pd
         | 
| 14 | 
            +
            from tqdm import tqdm
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from rdkit import RDLogger, Chem
         | 
| 17 | 
            +
            from rdkit.Chem import QED, RDConfig
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
         | 
| 20 | 
            +
            import sascorer
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from src.util.utils import *
         | 
| 23 | 
            +
            from src.model.models import Generator
         | 
| 24 | 
            +
            from src.data.dataset import DruggenDataset
         | 
| 25 | 
            +
            from src.data.utils import get_encoders_decoders, load_molecules
         | 
| 26 | 
            +
            from src.model.loss import generator_loss
         | 
| 27 | 
            +
            from src.util.smiles_cor import smi_correct
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class Inference(object):
         | 
| 31 | 
            +
                """Inference class for DrugGEN."""
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def __init__(self, config):
         | 
| 34 | 
            +
                    if config.set_seed:
         | 
| 35 | 
            +
                        np.random.seed(config.seed)
         | 
| 36 | 
            +
                        random.seed(config.seed)
         | 
| 37 | 
            +
                        torch.manual_seed(config.seed)
         | 
| 38 | 
            +
                        torch.cuda.manual_seed_all(config.seed)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        torch.backends.cudnn.deterministic = True
         | 
| 41 | 
            +
                        torch.backends.cudnn.benchmark = False
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                        os.environ["PYTHONHASHSEED"] = str(config.seed)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        print(f'Using seed {config.seed}')
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    # Initialize configurations
         | 
| 50 | 
            +
                    self.submodel = config.submodel
         | 
| 51 | 
            +
                    self.inference_model = config.inference_model
         | 
| 52 | 
            +
                    self.sample_num = config.sample_num
         | 
| 53 | 
            +
                    self.disable_correction = config.disable_correction
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Data loader.
         | 
| 56 | 
            +
                    self.inf_smiles = config.inf_smiles  # SMILES containing text file for first dataset. 
         | 
| 57 | 
            +
                                                     # Write the full path to file.
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    inf_smiles_basename = osp.basename(self.inf_smiles)
         | 
| 60 | 
            +
                    
         | 
| 61 | 
            +
                    # Get the base name without extension and add max_atom to it
         | 
| 62 | 
            +
                    self.max_atom = config.max_atom  # Model is based on one-shot generation.
         | 
| 63 | 
            +
                    inf_smiles_base = os.path.splitext(inf_smiles_basename)[0]
         | 
| 64 | 
            +
                    
         | 
| 65 | 
            +
                    # Change extension from .smi to .pt and add max_atom to the filename
         | 
| 66 | 
            +
                    self.inf_dataset_file = f"{inf_smiles_base}{self.max_atom}.pt"
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.inf_batch_size = config.inf_batch_size
         | 
| 69 | 
            +
                    self.train_smiles = config.train_smiles
         | 
| 70 | 
            +
                    self.train_drug_smiles = config.train_drug_smiles
         | 
| 71 | 
            +
                    self.mol_data_dir = config.mol_data_dir  # Directory where the dataset files are stored.
         | 
| 72 | 
            +
                    self.dataset_name = self.inf_dataset_file.split(".")[0]
         | 
| 73 | 
            +
                    self.features = config.features  # Small model uses atom types as node features. (Boolean, False uses atom types only.)
         | 
| 74 | 
            +
                                                     # Additional node features can be added. Please check new_dataloarder.py Line 102.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # Get atom and bond encoders/decoders
         | 
| 77 | 
            +
                    self.atom_encoder, self.atom_decoder, self.bond_encoder, self.bond_decoder = get_encoders_decoders(
         | 
| 78 | 
            +
                        self.train_smiles,
         | 
| 79 | 
            +
                        self.train_drug_smiles,
         | 
| 80 | 
            +
                        self.max_atom
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self.inf_dataset = DruggenDataset(self.mol_data_dir,
         | 
| 84 | 
            +
                                                  self.inf_dataset_file,
         | 
| 85 | 
            +
                                                  self.inf_smiles,
         | 
| 86 | 
            +
                                                  self.max_atom,
         | 
| 87 | 
            +
                                                  self.features,
         | 
| 88 | 
            +
                                                  atom_encoder=self.atom_encoder,
         | 
| 89 | 
            +
                                                  atom_decoder=self.atom_decoder,
         | 
| 90 | 
            +
                                                  bond_encoder=self.bond_encoder,
         | 
| 91 | 
            +
                                                  bond_decoder=self.bond_decoder)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    self.inf_loader = DataLoader(self.inf_dataset,
         | 
| 94 | 
            +
                                             shuffle=True,
         | 
| 95 | 
            +
                                             batch_size=self.inf_batch_size,
         | 
| 96 | 
            +
                                             drop_last=True)  # PyG dataloader for the first GAN.
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.m_dim = len(self.atom_decoder) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension.
         | 
| 99 | 
            +
                    self.b_dim = len(self.bond_decoder) # Bond type dimension.
         | 
| 100 | 
            +
                    self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Model configurations.
         | 
| 103 | 
            +
                    self.act = config.act
         | 
| 104 | 
            +
                    self.dim = config.dim
         | 
| 105 | 
            +
                    self.depth = config.depth
         | 
| 106 | 
            +
                    self.heads = config.heads
         | 
| 107 | 
            +
                    self.mlp_ratio = config.mlp_ratio
         | 
| 108 | 
            +
                    self.dropout = config.dropout
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self.build_model()
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def build_model(self):
         | 
| 113 | 
            +
                    """Create generators and discriminators."""
         | 
| 114 | 
            +
                    self.G = Generator(self.act,
         | 
| 115 | 
            +
                                       self.vertexes,
         | 
| 116 | 
            +
                                       self.b_dim,
         | 
| 117 | 
            +
                                       self.m_dim,
         | 
| 118 | 
            +
                                       self.dropout,
         | 
| 119 | 
            +
                                       dim=self.dim,
         | 
| 120 | 
            +
                                       depth=self.depth,
         | 
| 121 | 
            +
                                       heads=self.heads,
         | 
| 122 | 
            +
                                       mlp_ratio=self.mlp_ratio)
         | 
| 123 | 
            +
                    self.G.to(self.device)
         | 
| 124 | 
            +
                    self.print_network(self.G, 'G')
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def print_network(self, model, name):
         | 
| 127 | 
            +
                    """Print out the network information."""
         | 
| 128 | 
            +
                    num_params = 0
         | 
| 129 | 
            +
                    for p in model.parameters():
         | 
| 130 | 
            +
                        num_params += p.numel() 
         | 
| 131 | 
            +
                    print(model)
         | 
| 132 | 
            +
                    print(name)
         | 
| 133 | 
            +
                    print("The number of parameters: {}".format(num_params))
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def restore_model(self, submodel, model_directory):
         | 
| 136 | 
            +
                    """Restore the trained generator and discriminator."""
         | 
| 137 | 
            +
                    print('Loading the model...')
         | 
| 138 | 
            +
                    G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
         | 
| 139 | 
            +
                    self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def inference(self):
         | 
| 142 | 
            +
                    # Load the trained generator.
         | 
| 143 | 
            +
                    self.restore_model(self.submodel, self.inference_model)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # smiles data for metrics calculation.
         | 
| 146 | 
            +
                    chembl_smiles = [line for line in open(self.train_smiles, 'r').read().splitlines()]
         | 
| 147 | 
            +
                    chembl_test = [line for line in open(self.inf_smiles, 'r').read().splitlines()]
         | 
| 148 | 
            +
                    drug_smiles = [line for line in open(self.train_drug_smiles, 'r').read().splitlines()]
         | 
| 149 | 
            +
                    drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
         | 
| 150 | 
            +
                    drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # Make directories if not exist.
         | 
| 154 | 
            +
                    if not os.path.exists("experiments/inference/{}".format(self.submodel)):
         | 
| 155 | 
            +
                        os.makedirs("experiments/inference/{}".format(self.submodel))
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    if not self.disable_correction:
         | 
| 158 | 
            +
                        correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    search_res = pd.DataFrame(columns=["submodel", "validity",
         | 
| 161 | 
            +
                                                       "uniqueness", "novelty",
         | 
| 162 | 
            +
                                                       "novelty_test", "drug_novelty",
         | 
| 163 | 
            +
                                                       "max_len", "mean_atom_type",
         | 
| 164 | 
            +
                                                       "snn_chembl", "snn_drug", "IntDiv", "qed", "sa"])
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    self.G.eval()
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    start_time = time.time()
         | 
| 169 | 
            +
                    metric_calc_dr = []
         | 
| 170 | 
            +
                    uniqueness_calc = []
         | 
| 171 | 
            +
                    real_smiles_snn = []
         | 
| 172 | 
            +
                    nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
         | 
| 173 | 
            +
                    generated_smiles = []
         | 
| 174 | 
            +
                    val_counter = 0
         | 
| 175 | 
            +
                    none_counter = 0
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # Inference mode
         | 
| 178 | 
            +
                    with torch.inference_mode():
         | 
| 179 | 
            +
                        pbar = tqdm(range(self.sample_num))
         | 
| 180 | 
            +
                        pbar.set_description('Inference mode for {} model started'.format(self.submodel))
         | 
| 181 | 
            +
                        for i, data in enumerate(self.inf_loader):
         | 
| 182 | 
            +
                            val_counter += 1
         | 
| 183 | 
            +
                            # Preprocess dataset 
         | 
| 184 | 
            +
                            _, a_tensor, x_tensor = load_molecules(
         | 
| 185 | 
            +
                                data=data, 
         | 
| 186 | 
            +
                                batch_size=self.inf_batch_size,
         | 
| 187 | 
            +
                                device=self.device,
         | 
| 188 | 
            +
                                b_dim=self.b_dim,
         | 
| 189 | 
            +
                                m_dim=self.m_dim,
         | 
| 190 | 
            +
                            )
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                            _, _, node_sample, edge_sample = self.G(a_tensor, x_tensor)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                            g_edges_hat_sample = torch.max(edge_sample, -1)[1]
         | 
| 195 | 
            +
                            g_nodes_hat_sample = torch.max(node_sample, -1)[1]
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                            fake_mol_g = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=False, file_name=self.dataset_name) 
         | 
| 198 | 
            +
                                    for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                            a_tensor_sample = torch.max(a_tensor, -1)[1]
         | 
| 201 | 
            +
                            x_tensor_sample = torch.max(x_tensor, -1)[1]
         | 
| 202 | 
            +
                            real_mols = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name) 
         | 
| 203 | 
            +
                                    for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                            inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g]
         | 
| 206 | 
            +
                            inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                            for molecules in inference_drugs:
         | 
| 209 | 
            +
                                if molecules is None:
         | 
| 210 | 
            +
                                    none_counter += 1
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                            for molecules in inference_drugs:
         | 
| 213 | 
            +
                                if molecules is not None:
         | 
| 214 | 
            +
                                    molecules = molecules.replace("*", "C")
         | 
| 215 | 
            +
                                    generated_smiles.append(molecules)
         | 
| 216 | 
            +
                                    uniqueness_calc.append(molecules)
         | 
| 217 | 
            +
                                    nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
         | 
| 218 | 
            +
                                    pbar.update(1)
         | 
| 219 | 
            +
                                metric_calc_dr.append(molecules)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                            real_smiles_snn.append(real_mols[0])
         | 
| 222 | 
            +
                            generation_number = len([x for x in metric_calc_dr if x is not None])
         | 
| 223 | 
            +
                            if generation_number == self.sample_num or none_counter == self.sample_num:
         | 
| 224 | 
            +
                                break
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    if not self.disable_correction:
         | 
| 227 | 
            +
                        correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
         | 
| 228 | 
            +
                        gen_smi = correct.correct_smiles_list(generated_smiles)
         | 
| 229 | 
            +
                    else:
         | 
| 230 | 
            +
                        gen_smi = generated_smiles
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    et = time.time() - start_time
         | 
| 233 | 
            +
                    
         | 
| 234 | 
            +
                    gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
         | 
| 235 | 
            +
                    real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if not self.disable_correction:
         | 
| 238 | 
            +
                        val = round(len(gen_smi)/self.sample_num, 3)
         | 
| 239 | 
            +
                    else: 
         | 
| 240 | 
            +
                        val = round(fraction_valid(gen_smi), 3)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    uniq = round(fraction_unique(gen_smi), 3)
         | 
| 243 | 
            +
                    nov = round(novelty(gen_smi, chembl_smiles), 3)
         | 
| 244 | 
            +
                    nov_test = round(novelty(gen_smi, chembl_test), 3)
         | 
| 245 | 
            +
                    drug_nov = round(novelty(gen_smi, drug_smiles), 3)
         | 
| 246 | 
            +
                    max_len = round(Metrics.max_component(gen_smi, self.vertexes), 3)
         | 
| 247 | 
            +
                    mean_atom = round(Metrics.mean_atom_type(nodes_sample), 3)
         | 
| 248 | 
            +
                    snn_chembl = round(average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)), 3)
         | 
| 249 | 
            +
                    snn_drug = round(average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)), 3)
         | 
| 250 | 
            +
                    int_div = round((internal_diversity(np.array(gen_vecs)))[0], 3)
         | 
| 251 | 
            +
                    qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
         | 
| 252 | 
            +
                    sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
         | 
| 255 | 
            +
                                    "uniqueness": [uniq], "novelty": [nov],
         | 
| 256 | 
            +
                                    "novelty_test": [nov_test], "drug_novelty": [drug_nov],
         | 
| 257 | 
            +
                                    "max_len": [max_len], "mean_atom_type": [mean_atom],
         | 
| 258 | 
            +
                                    "snn_chembl": [snn_chembl], "snn_drug": [snn_drug], 
         | 
| 259 | 
            +
                                    "IntDiv": [int_div], "qed": [qed], "sa": [sa]})
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    # Write generated SMILES to a temporary file for app.py to use
         | 
| 262 | 
            +
                    temp_file = f'{self.submodel}_denovo_mols.smi'
         | 
| 263 | 
            +
                    with open(temp_file, 'w') as f:
         | 
| 264 | 
            +
                        f.write("SMILES\n")
         | 
| 265 | 
            +
                        for smiles in gen_smi:
         | 
| 266 | 
            +
                            f.write(f"{smiles}\n")
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    return model_res
         | 
| 269 | 
            +
             | 
| 270 | 
            +
             | 
| 271 | 
            +
            if __name__=="__main__":
         | 
| 272 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                # Inference configuration.
         | 
| 275 | 
            +
                parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
         | 
| 276 | 
            +
                parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
         | 
| 277 | 
            +
                parser.add_argument('--sample_num', type=int, default=100, help='inference samples')
         | 
| 278 | 
            +
                parser.add_argument('--disable_correction', action='store_true', help='Disable SMILES correction')
         | 
| 279 | 
            +
               
         | 
| 280 | 
            +
                # Data configuration.
         | 
| 281 | 
            +
                parser.add_argument('--inf_smiles', type=str, required=True)
         | 
| 282 | 
            +
                parser.add_argument('--train_smiles', type=str, required=True)
         | 
| 283 | 
            +
                parser.add_argument('--train_drug_smiles', type=str, required=True)
         | 
| 284 | 
            +
                parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
         | 
| 285 | 
            +
                parser.add_argument('--mol_data_dir', type=str, default='data')
         | 
| 286 | 
            +
                parser.add_argument('--features', action='store_true', help='features dimension for nodes')
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                # Model configuration.
         | 
| 289 | 
            +
                parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
         | 
| 290 | 
            +
                parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
         | 
| 291 | 
            +
                parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
         | 
| 292 | 
            +
                parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
         | 
| 293 | 
            +
                parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
         | 
| 294 | 
            +
                parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
         | 
| 295 | 
            +
                parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                # Seed configuration.
         | 
| 298 | 
            +
                parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
         | 
| 299 | 
            +
                parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                config = parser.parse_args()
         | 
| 302 | 
            +
                inference = Inference(config)
         | 
| 303 | 
            +
                inference.inference()
         | 
    	
        src/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        src/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (150 Bytes). View file | 
|  | 
    	
        src/data/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        src/data/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (155 Bytes). View file | 
|  | 
    	
        src/data/__pycache__/dataset.cpython-310.pyc
    ADDED
    
    | Binary file (12.9 kB). View file | 
|  | 
    	
        src/data/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (4.75 kB). View file | 
|  | 
    	
        src/data/dataset.py
    ADDED
    
    | @@ -0,0 +1,317 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import os.path as osp
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            import pickle
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import pandas as pd
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from torch_geometric.data import Data, InMemoryDataset
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from rdkit import Chem, RDLogger
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from src.data.utils import label2onehot
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            RDLogger.DisableLog('rdApp.*') 
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class DruggenDataset(InMemoryDataset):
         | 
| 21 | 
            +
                def __init__(self, root, dataset_file, raw_files, max_atom, features, 
         | 
| 22 | 
            +
                             atom_encoder, atom_decoder, bond_encoder, bond_decoder,
         | 
| 23 | 
            +
                             transform=None, pre_transform=None, pre_filter=None):
         | 
| 24 | 
            +
                    """
         | 
| 25 | 
            +
                    Initialize the DruggenDataset with pre-loaded encoder/decoder dictionaries.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    Parameters:
         | 
| 28 | 
            +
                        root (str): Root directory.
         | 
| 29 | 
            +
                        dataset_file (str): Name of the processed dataset file.
         | 
| 30 | 
            +
                        raw_files (str): Path to the raw SMILES file.
         | 
| 31 | 
            +
                        max_atom (int): Maximum number of atoms allowed in a molecule.
         | 
| 32 | 
            +
                        features (bool): Whether to include additional node features.
         | 
| 33 | 
            +
                        atom_encoder (dict): Pre-loaded atom encoder dictionary.
         | 
| 34 | 
            +
                        atom_decoder (dict): Pre-loaded atom decoder dictionary.
         | 
| 35 | 
            +
                        bond_encoder (dict): Pre-loaded bond encoder dictionary.
         | 
| 36 | 
            +
                        bond_decoder (dict): Pre-loaded bond decoder dictionary.
         | 
| 37 | 
            +
                        transform, pre_transform, pre_filter: See PyG InMemoryDataset.
         | 
| 38 | 
            +
                    """
         | 
| 39 | 
            +
                    self.dataset_name = dataset_file.split(".")[0]
         | 
| 40 | 
            +
                    self.dataset_file = dataset_file
         | 
| 41 | 
            +
                    self.raw_files = raw_files
         | 
| 42 | 
            +
                    self.max_atom = max_atom
         | 
| 43 | 
            +
                    self.features = features
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    # Use the provided encoder/decoder mappings.
         | 
| 46 | 
            +
                    self.atom_encoder_m = atom_encoder
         | 
| 47 | 
            +
                    self.atom_decoder_m = atom_decoder
         | 
| 48 | 
            +
                    self.bond_encoder_m = bond_encoder
         | 
| 49 | 
            +
                    self.bond_decoder_m = bond_decoder
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.atom_num_types = len(atom_encoder)
         | 
| 52 | 
            +
                    self.bond_num_types = len(bond_encoder)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    super().__init__(root, transform, pre_transform, pre_filter)
         | 
| 55 | 
            +
                    path = osp.join(self.processed_dir, dataset_file)
         | 
| 56 | 
            +
                    self.data, self.slices = torch.load(path)
         | 
| 57 | 
            +
                    self.root = root
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                @property
         | 
| 60 | 
            +
                def processed_dir(self):
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    Returns the directory where processed dataset files are stored.
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    return self.root
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                @property
         | 
| 67 | 
            +
                def raw_file_names(self):
         | 
| 68 | 
            +
                    """
         | 
| 69 | 
            +
                    Returns the raw SMILES file name.
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
                    return self.raw_files
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                @property
         | 
| 74 | 
            +
                def processed_file_names(self):
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    Returns the name of the processed dataset file.
         | 
| 77 | 
            +
                    """
         | 
| 78 | 
            +
                    return self.dataset_file
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def _filter_smiles(self, smiles_list):
         | 
| 81 | 
            +
                    """
         | 
| 82 | 
            +
                    Filters the input list of SMILES strings to keep only valid molecules that:
         | 
| 83 | 
            +
                     - Can be successfully parsed,
         | 
| 84 | 
            +
                     - Have a number of atoms less than or equal to the maximum allowed (max_atom),
         | 
| 85 | 
            +
                     - Contain only atoms present in the atom_encoder,
         | 
| 86 | 
            +
                     - Contain only bonds present in the bond_encoder.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    Parameters:
         | 
| 89 | 
            +
                        smiles_list (list): List of SMILES strings.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    Returns:
         | 
| 92 | 
            +
                        max_length (int): Maximum number of atoms found in the filtered molecules.
         | 
| 93 | 
            +
                        filtered_smiles (list): List of valid SMILES strings.
         | 
| 94 | 
            +
                    """
         | 
| 95 | 
            +
                    max_length = 0
         | 
| 96 | 
            +
                    filtered_smiles = []
         | 
| 97 | 
            +
                    for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
         | 
| 98 | 
            +
                        mol = Chem.MolFromSmiles(smiles)
         | 
| 99 | 
            +
                        if mol is None:
         | 
| 100 | 
            +
                            continue
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        # Check molecule size
         | 
| 103 | 
            +
                        molecule_size = mol.GetNumAtoms()
         | 
| 104 | 
            +
                        if molecule_size > self.max_atom:
         | 
| 105 | 
            +
                            continue
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                        # Filter out molecules with atoms not in the atom_encoder
         | 
| 108 | 
            +
                        if not all(atom.GetAtomicNum() in self.atom_encoder_m for atom in mol.GetAtoms()):
         | 
| 109 | 
            +
                            continue
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        # Filter out molecules with bonds not in the bond_encoder
         | 
| 112 | 
            +
                        if not all(bond.GetBondType() in self.bond_encoder_m for bond in mol.GetBonds()):
         | 
| 113 | 
            +
                            continue
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        filtered_smiles.append(smiles)
         | 
| 116 | 
            +
                        max_length = max(max_length, molecule_size)
         | 
| 117 | 
            +
                    return max_length, filtered_smiles
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def _genA(self, mol, connected=True, max_length=None):
         | 
| 120 | 
            +
                    """
         | 
| 121 | 
            +
                    Generates the adjacency matrix for a molecule based on its bond structure.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    Parameters:
         | 
| 124 | 
            +
                        mol (rdkit.Chem.Mol): The molecule.
         | 
| 125 | 
            +
                        connected (bool): If True, ensures all atoms are connected.
         | 
| 126 | 
            +
                        max_length (int, optional): The size of the matrix; if None, uses number of atoms in mol.
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    Returns:
         | 
| 129 | 
            +
                        np.array: Adjacency matrix with bond types as entries, or None if disconnected.
         | 
| 130 | 
            +
                    """
         | 
| 131 | 
            +
                    max_length = max_length if max_length is not None else mol.GetNumAtoms()
         | 
| 132 | 
            +
                    A = np.zeros((max_length, max_length))
         | 
| 133 | 
            +
                    begin = [b.GetBeginAtomIdx() for b in mol.GetBonds()]
         | 
| 134 | 
            +
                    end = [b.GetEndAtomIdx() for b in mol.GetBonds()]
         | 
| 135 | 
            +
                    bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
         | 
| 136 | 
            +
                    A[begin, end] = bond_type
         | 
| 137 | 
            +
                    A[end, begin] = bond_type
         | 
| 138 | 
            +
                    degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
         | 
| 139 | 
            +
                    return A if connected and (degree > 0).all() else None
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def _genX(self, mol, max_length=None):
         | 
| 142 | 
            +
                    """
         | 
| 143 | 
            +
                    Generates the feature vector for each atom in a molecule by encoding their atomic numbers.
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    Parameters:
         | 
| 146 | 
            +
                        mol (rdkit.Chem.Mol): The molecule.
         | 
| 147 | 
            +
                        max_length (int, optional): Length of the feature vector; if None, uses number of atoms in mol.
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    Returns:
         | 
| 150 | 
            +
                        np.array: Array of atom feature indices, padded with zeros if necessary, or None on error.
         | 
| 151 | 
            +
                    """
         | 
| 152 | 
            +
                    max_length = max_length if max_length is not None else mol.GetNumAtoms()
         | 
| 153 | 
            +
                    try:
         | 
| 154 | 
            +
                        return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] +
         | 
| 155 | 
            +
                                        [0] * (max_length - mol.GetNumAtoms()))
         | 
| 156 | 
            +
                    except KeyError as e:
         | 
| 157 | 
            +
                        print(f"Skipping molecule with unsupported atom: {e}")
         | 
| 158 | 
            +
                        print(f"Skipped SMILES: {Chem.MolToSmiles(mol)}")
         | 
| 159 | 
            +
                        return None
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def _genF(self, mol, max_length=None):
         | 
| 162 | 
            +
                    """
         | 
| 163 | 
            +
                    Generates additional node features for a molecule using various atomic properties.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    Parameters:
         | 
| 166 | 
            +
                        mol (rdkit.Chem.Mol): The molecule.
         | 
| 167 | 
            +
                        max_length (int, optional): Number of rows in the features matrix; if None, uses number of atoms.
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    Returns:
         | 
| 170 | 
            +
                        np.array: Array of additional features for each atom, padded with zeros if necessary.
         | 
| 171 | 
            +
                    """
         | 
| 172 | 
            +
                    max_length = max_length if max_length is not None else mol.GetNumAtoms()
         | 
| 173 | 
            +
                    features = np.array([[*[a.GetDegree() == i for i in range(5)],
         | 
| 174 | 
            +
                                           *[a.GetExplicitValence() == i for i in range(9)],
         | 
| 175 | 
            +
                                           *[int(a.GetHybridization()) == i for i in range(1, 7)],
         | 
| 176 | 
            +
                                           *[a.GetImplicitValence() == i for i in range(9)],
         | 
| 177 | 
            +
                                           a.GetIsAromatic(),
         | 
| 178 | 
            +
                                           a.GetNoImplicit(),
         | 
| 179 | 
            +
                                           *[a.GetNumExplicitHs() == i for i in range(5)],
         | 
| 180 | 
            +
                                           *[a.GetNumImplicitHs() == i for i in range(5)],
         | 
| 181 | 
            +
                                           *[a.GetNumRadicalElectrons() == i for i in range(5)],
         | 
| 182 | 
            +
                                           a.IsInRing(),
         | 
| 183 | 
            +
                                           *[a.IsInRingSize(i) for i in range(2, 9)]]
         | 
| 184 | 
            +
                                          for a in mol.GetAtoms()], dtype=np.int32)
         | 
| 185 | 
            +
                    return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def decoder_load(self, dictionary_name, file):
         | 
| 188 | 
            +
                    """
         | 
| 189 | 
            +
                    Returns the pre-loaded decoder dictionary based on the dictionary name.
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    Parameters:
         | 
| 192 | 
            +
                        dictionary_name (str): Name of the dictionary ("atom" or "bond").
         | 
| 193 | 
            +
                        file: Placeholder parameter for compatibility.
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    Returns:
         | 
| 196 | 
            +
                        dict: The corresponding decoder dictionary.
         | 
| 197 | 
            +
                    """
         | 
| 198 | 
            +
                    if dictionary_name == "atom":
         | 
| 199 | 
            +
                        return self.atom_decoder_m
         | 
| 200 | 
            +
                    elif dictionary_name == "bond":
         | 
| 201 | 
            +
                        return self.bond_decoder_m
         | 
| 202 | 
            +
                    else:
         | 
| 203 | 
            +
                        raise ValueError("Unknown dictionary name.")
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
         | 
| 206 | 
            +
                    """
         | 
| 207 | 
            +
                    Converts graph representations (node labels and edge labels) back to an RDKit molecule.
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    Parameters:
         | 
| 210 | 
            +
                        node_labels (iterable): Encoded atom labels.
         | 
| 211 | 
            +
                        edge_labels (np.array): Adjacency matrix with encoded bond types.
         | 
| 212 | 
            +
                        strict (bool): If True, sanitizes the molecule and returns None on failure.
         | 
| 213 | 
            +
                        file_name: Placeholder parameter for compatibility.
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    Returns:
         | 
| 216 | 
            +
                        rdkit.Chem.Mol: The resulting molecule, or None if sanitization fails.
         | 
| 217 | 
            +
                    """
         | 
| 218 | 
            +
                    mol = Chem.RWMol()
         | 
| 219 | 
            +
                    for node_label in node_labels:
         | 
| 220 | 
            +
                        mol.AddAtom(Chem.Atom(self.atom_decoder_m[node_label]))
         | 
| 221 | 
            +
                    for start, end in zip(*np.nonzero(edge_labels)):
         | 
| 222 | 
            +
                        if start > end:
         | 
| 223 | 
            +
                            mol.AddBond(int(start), int(end), self.bond_decoder_m[edge_labels[start, end]])
         | 
| 224 | 
            +
                    if strict:
         | 
| 225 | 
            +
                        try:
         | 
| 226 | 
            +
                            Chem.SanitizeMol(mol)
         | 
| 227 | 
            +
                        except Exception:
         | 
| 228 | 
            +
                            mol = None
         | 
| 229 | 
            +
                    return mol
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                def check_valency(self, mol):
         | 
| 232 | 
            +
                    """
         | 
| 233 | 
            +
                    Checks that no atom in the molecule has exceeded its allowed valency.
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    Parameters:
         | 
| 236 | 
            +
                        mol (rdkit.Chem.Mol): The molecule.
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    Returns:
         | 
| 239 | 
            +
                        tuple: (True, None) if valid; (False, atomid_valence) if there is a valency issue.
         | 
| 240 | 
            +
                    """
         | 
| 241 | 
            +
                    try:
         | 
| 242 | 
            +
                        Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
         | 
| 243 | 
            +
                        return True, None
         | 
| 244 | 
            +
                    except ValueError as e:
         | 
| 245 | 
            +
                        e = str(e)
         | 
| 246 | 
            +
                        p = e.find('#')
         | 
| 247 | 
            +
                        e_sub = e[p:]
         | 
| 248 | 
            +
                        atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
         | 
| 249 | 
            +
                        return False, atomid_valence
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                def correct_mol(self, mol):
         | 
| 252 | 
            +
                    """
         | 
| 253 | 
            +
                    Corrects a molecule by removing bonds until all atoms satisfy their valency limits.
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    Parameters:
         | 
| 256 | 
            +
                        mol (rdkit.Chem.Mol): The molecule.
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    Returns:
         | 
| 259 | 
            +
                        rdkit.Chem.Mol: The corrected molecule.
         | 
| 260 | 
            +
                    """
         | 
| 261 | 
            +
                    while True:
         | 
| 262 | 
            +
                        flag, atomid_valence = self.check_valency(mol)
         | 
| 263 | 
            +
                        if flag:
         | 
| 264 | 
            +
                            break
         | 
| 265 | 
            +
                        else:
         | 
| 266 | 
            +
                            # Expecting two numbers: atom index and its valence.
         | 
| 267 | 
            +
                            assert len(atomid_valence) == 2
         | 
| 268 | 
            +
                            idx = atomid_valence[0]
         | 
| 269 | 
            +
                            queue = []
         | 
| 270 | 
            +
                            for b in mol.GetAtomWithIdx(idx).GetBonds():
         | 
| 271 | 
            +
                                queue.append((b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
         | 
| 272 | 
            +
                            queue.sort(key=lambda tup: tup[1], reverse=True)
         | 
| 273 | 
            +
                            if queue:
         | 
| 274 | 
            +
                                start = queue[0][2]
         | 
| 275 | 
            +
                                end = queue[0][3]
         | 
| 276 | 
            +
                                mol.RemoveBond(start, end)
         | 
| 277 | 
            +
                    return mol
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
                def process(self, size=None):
         | 
| 281 | 
            +
                    """
         | 
| 282 | 
            +
                    Processes the raw SMILES file by filtering and converting each valid SMILES into a PyTorch Geometric Data object.
         | 
| 283 | 
            +
                    The resulting dataset is saved to disk.
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    Parameters:
         | 
| 286 | 
            +
                        size (optional): Placeholder parameter for compatibility.
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    Side Effects:
         | 
| 289 | 
            +
                        Saves the processed dataset as a file in the processed directory.
         | 
| 290 | 
            +
                    """
         | 
| 291 | 
            +
                    # Read raw SMILES from file (assuming CSV with no header)
         | 
| 292 | 
            +
                    smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
         | 
| 293 | 
            +
                    max_length, filtered_smiles = self._filter_smiles(smiles_list)
         | 
| 294 | 
            +
                    data_list = []
         | 
| 295 | 
            +
                    self.m_dim = len(self.atom_decoder_m)
         | 
| 296 | 
            +
                    for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
         | 
| 297 | 
            +
                        mol = Chem.MolFromSmiles(smiles)
         | 
| 298 | 
            +
                        A = self._genA(mol, connected=True, max_length=max_length)
         | 
| 299 | 
            +
                        if A is not None:
         | 
| 300 | 
            +
                            x_array = self._genX(mol, max_length=max_length)
         | 
| 301 | 
            +
                            if x_array is None:
         | 
| 302 | 
            +
                                continue
         | 
| 303 | 
            +
                            x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
         | 
| 304 | 
            +
                            x = label2onehot(x, self.m_dim).squeeze()
         | 
| 305 | 
            +
                            if self.features:
         | 
| 306 | 
            +
                                f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
         | 
| 307 | 
            +
                                x = torch.concat((x, f), dim=-1)
         | 
| 308 | 
            +
                            adjacency = torch.from_numpy(A)
         | 
| 309 | 
            +
                            edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
         | 
| 310 | 
            +
                            edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
         | 
| 311 | 
            +
                            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
         | 
| 312 | 
            +
                            if self.pre_filter is not None and not self.pre_filter(data):
         | 
| 313 | 
            +
                                continue
         | 
| 314 | 
            +
                            if self.pre_transform is not None:
         | 
| 315 | 
            +
                                data = self.pre_transform(data)
         | 
| 316 | 
            +
                            data_list.append(data)
         | 
| 317 | 
            +
                    torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
         | 
    	
        src/data/utils.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import pickle
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import pandas as pd
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from torch_geometric.data import Data, InMemoryDataset
         | 
| 9 | 
            +
            import torch_geometric.utils as geoutils
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from rdkit import Chem, RDLogger
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def label2onehot(labels, dim, device=None):
         | 
| 16 | 
            +
                """Convert label indices to one-hot vectors."""
         | 
| 17 | 
            +
                out = torch.zeros(list(labels.size())+[dim])
         | 
| 18 | 
            +
                if device:
         | 
| 19 | 
            +
                    out = out.to(device)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                return out.float()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def get_encoders_decoders(raw_file1, raw_file2, max_atom):
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                Given two raw SMILES files, either load the atom and bond encoders/decoders
         | 
| 29 | 
            +
                if they exist (naming them based on the file names) or create and save them.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                Parameters:
         | 
| 32 | 
            +
                    raw_file1 (str): Path to the first SMILES file.
         | 
| 33 | 
            +
                    raw_file2 (str): Path to the second SMILES file.
         | 
| 34 | 
            +
                    max_atom (int): Maximum allowed number of atoms in a molecule.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                Returns:
         | 
| 37 | 
            +
                    atom_encoder (dict): Mapping from atomic numbers to indices.
         | 
| 38 | 
            +
                    atom_decoder (dict): Mapping from indices to atomic numbers.
         | 
| 39 | 
            +
                    bond_encoder (dict): Mapping from bond types to indices.
         | 
| 40 | 
            +
                    bond_decoder (dict): Mapping from indices to bond types.
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
                # Determine unique suffix based on the two file names (alphabetically sorted for consistency)
         | 
| 43 | 
            +
                name1 = os.path.splitext(os.path.basename(raw_file1))[0]
         | 
| 44 | 
            +
                name2 = os.path.splitext(os.path.basename(raw_file2))[0]
         | 
| 45 | 
            +
                sorted_names = sorted([name1, name2])
         | 
| 46 | 
            +
                suffix = f"{sorted_names[0]}_{sorted_names[1]}"
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # Define encoder/decoder directories and file paths
         | 
| 49 | 
            +
                enc_dir = os.path.join("data", "encoders")
         | 
| 50 | 
            +
                dec_dir = os.path.join("data", "decoders")
         | 
| 51 | 
            +
                atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl")
         | 
| 52 | 
            +
                atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl")
         | 
| 53 | 
            +
                bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl")
         | 
| 54 | 
            +
                bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl")
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # If all files exist, load and return them
         | 
| 57 | 
            +
                if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and 
         | 
| 58 | 
            +
                    os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)):
         | 
| 59 | 
            +
                    with open(atom_encoder_path, "rb") as f:
         | 
| 60 | 
            +
                        atom_encoder = pickle.load(f)
         | 
| 61 | 
            +
                    with open(atom_decoder_path, "rb") as f:
         | 
| 62 | 
            +
                        atom_decoder = pickle.load(f)
         | 
| 63 | 
            +
                    with open(bond_encoder_path, "rb") as f:
         | 
| 64 | 
            +
                        bond_encoder = pickle.load(f)
         | 
| 65 | 
            +
                    with open(bond_decoder_path, "rb") as f:
         | 
| 66 | 
            +
                        bond_decoder = pickle.load(f)
         | 
| 67 | 
            +
                    print("Loaded existing encoders/decoders!")
         | 
| 68 | 
            +
                    return atom_encoder, atom_decoder, bond_encoder, bond_decoder
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                # Otherwise, create the encoders/decoders
         | 
| 71 | 
            +
                print("Creating new encoders/decoders...")
         | 
| 72 | 
            +
                # Read SMILES from both files (assuming one SMILES per row, no header)
         | 
| 73 | 
            +
                smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist()
         | 
| 74 | 
            +
                smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist()
         | 
| 75 | 
            +
                smiles_combined = smiles1 + smiles2
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                atom_labels = set()
         | 
| 78 | 
            +
                bond_labels = set()
         | 
| 79 | 
            +
                max_length = 0
         | 
| 80 | 
            +
                filtered_smiles = []
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                # Process each SMILES: keep only valid molecules with <= max_atom atoms
         | 
| 83 | 
            +
                for smiles in tqdm(smiles_combined, desc="Processing SMILES"):
         | 
| 84 | 
            +
                    mol = Chem.MolFromSmiles(smiles)
         | 
| 85 | 
            +
                    if mol is None:
         | 
| 86 | 
            +
                        continue
         | 
| 87 | 
            +
                    molecule_size = mol.GetNumAtoms()
         | 
| 88 | 
            +
                    if molecule_size > max_atom:
         | 
| 89 | 
            +
                        continue
         | 
| 90 | 
            +
                    filtered_smiles.append(smiles)
         | 
| 91 | 
            +
                    # Collect atomic numbers
         | 
| 92 | 
            +
                    atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
         | 
| 93 | 
            +
                    max_length = max(max_length, molecule_size)
         | 
| 94 | 
            +
                    # Collect bond types
         | 
| 95 | 
            +
                    bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                # Add a PAD symbol (here using 0 for atoms)
         | 
| 98 | 
            +
                atom_labels.add(0)
         | 
| 99 | 
            +
                atom_labels = sorted(atom_labels)
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                # For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO)
         | 
| 102 | 
            +
                bond_labels = sorted(bond_labels)
         | 
| 103 | 
            +
                bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                # Create encoder and decoder dictionaries
         | 
| 106 | 
            +
                atom_encoder = {l: i for i, l in enumerate(atom_labels)}
         | 
| 107 | 
            +
                atom_decoder = {i: l for i, l in enumerate(atom_labels)}
         | 
| 108 | 
            +
                bond_encoder = {l: i for i, l in enumerate(bond_labels)}
         | 
| 109 | 
            +
                bond_decoder = {i: l for i, l in enumerate(bond_labels)}
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                # Ensure directories exist
         | 
| 112 | 
            +
                os.makedirs(enc_dir, exist_ok=True)
         | 
| 113 | 
            +
                os.makedirs(dec_dir, exist_ok=True)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                # Save the encoders/decoders to disk
         | 
| 116 | 
            +
                with open(atom_encoder_path, "wb") as f:
         | 
| 117 | 
            +
                    pickle.dump(atom_encoder, f)
         | 
| 118 | 
            +
                with open(atom_decoder_path, "wb") as f:
         | 
| 119 | 
            +
                    pickle.dump(atom_decoder, f)
         | 
| 120 | 
            +
                with open(bond_encoder_path, "wb") as f:
         | 
| 121 | 
            +
                    pickle.dump(bond_encoder, f)
         | 
| 122 | 
            +
                with open(bond_decoder_path, "wb") as f:
         | 
| 123 | 
            +
                    pickle.dump(bond_decoder, f)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                print("Encoders/decoders created and saved.")
         | 
| 126 | 
            +
                return atom_encoder, atom_decoder, bond_encoder, bond_decoder
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
         | 
| 129 | 
            +
                data = data.to(device)
         | 
| 130 | 
            +
                a = geoutils.to_dense_adj(
         | 
| 131 | 
            +
                    edge_index = data.edge_index,
         | 
| 132 | 
            +
                    batch=data.batch,
         | 
| 133 | 
            +
                    edge_attr=data.edge_attr,
         | 
| 134 | 
            +
                    max_num_nodes=int(data.batch.shape[0]/batch_size)
         | 
| 135 | 
            +
                )
         | 
| 136 | 
            +
                x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
         | 
| 137 | 
            +
                a_tensor = label2onehot(a, b_dim, device)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                a_tensor_vec = a_tensor.reshape(batch_size,-1)
         | 
| 140 | 
            +
                x_tensor_vec = x_tensor.reshape(batch_size,-1)
         | 
| 141 | 
            +
                real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                return real_graphs, a_tensor, x_tensor
         | 
    	
        src/model/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        src/model/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (156 Bytes). View file | 
|  | 
    	
        src/model/__pycache__/layers.cpython-310.pyc
    ADDED
    
    | Binary file (8.31 kB). View file | 
|  | 
    	
        src/model/__pycache__/loss.cpython-310.pyc
    ADDED
    
    | Binary file (2.04 kB). View file | 
|  | 
    	
        src/model/__pycache__/models.cpython-310.pyc
    ADDED
    
    | Binary file (7.35 kB). View file | 
|  | 
    	
        src/model/layers.py
    ADDED
    
    | @@ -0,0 +1,234 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from torch.nn import functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            class MLP(nn.Module):
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                A simple Multi-Layer Perceptron (MLP) module consisting of two linear layers with a ReLU activation in between,
         | 
| 10 | 
            +
                followed by a dropout on the output.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Attributes:
         | 
| 13 | 
            +
                    fc1 (nn.Linear): The first fully-connected layer.
         | 
| 14 | 
            +
                    act (nn.ReLU): ReLU activation function.
         | 
| 15 | 
            +
                    fc2 (nn.Linear): The second fully-connected layer.
         | 
| 16 | 
            +
                    droprateout (nn.Dropout): Dropout layer applied to the output.
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.):
         | 
| 19 | 
            +
                    """
         | 
| 20 | 
            +
                    Initializes the MLP module.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    Args:
         | 
| 23 | 
            +
                        in_feat (int): Number of input features.
         | 
| 24 | 
            +
                        hid_feat (int, optional): Number of hidden features. Defaults to in_feat if not provided.
         | 
| 25 | 
            +
                        out_feat (int, optional): Number of output features. Defaults to in_feat if not provided.
         | 
| 26 | 
            +
                        dropout (float, optional): Dropout rate. Defaults to 0.
         | 
| 27 | 
            +
                    """
         | 
| 28 | 
            +
                    super().__init__()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    # Set hidden and output dimensions to input dimension if not specified
         | 
| 31 | 
            +
                    if not hid_feat:
         | 
| 32 | 
            +
                        hid_feat = in_feat
         | 
| 33 | 
            +
                    if not out_feat:
         | 
| 34 | 
            +
                        out_feat = in_feat
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    self.fc1 = nn.Linear(in_feat, hid_feat)
         | 
| 37 | 
            +
                    self.act = nn.ReLU()
         | 
| 38 | 
            +
                    self.fc2 = nn.Linear(hid_feat, out_feat)
         | 
| 39 | 
            +
                    self.droprateout = nn.Dropout(dropout)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def forward(self, x):
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    Forward pass for the MLP.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    Args:
         | 
| 46 | 
            +
                        x (torch.Tensor): Input tensor.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    Returns:
         | 
| 49 | 
            +
                        torch.Tensor: Output tensor after applying the linear layers, activation, and dropout.
         | 
| 50 | 
            +
                    """
         | 
| 51 | 
            +
                    x = self.fc1(x)
         | 
| 52 | 
            +
                    x = self.act(x)
         | 
| 53 | 
            +
                    x = self.fc2(x)
         | 
| 54 | 
            +
                    return self.droprateout(x)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            class MHA(nn.Module):
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                Multi-Head Attention (MHA) module of the graph transformer with edge features incorporated into the attention computation.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                Attributes:
         | 
| 61 | 
            +
                    heads (int): Number of attention heads.
         | 
| 62 | 
            +
                    scale (float): Scaling factor for the attention scores.
         | 
| 63 | 
            +
                    q, k, v (nn.Linear): Linear layers to project the node features into query, key, and value embeddings.
         | 
| 64 | 
            +
                    e (nn.Linear): Linear layer to project the edge features.
         | 
| 65 | 
            +
                    d_k (int): Dimension of each attention head.
         | 
| 66 | 
            +
                    out_e (nn.Linear): Linear layer applied to the computed edge features.
         | 
| 67 | 
            +
                    out_n (nn.Linear): Linear layer applied to the aggregated node features.
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                def __init__(self, dim, heads, attention_dropout=0.):
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
                    Initializes the Multi-Head Attention module.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    Args:
         | 
| 74 | 
            +
                        dim (int): Dimensionality of the input features.
         | 
| 75 | 
            +
                        heads (int): Number of attention heads.
         | 
| 76 | 
            +
                        attention_dropout (float, optional): Dropout rate for attention (not used explicitly in this implementation).
         | 
| 77 | 
            +
                    """
         | 
| 78 | 
            +
                    super().__init__()
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Ensure that dimension is divisible by the number of heads
         | 
| 81 | 
            +
                    assert dim % heads == 0
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self.heads = heads
         | 
| 84 | 
            +
                    self.scale = 1. / math.sqrt(dim)  # Scaling factor for attention
         | 
| 85 | 
            +
                    # Linear layers for projecting node features
         | 
| 86 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 87 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 88 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 89 | 
            +
                    # Linear layer for projecting edge features
         | 
| 90 | 
            +
                    self.e = nn.Linear(dim, dim)
         | 
| 91 | 
            +
                    self.d_k = dim // heads  # Dimension per head
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # Linear layers for output transformations
         | 
| 94 | 
            +
                    self.out_e = nn.Linear(dim, dim)
         | 
| 95 | 
            +
                    self.out_n = nn.Linear(dim, dim)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def forward(self, node, edge):
         | 
| 98 | 
            +
                    """
         | 
| 99 | 
            +
                    Forward pass for the Multi-Head Attention.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    Args:
         | 
| 102 | 
            +
                        node (torch.Tensor): Node feature tensor of shape (batch, num_nodes, dim).
         | 
| 103 | 
            +
                        edge (torch.Tensor): Edge feature tensor of shape (batch, num_nodes, num_nodes, dim).
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    Returns:
         | 
| 106 | 
            +
                        tuple: (updated node features, updated edge features)
         | 
| 107 | 
            +
                    """
         | 
| 108 | 
            +
                    b, n, c = node.shape
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # Compute query, key, and value embeddings and reshape for multi-head attention
         | 
| 111 | 
            +
                    q_embed = self.q(node).view(b, n, self.heads, c // self.heads)
         | 
| 112 | 
            +
                    k_embed = self.k(node).view(b, n, self.heads, c // self.heads)
         | 
| 113 | 
            +
                    v_embed = self.v(node).view(b, n, self.heads, c // self.heads)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # Compute edge embeddings
         | 
| 116 | 
            +
                    e_embed = self.e(edge).view(b, n, n, self.heads, c // self.heads)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    # Adjust dimensions for broadcasting: add singleton dimensions to queries and keys
         | 
| 119 | 
            +
                    q_embed = q_embed.unsqueeze(2)  # Shape: (b, n, 1, heads, c//heads)
         | 
| 120 | 
            +
                    k_embed = k_embed.unsqueeze(1)  # Shape: (b, 1, n, heads, c//heads)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # Compute  attention scores
         | 
| 123 | 
            +
                    attn = q_embed * k_embed
         | 
| 124 | 
            +
                    attn = attn / math.sqrt(self.d_k)
         | 
| 125 | 
            +
                    attn = attn * (e_embed + 1) * e_embed   # Modulated attention incorporating edge features
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    edge_out = self.out_e(attn.flatten(3))  # Flatten last dimension for linear layer
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Apply softmax over the node dimension to obtain normalized attention weights
         | 
| 130 | 
            +
                    attn = F.softmax(attn, dim=2)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    v_embed = v_embed.unsqueeze(1)  # Adjust dimensions to broadcast: (b, 1, n, heads, c//heads)
         | 
| 133 | 
            +
                    v_embed = attn * v_embed
         | 
| 134 | 
            +
                    v_embed = v_embed.sum(dim=2).flatten(2)
         | 
| 135 | 
            +
                    node_out = self.out_n(v_embed)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    return node_out, edge_out
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            class Encoder_Block(nn.Module):
         | 
| 140 | 
            +
                """
         | 
| 141 | 
            +
                Transformer encoder block that integrates node and edge features.
         | 
| 142 | 
            +
                
         | 
| 143 | 
            +
                Consists of:
         | 
| 144 | 
            +
                    - A multi-head attention layer with edge modulation.
         | 
| 145 | 
            +
                    - Two MLP layers, each with residual connections and layer normalization.
         | 
| 146 | 
            +
                
         | 
| 147 | 
            +
                Attributes:
         | 
| 148 | 
            +
                    ln1, ln3, ln4, ln5, ln6 (nn.LayerNorm): Layer normalization modules.
         | 
| 149 | 
            +
                    attn (MHA): Multi-head attention module.
         | 
| 150 | 
            +
                    mlp, mlp2 (MLP): MLP modules for further transformation of node and edge features.
         | 
| 151 | 
            +
                """
         | 
| 152 | 
            +
                def __init__(self, dim, heads, act, mlp_ratio=4, drop_rate=0.):
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    Initializes the encoder block.
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    Args:
         | 
| 157 | 
            +
                        dim (int): Dimensionality of the input features.
         | 
| 158 | 
            +
                        heads (int): Number of attention heads.
         | 
| 159 | 
            +
                        act (callable): Activation function (not explicitly used in this block, but provided for potential extensions).
         | 
| 160 | 
            +
                        mlp_ratio (int, optional): Ratio to determine the hidden layer size in the MLP. Defaults to 4.
         | 
| 161 | 
            +
                        drop_rate (float, optional): Dropout rate applied in the MLPs. Defaults to 0.
         | 
| 162 | 
            +
                    """
         | 
| 163 | 
            +
                    super().__init__()
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    self.ln1 = nn.LayerNorm(dim)
         | 
| 166 | 
            +
                    self.attn = MHA(dim, heads, drop_rate)
         | 
| 167 | 
            +
                    self.ln3 = nn.LayerNorm(dim)
         | 
| 168 | 
            +
                    self.ln4 = nn.LayerNorm(dim)
         | 
| 169 | 
            +
                    self.mlp = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
         | 
| 170 | 
            +
                    self.mlp2 = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
         | 
| 171 | 
            +
                    self.ln5 = nn.LayerNorm(dim)
         | 
| 172 | 
            +
                    self.ln6 = nn.LayerNorm(dim)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def forward(self, x, y):
         | 
| 175 | 
            +
                    """
         | 
| 176 | 
            +
                    Forward pass of the encoder block.
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    Args:
         | 
| 179 | 
            +
                        x (torch.Tensor): Node feature tensor.
         | 
| 180 | 
            +
                        y (torch.Tensor): Edge feature tensor.
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    Returns:
         | 
| 183 | 
            +
                        tuple: (updated node features, updated edge features)
         | 
| 184 | 
            +
                    """
         | 
| 185 | 
            +
                    x1 = self.ln1(x)
         | 
| 186 | 
            +
                    x2, y1 = self.attn(x1, y)
         | 
| 187 | 
            +
                    x2 = x1 + x2
         | 
| 188 | 
            +
                    y2 = y + y1
         | 
| 189 | 
            +
                    x2 = self.ln3(x2)
         | 
| 190 | 
            +
                    y2 = self.ln4(y2)
         | 
| 191 | 
            +
                    x = self.ln5(x2 + self.mlp(x2))
         | 
| 192 | 
            +
                    y = self.ln6(y2 + self.mlp2(y2))
         | 
| 193 | 
            +
                    return x, y
         | 
| 194 | 
            +
             | 
| 195 | 
            +
            class TransformerEncoder(nn.Module):
         | 
| 196 | 
            +
                """
         | 
| 197 | 
            +
                Transformer Encoder composed of a sequence of encoder blocks.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                Attributes:
         | 
| 200 | 
            +
                    Encoder_Blocks (nn.ModuleList): A list of Encoder_Block modules stacked sequentially.
         | 
| 201 | 
            +
                """
         | 
| 202 | 
            +
                def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1):
         | 
| 203 | 
            +
                    """
         | 
| 204 | 
            +
                    Initializes the Transformer Encoder.
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    Args:
         | 
| 207 | 
            +
                        dim (int): Dimensionality of the input features.
         | 
| 208 | 
            +
                        depth (int): Number of encoder blocks to stack.
         | 
| 209 | 
            +
                        heads (int): Number of attention heads in each block.
         | 
| 210 | 
            +
                        act (callable): Activation function (passed to encoder blocks for potential use).
         | 
| 211 | 
            +
                        mlp_ratio (int, optional): Ratio for determining the hidden layer size in MLP modules. Defaults to 4.
         | 
| 212 | 
            +
                        drop_rate (float, optional): Dropout rate for the MLPs within each block. Defaults to 0.1.
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
                    super().__init__()
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.Encoder_Blocks = nn.ModuleList([
         | 
| 217 | 
            +
                        Encoder_Block(dim, heads, act, mlp_ratio, drop_rate)
         | 
| 218 | 
            +
                        for _ in range(depth)
         | 
| 219 | 
            +
                    ])
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def forward(self, x, y):
         | 
| 222 | 
            +
                    """
         | 
| 223 | 
            +
                    Forward pass of the Transformer Encoder.
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    Args:
         | 
| 226 | 
            +
                        x (torch.Tensor): Node feature tensor.
         | 
| 227 | 
            +
                        y (torch.Tensor): Edge feature tensor.
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    Returns:
         | 
| 230 | 
            +
                        tuple: (final node features, final edge features) after processing through all encoder blocks.
         | 
| 231 | 
            +
                    """
         | 
| 232 | 
            +
                    for block in self.Encoder_Blocks:
         | 
| 233 | 
            +
                        x, y = block(x, y)
         | 
| 234 | 
            +
                    return x, y
         | 
    	
        src/model/loss.py
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device):
         | 
| 5 | 
            +
                """
         | 
| 6 | 
            +
                Calculate gradient penalty for WGAN-GP.
         | 
| 7 | 
            +
                
         | 
| 8 | 
            +
                Args:
         | 
| 9 | 
            +
                    discriminator: The discriminator model
         | 
| 10 | 
            +
                    real_node: Real node features
         | 
| 11 | 
            +
                    real_edge: Real edge features
         | 
| 12 | 
            +
                    fake_node: Generated node features
         | 
| 13 | 
            +
                    fake_edge: Generated edge features
         | 
| 14 | 
            +
                    batch_size: Batch size
         | 
| 15 | 
            +
                    device: Device to compute on
         | 
| 16 | 
            +
                    
         | 
| 17 | 
            +
                Returns:
         | 
| 18 | 
            +
                    Gradient penalty term
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                # Generate random interpolation factors
         | 
| 21 | 
            +
                eps_edge = torch.rand(batch_size, 1, 1, 1, device=device)
         | 
| 22 | 
            +
                eps_node = torch.rand(batch_size, 1, 1, device=device)
         | 
| 23 | 
            +
                
         | 
| 24 | 
            +
                # Create interpolated samples
         | 
| 25 | 
            +
                int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True)
         | 
| 26 | 
            +
                int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                logits_interpolated = discriminator(int_edge, int_node)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # Calculate gradients for both node and edge inputs
         | 
| 31 | 
            +
                weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device)
         | 
| 32 | 
            +
                gradients = torch.autograd.grad(
         | 
| 33 | 
            +
                    outputs=logits_interpolated,
         | 
| 34 | 
            +
                    inputs=[int_node, int_edge],
         | 
| 35 | 
            +
                    grad_outputs=weight,
         | 
| 36 | 
            +
                    create_graph=True,
         | 
| 37 | 
            +
                    retain_graph=True,
         | 
| 38 | 
            +
                    only_inputs=True
         | 
| 39 | 
            +
                )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                # Combine gradients from both inputs
         | 
| 42 | 
            +
                gradients_node = gradients[0].view(batch_size, -1)
         | 
| 43 | 
            +
                gradients_edge = gradients[1].view(batch_size, -1)
         | 
| 44 | 
            +
                gradients = torch.cat([gradients_node, gradients_edge], dim=1)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                # Calculate gradient penalty
         | 
| 47 | 
            +
                gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
         | 
| 48 | 
            +
                
         | 
| 49 | 
            +
                return gradient_penalty
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp):
         | 
| 53 | 
            +
                # Compute loss for drugs
         | 
| 54 | 
            +
                logits_real_disc = discriminator(drug_adj, drug_annot)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # Use mean reduction for more stable training
         | 
| 57 | 
            +
                prediction_real = -torch.mean(logits_real_disc)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # Compute loss for generated molecules
         | 
| 60 | 
            +
                node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach())
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                prediction_fake = torch.mean(logits_fake_disc)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                # Compute gradient penalty using the new function
         | 
| 67 | 
            +
                gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                # Calculate total discriminator loss
         | 
| 70 | 
            +
                d_loss = prediction_fake + prediction_real + lambda_gp * gp
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                return node, edge, d_loss
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size):
         | 
| 76 | 
            +
                # Generate fake molecules
         | 
| 77 | 
            +
                node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # Compute logits for fake molecules
         | 
| 80 | 
            +
                logits_fake_disc = discriminator(edge_sample, node_sample)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                prediction_fake = -torch.mean(logits_fake_disc)
         | 
| 83 | 
            +
                g_loss = prediction_fake
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                return g_loss, node, edge, node_sample, edge_sample
         | 
    	
        src/model/models.py
    ADDED
    
    | @@ -0,0 +1,269 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            from src.model.layers import TransformerEncoder
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class Generator(nn.Module):
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                Generator network that uses a Transformer Encoder to process node and edge features.
         | 
| 8 | 
            +
                
         | 
| 9 | 
            +
                The network first processes input node and edge features with separate linear layers,
         | 
| 10 | 
            +
                then applies a Transformer Encoder to model interactions, and finally outputs both transformed
         | 
| 11 | 
            +
                features and readout samples.
         | 
| 12 | 
            +
                """
         | 
| 13 | 
            +
                def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
         | 
| 14 | 
            +
                    """
         | 
| 15 | 
            +
                    Initializes the Generator.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    Args:
         | 
| 18 | 
            +
                        act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh").
         | 
| 19 | 
            +
                        vertexes (int): Number of vertexes in the graph.
         | 
| 20 | 
            +
                        edges (int): Number of edge features.
         | 
| 21 | 
            +
                        nodes (int): Number of node features.
         | 
| 22 | 
            +
                        dropout (float): Dropout rate.
         | 
| 23 | 
            +
                        dim (int): Dimensionality used for intermediate features.
         | 
| 24 | 
            +
                        depth (int): Number of Transformer encoder blocks.
         | 
| 25 | 
            +
                        heads (int): Number of attention heads in the Transformer.
         | 
| 26 | 
            +
                        mlp_ratio (int): Ratio for determining hidden layer size in MLP modules.
         | 
| 27 | 
            +
                    """
         | 
| 28 | 
            +
                    super(Generator, self).__init__()
         | 
| 29 | 
            +
                    self.vertexes = vertexes
         | 
| 30 | 
            +
                    self.edges = edges
         | 
| 31 | 
            +
                    self.nodes = nodes
         | 
| 32 | 
            +
                    self.depth = depth
         | 
| 33 | 
            +
                    self.dim = dim
         | 
| 34 | 
            +
                    self.heads = heads
         | 
| 35 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 36 | 
            +
                    self.dropout = dropout
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    # Set the activation function based on the provided string
         | 
| 39 | 
            +
                    if act == "relu":
         | 
| 40 | 
            +
                        act = nn.ReLU()
         | 
| 41 | 
            +
                    elif act == "leaky":
         | 
| 42 | 
            +
                        act = nn.LeakyReLU()
         | 
| 43 | 
            +
                    elif act == "sigmoid":
         | 
| 44 | 
            +
                        act = nn.Sigmoid()
         | 
| 45 | 
            +
                    elif act == "tanh":
         | 
| 46 | 
            +
                        act = nn.Tanh()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    # Calculate the total number of features and dimensions for transformer
         | 
| 49 | 
            +
                    self.features = vertexes * vertexes * edges + vertexes * nodes
         | 
| 50 | 
            +
                    self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.node_layers = nn.Sequential(
         | 
| 53 | 
            +
                        nn.Linear(nodes, 64), act,
         | 
| 54 | 
            +
                        nn.Linear(64, dim), act,
         | 
| 55 | 
            +
                        nn.Dropout(self.dropout)
         | 
| 56 | 
            +
                    )
         | 
| 57 | 
            +
                    self.edge_layers = nn.Sequential(
         | 
| 58 | 
            +
                        nn.Linear(edges, 64), act,
         | 
| 59 | 
            +
                        nn.Linear(64, dim), act,
         | 
| 60 | 
            +
                        nn.Dropout(self.dropout)
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
                    self.TransformerEncoder = TransformerEncoder(
         | 
| 63 | 
            +
                        dim=self.dim, depth=self.depth, heads=self.heads, act=act,
         | 
| 64 | 
            +
                        mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    self.readout_e = nn.Linear(self.dim, edges)
         | 
| 68 | 
            +
                    self.readout_n = nn.Linear(self.dim, nodes)
         | 
| 69 | 
            +
                    self.softmax = nn.Softmax(dim=-1)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def forward(self, z_e, z_n):
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                    Forward pass of the Generator.
         | 
| 74 | 
            +
                    
         | 
| 75 | 
            +
                    Args:
         | 
| 76 | 
            +
                        z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
         | 
| 77 | 
            +
                        z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    Returns:
         | 
| 80 | 
            +
                        tuple: A tuple containing:
         | 
| 81 | 
            +
                            - node: Updated node features after the transformer.
         | 
| 82 | 
            +
                            - edge: Updated edge features after the transformer.
         | 
| 83 | 
            +
                            - node_sample: Readout sample from node features.
         | 
| 84 | 
            +
                            - edge_sample: Readout sample from edge features.
         | 
| 85 | 
            +
                    """
         | 
| 86 | 
            +
                    b, n, c = z_n.shape
         | 
| 87 | 
            +
                    # The fourth dimension of edge features
         | 
| 88 | 
            +
                    _, _, _, d = z_e.shape
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    # Process node and edge features through their respective layers
         | 
| 91 | 
            +
                    node = self.node_layers(z_n)
         | 
| 92 | 
            +
                    edge = self.edge_layers(z_e)
         | 
| 93 | 
            +
                    # Symmetrize the edge features by averaging with its transpose along vertex dimensions
         | 
| 94 | 
            +
                    edge = (edge + edge.permute(0, 2, 1, 3)) / 2
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    # Pass the features through the Transformer Encoder
         | 
| 97 | 
            +
                    node, edge = self.TransformerEncoder(node, edge)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # Readout layers to generate final outputs
         | 
| 100 | 
            +
                    node_sample = self.readout_n(node)
         | 
| 101 | 
            +
                    edge_sample = self.readout_e(edge)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    return node, edge, node_sample, edge_sample
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            class Discriminator(nn.Module):
         | 
| 107 | 
            +
                """
         | 
| 108 | 
            +
                Discriminator network that evaluates node and edge features.
         | 
| 109 | 
            +
                
         | 
| 110 | 
            +
                It processes features with linear layers, applies a Transformer Encoder to capture dependencies,
         | 
| 111 | 
            +
                and finally predicts a scalar value using an MLP on aggregated node features.
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                This class is used in DrugGEN model.
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
         | 
| 116 | 
            +
                    """
         | 
| 117 | 
            +
                    Initializes the Discriminator.
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    Args:
         | 
| 120 | 
            +
                        act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
         | 
| 121 | 
            +
                        vertexes (int): Number of vertexes.
         | 
| 122 | 
            +
                        edges (int): Number of edge features.
         | 
| 123 | 
            +
                        nodes (int): Number of node features.
         | 
| 124 | 
            +
                        dropout (float): Dropout rate.
         | 
| 125 | 
            +
                        dim (int): Dimensionality for intermediate representations.
         | 
| 126 | 
            +
                        depth (int): Number of Transformer encoder blocks.
         | 
| 127 | 
            +
                        heads (int): Number of attention heads.
         | 
| 128 | 
            +
                        mlp_ratio (int): MLP ratio for hidden layer dimensions.
         | 
| 129 | 
            +
                    """
         | 
| 130 | 
            +
                    super(Discriminator, self).__init__()
         | 
| 131 | 
            +
                    self.vertexes = vertexes
         | 
| 132 | 
            +
                    self.edges = edges
         | 
| 133 | 
            +
                    self.nodes = nodes
         | 
| 134 | 
            +
                    self.depth = depth
         | 
| 135 | 
            +
                    self.dim = dim
         | 
| 136 | 
            +
                    self.heads = heads
         | 
| 137 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 138 | 
            +
                    self.dropout = dropout
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # Set the activation function
         | 
| 141 | 
            +
                    if act == "relu":
         | 
| 142 | 
            +
                        act = nn.ReLU()
         | 
| 143 | 
            +
                    elif act == "leaky":
         | 
| 144 | 
            +
                        act = nn.LeakyReLU()
         | 
| 145 | 
            +
                    elif act == "sigmoid":
         | 
| 146 | 
            +
                        act = nn.Sigmoid()
         | 
| 147 | 
            +
                    elif act == "tanh":
         | 
| 148 | 
            +
                        act = nn.Tanh()
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    self.features = vertexes * vertexes * edges + vertexes * nodes
         | 
| 151 | 
            +
                    self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # Define layers for processing node and edge features
         | 
| 154 | 
            +
                    self.node_layers = nn.Sequential(
         | 
| 155 | 
            +
                        nn.Linear(nodes, 64), act,
         | 
| 156 | 
            +
                        nn.Linear(64, dim), act,
         | 
| 157 | 
            +
                        nn.Dropout(self.dropout)
         | 
| 158 | 
            +
                    )
         | 
| 159 | 
            +
                    self.edge_layers = nn.Sequential(
         | 
| 160 | 
            +
                        nn.Linear(edges, 64), act,
         | 
| 161 | 
            +
                        nn.Linear(64, dim), act,
         | 
| 162 | 
            +
                        nn.Dropout(self.dropout)
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
                    # Transformer Encoder for modeling node and edge interactions
         | 
| 165 | 
            +
                    self.TransformerEncoder = TransformerEncoder(
         | 
| 166 | 
            +
                        dim=self.dim, depth=self.depth, heads=self.heads, act=act,
         | 
| 167 | 
            +
                        mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
         | 
| 168 | 
            +
                    )
         | 
| 169 | 
            +
                    # Calculate dimensions for node features aggregation
         | 
| 170 | 
            +
                    self.node_features = vertexes * dim
         | 
| 171 | 
            +
                    self.edge_features = vertexes * vertexes * dim
         | 
| 172 | 
            +
                    # MLP to predict a scalar value from aggregated node features
         | 
| 173 | 
            +
                    self.node_mlp = nn.Sequential(
         | 
| 174 | 
            +
                        nn.Linear(self.node_features, 64), act,
         | 
| 175 | 
            +
                        nn.Linear(64, 32), act,
         | 
| 176 | 
            +
                        nn.Linear(32, 16), act,
         | 
| 177 | 
            +
                        nn.Linear(16, 1)
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def forward(self, z_e, z_n):
         | 
| 181 | 
            +
                    """
         | 
| 182 | 
            +
                    Forward pass of the Discriminator.
         | 
| 183 | 
            +
                    
         | 
| 184 | 
            +
                    Args:
         | 
| 185 | 
            +
                        z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
         | 
| 186 | 
            +
                        z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
         | 
| 187 | 
            +
                    
         | 
| 188 | 
            +
                    Returns:
         | 
| 189 | 
            +
                        torch.Tensor: Prediction scores (typically a scalar per sample).
         | 
| 190 | 
            +
                    """
         | 
| 191 | 
            +
                    b, n, c = z_n.shape
         | 
| 192 | 
            +
                    # Unpack the shape of edge features (not used further directly)
         | 
| 193 | 
            +
                    _, _, _, d = z_e.shape
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # Process node and edge features separately
         | 
| 196 | 
            +
                    node = self.node_layers(z_n)
         | 
| 197 | 
            +
                    edge = self.edge_layers(z_e)
         | 
| 198 | 
            +
                    # Symmetrize edge features by averaging with its transpose
         | 
| 199 | 
            +
                    edge = (edge + edge.permute(0, 2, 1, 3)) / 2
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # Process features through the Transformer Encoder
         | 
| 202 | 
            +
                    node, edge = self.TransformerEncoder(node, edge)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # Flatten node features for MLP
         | 
| 205 | 
            +
                    node = node.view(b, -1)
         | 
| 206 | 
            +
                    # Predict a scalar score using the node MLP
         | 
| 207 | 
            +
                    prediction = self.node_mlp(node)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    return prediction
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            class simple_disc(nn.Module):
         | 
| 213 | 
            +
                """
         | 
| 214 | 
            +
                A simplified discriminator that processes flattened features through an MLP
         | 
| 215 | 
            +
                to predict a scalar score.
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                This class is used in NoTarget model.
         | 
| 218 | 
            +
                """
         | 
| 219 | 
            +
                def __init__(self, act, m_dim, vertexes, b_dim):
         | 
| 220 | 
            +
                    """
         | 
| 221 | 
            +
                    Initializes the simple discriminator.
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    Args:
         | 
| 224 | 
            +
                        act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
         | 
| 225 | 
            +
                        m_dim (int): Dimensionality for atom type features.
         | 
| 226 | 
            +
                        vertexes (int): Number of vertexes.
         | 
| 227 | 
            +
                        b_dim (int): Dimensionality for bond type features.
         | 
| 228 | 
            +
                    """
         | 
| 229 | 
            +
                    super().__init__()
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    # Set the activation function and check if it's supported
         | 
| 232 | 
            +
                    if act == "relu":
         | 
| 233 | 
            +
                        act = nn.ReLU()
         | 
| 234 | 
            +
                    elif act == "leaky":
         | 
| 235 | 
            +
                        act = nn.LeakyReLU()
         | 
| 236 | 
            +
                    elif act == "sigmoid":
         | 
| 237 | 
            +
                        act = nn.Sigmoid()
         | 
| 238 | 
            +
                    elif act == "tanh":
         | 
| 239 | 
            +
                        act = nn.Tanh()
         | 
| 240 | 
            +
                    else:
         | 
| 241 | 
            +
                        raise ValueError("Unsupported activation function: {}".format(act))
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # Compute total number of features combining both dimensions
         | 
| 244 | 
            +
                    features = vertexes * m_dim + vertexes * vertexes * b_dim
         | 
| 245 | 
            +
                    print(vertexes)
         | 
| 246 | 
            +
                    print(m_dim)
         | 
| 247 | 
            +
                    print(b_dim)
         | 
| 248 | 
            +
                    print(features)
         | 
| 249 | 
            +
                    self.predictor = nn.Sequential(
         | 
| 250 | 
            +
                        nn.Linear(features, 256), act,
         | 
| 251 | 
            +
                        nn.Linear(256, 128), act,
         | 
| 252 | 
            +
                        nn.Linear(128, 64), act,
         | 
| 253 | 
            +
                        nn.Linear(64, 32), act,
         | 
| 254 | 
            +
                        nn.Linear(32, 16), act,
         | 
| 255 | 
            +
                        nn.Linear(16, 1)
         | 
| 256 | 
            +
                    )
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def forward(self, x):
         | 
| 259 | 
            +
                    """
         | 
| 260 | 
            +
                    Forward pass of the simple discriminator.
         | 
| 261 | 
            +
                    
         | 
| 262 | 
            +
                    Args:
         | 
| 263 | 
            +
                        x (torch.Tensor): Input features tensor.
         | 
| 264 | 
            +
                    
         | 
| 265 | 
            +
                    Returns:
         | 
| 266 | 
            +
                        torch.Tensor: Prediction scores.
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    prediction = self.predictor(x)
         | 
| 269 | 
            +
                    return prediction
         | 
    	
        src/util/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        src/util/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (155 Bytes). View file | 
|  | 
    	
        src/util/__pycache__/smiles_cor.cpython-310.pyc
    ADDED
    
    | Binary file (30.2 kB). View file | 
|  | 
    	
        src/util/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (30 kB). View file | 
|  | 
    	
        src/util/smiles_cor.py
    ADDED
    
    | @@ -0,0 +1,1284 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
            import itertools
         | 
| 6 | 
            +
            import statistics
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import pandas as pd
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torch.nn as nn
         | 
| 12 | 
            +
            import torch.optim as optim
         | 
| 13 | 
            +
            from torch.utils.data import DataLoader
         | 
| 14 | 
            +
            from torchtext.data import TabularDataset, Field, BucketIterator, Iterator
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from rdkit import Chem, rdBase, RDLogger
         | 
| 17 | 
            +
            from rdkit.Chem import (
         | 
| 18 | 
            +
                MolStandardize,
         | 
| 19 | 
            +
                GraphDescriptors,
         | 
| 20 | 
            +
                Lipinski,
         | 
| 21 | 
            +
                AllChem,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
            from rdkit.Chem.rdSLNParse import MolFromSLN
         | 
| 24 | 
            +
            from rdkit.Chem.rdmolfiles import MolFromSmiles
         | 
| 25 | 
            +
            from chembl_structure_pipeline import standardizer
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            RDLogger.DisableLog('rdApp.*')
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            SEED = 42
         | 
| 30 | 
            +
            random.seed(SEED)
         | 
| 31 | 
            +
            torch.manual_seed(SEED)
         | 
| 32 | 
            +
            torch.backends.cudnn.deterministic = True
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            ##################################################################################################
         | 
| 35 | 
            +
            ##################################################################################################
         | 
| 36 | 
            +
            #                                                                                                #
         | 
| 37 | 
            +
            #  THIS SCRIPT IS DIRECTLY ADAPTED FROM https://github.com/LindeSchoenmaker/SMILES-corrector     #
         | 
| 38 | 
            +
            #                                                                                                #
         | 
| 39 | 
            +
            ##################################################################################################
         | 
| 40 | 
            +
            ##################################################################################################
         | 
| 41 | 
            +
            def is_smiles(array,
         | 
| 42 | 
            +
                          TRG,
         | 
| 43 | 
            +
                          reverse: bool,
         | 
| 44 | 
            +
                          return_output=False,
         | 
| 45 | 
            +
                          src=None,
         | 
| 46 | 
            +
                          src_field=None):
         | 
| 47 | 
            +
                """Turns predicted tokens within batch into smiles and evaluates their validity
         | 
| 48 | 
            +
                Arguments:
         | 
| 49 | 
            +
                    array: Tensor with most probable token for each location for each sequence in batch
         | 
| 50 | 
            +
                        [trg len, batch size]
         | 
| 51 | 
            +
                    TRG: target field for getting tokens from vocab
         | 
| 52 | 
            +
                    reverse (bool): True if the target sequence is reversed
         | 
| 53 | 
            +
                    return_output (bool): True if output sequences and their validity should be saved
         | 
| 54 | 
            +
                Returns:
         | 
| 55 | 
            +
                    df: dataframe with correct and incorrect sequences
         | 
| 56 | 
            +
                    valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
         | 
| 57 | 
            +
                    smiless: list of the predicted smiles
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                trg_field = TRG
         | 
| 60 | 
            +
                valids = []
         | 
| 61 | 
            +
                smiless = []
         | 
| 62 | 
            +
                if return_output:
         | 
| 63 | 
            +
                    df = pd.DataFrame()
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    df = None
         | 
| 66 | 
            +
                batch_size = array.size(1)
         | 
| 67 | 
            +
                # check if the first token should be removed, first token is zero because
         | 
| 68 | 
            +
                # outputs initaliazed to all be zeros
         | 
| 69 | 
            +
                if int((array[0, 0]).tolist()) == 0:
         | 
| 70 | 
            +
                    start = 1
         | 
| 71 | 
            +
                else:
         | 
| 72 | 
            +
                    start = 0
         | 
| 73 | 
            +
                # for each sequence in the batch
         | 
| 74 | 
            +
                for i in range(0, batch_size):
         | 
| 75 | 
            +
                    # turns sequence from tensor to list skipps first row as this is not
         | 
| 76 | 
            +
                    # filled in in forward
         | 
| 77 | 
            +
                    sequence = (array[start:, i]).tolist()
         | 
| 78 | 
            +
                    # goes from embedded to tokens
         | 
| 79 | 
            +
                    trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
         | 
| 80 | 
            +
                    # print(trg_tokens)
         | 
| 81 | 
            +
                    # takes all tokens untill eos token, model would be faster if did this
         | 
| 82 | 
            +
                    # one step earlier, but then changes in vocab order would disrupt.
         | 
| 83 | 
            +
                    rev_tokens = list(
         | 
| 84 | 
            +
                        itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
         | 
| 85 | 
            +
                    if reverse:
         | 
| 86 | 
            +
                        rev_tokens = rev_tokens[::-1]
         | 
| 87 | 
            +
                    smiles = "".join(rev_tokens)
         | 
| 88 | 
            +
                    # determine how many valid smiles are made
         | 
| 89 | 
            +
                    valid = True if MolFromSmiles(smiles) else False
         | 
| 90 | 
            +
                    valids.append(valid)
         | 
| 91 | 
            +
                    smiless.append(smiles)
         | 
| 92 | 
            +
                    if return_output:
         | 
| 93 | 
            +
                        if valid:
         | 
| 94 | 
            +
                            df.loc[i, "CORRECT"] = smiles
         | 
| 95 | 
            +
                        else:
         | 
| 96 | 
            +
                            df.loc[i, "INCORRECT"] = smiles
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                # add the original drugex outputs to the _de dataframe
         | 
| 99 | 
            +
                if return_output and src is not None:
         | 
| 100 | 
            +
                    for i in range(0, batch_size):
         | 
| 101 | 
            +
                        # turns sequence from tensor to list skipps first row as this is
         | 
| 102 | 
            +
                        # <sos> for src
         | 
| 103 | 
            +
                        sequence = (src[1:, i]).tolist()
         | 
| 104 | 
            +
                        # goes from embedded to tokens
         | 
| 105 | 
            +
                        src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
         | 
| 106 | 
            +
                        # takes all tokens untill eos token, model would be faster if did
         | 
| 107 | 
            +
                        # this one step earlier, but then changes in vocab order would
         | 
| 108 | 
            +
                        # disrupt.
         | 
| 109 | 
            +
                        rev_tokens = list(
         | 
| 110 | 
            +
                            itertools.takewhile(lambda x: x != "<eos>", src_tokens))
         | 
| 111 | 
            +
                        smiles = "".join(rev_tokens)
         | 
| 112 | 
            +
                        df.loc[i, "ORIGINAL"] = smiles
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                return df, valids, smiless
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            def is_unchanged(array,
         | 
| 118 | 
            +
                             TRG,
         | 
| 119 | 
            +
                             reverse: bool,
         | 
| 120 | 
            +
                             return_output=False,
         | 
| 121 | 
            +
                             src=None,
         | 
| 122 | 
            +
                             src_field=None):
         | 
| 123 | 
            +
                """Checks is output is different from input
         | 
| 124 | 
            +
                Arguments:
         | 
| 125 | 
            +
                    array: Tensor with most probable token for each location for each sequence in batch
         | 
| 126 | 
            +
                        [trg len, batch size]
         | 
| 127 | 
            +
                    TRG: target field for getting tokens from vocab
         | 
| 128 | 
            +
                    reverse (bool): True if the target sequence is reversed
         | 
| 129 | 
            +
                    return_output (bool): True if output sequences and their validity should be saved
         | 
| 130 | 
            +
                Returns:
         | 
| 131 | 
            +
                    df: dataframe with correct and incorrect sequences
         | 
| 132 | 
            +
                    valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
         | 
| 133 | 
            +
                    smiless: list of the predicted smiles
         | 
| 134 | 
            +
                """
         | 
| 135 | 
            +
                trg_field = TRG
         | 
| 136 | 
            +
                sources = []
         | 
| 137 | 
            +
                batch_size = array.size(1)
         | 
| 138 | 
            +
                unchanged = 0
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                # check if the first token should be removed, first token is zero because
         | 
| 141 | 
            +
                # outputs initaliazed to all be zeros
         | 
| 142 | 
            +
                if int((array[0, 0]).tolist()) == 0:
         | 
| 143 | 
            +
                    start = 1
         | 
| 144 | 
            +
                else:
         | 
| 145 | 
            +
                    start = 0
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                for i in range(0, batch_size):
         | 
| 148 | 
            +
                    # turns sequence from tensor to list skipps first row as this is <sos>
         | 
| 149 | 
            +
                    # for src
         | 
| 150 | 
            +
                    sequence = (src[1:, i]).tolist()
         | 
| 151 | 
            +
                    # goes from embedded to tokens
         | 
| 152 | 
            +
                    src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
         | 
| 153 | 
            +
                    # takes all tokens untill eos token, model would be faster if did this
         | 
| 154 | 
            +
                    # one step earlier, but then changes in vocab order would disrupt.
         | 
| 155 | 
            +
                    rev_tokens = list(
         | 
| 156 | 
            +
                        itertools.takewhile(lambda x: x != "<eos>", src_tokens))
         | 
| 157 | 
            +
                    smiles = "".join(rev_tokens)
         | 
| 158 | 
            +
                    sources.append(smiles)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                # for each sequence in the batch
         | 
| 161 | 
            +
                for i in range(0, batch_size):
         | 
| 162 | 
            +
                    # turns sequence from tensor to list skipps first row as this is not
         | 
| 163 | 
            +
                    # filled in in forward
         | 
| 164 | 
            +
                    sequence = (array[start:, i]).tolist()
         | 
| 165 | 
            +
                    # goes from embedded to tokens
         | 
| 166 | 
            +
                    trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
         | 
| 167 | 
            +
                    # print(trg_tokens)
         | 
| 168 | 
            +
                    # takes all tokens untill eos token, model would be faster if did this
         | 
| 169 | 
            +
                    # one step earlier, but then changes in vocab order would disrupt.
         | 
| 170 | 
            +
                    rev_tokens = list(
         | 
| 171 | 
            +
                        itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
         | 
| 172 | 
            +
                    if reverse:
         | 
| 173 | 
            +
                        rev_tokens = rev_tokens[::-1]
         | 
| 174 | 
            +
                    smiles = "".join(rev_tokens)
         | 
| 175 | 
            +
                    # determine how many valid smiles are made
         | 
| 176 | 
            +
                    valid = True if MolFromSmiles(smiles) else False
         | 
| 177 | 
            +
                    if not valid:
         | 
| 178 | 
            +
                        if smiles == sources[i]:
         | 
| 179 | 
            +
                            unchanged += 1
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                return unchanged
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            def molecule_reconstruction(array, TRG, reverse: bool, outputs):
         | 
| 185 | 
            +
                """Turns target tokens within batch into smiles and compares them to predicted output smiles
         | 
| 186 | 
            +
                Arguments:
         | 
| 187 | 
            +
                    array: Tensor with target's token for each location for each sequence in batch
         | 
| 188 | 
            +
                        [trg len, batch size]
         | 
| 189 | 
            +
                    TRG: target field for getting tokens from vocab
         | 
| 190 | 
            +
                    reverse (bool): True if the target sequence is reversed
         | 
| 191 | 
            +
                    outputs: list of predicted SMILES sequences
         | 
| 192 | 
            +
                Returns:
         | 
| 193 | 
            +
                     matches(int): number of total right molecules
         | 
| 194 | 
            +
                """
         | 
| 195 | 
            +
                trg_field = TRG
         | 
| 196 | 
            +
                matches = 0
         | 
| 197 | 
            +
                targets = []
         | 
| 198 | 
            +
                batch_size = array.size(1)
         | 
| 199 | 
            +
                # for each sequence in the batch
         | 
| 200 | 
            +
                for i in range(0, batch_size):
         | 
| 201 | 
            +
                    # turns sequence from tensor to list skipps first row as this is not
         | 
| 202 | 
            +
                    # filled in in forward
         | 
| 203 | 
            +
                    sequence = (array[1:, i]).tolist()
         | 
| 204 | 
            +
                    # goes from embedded to tokens
         | 
| 205 | 
            +
                    trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
         | 
| 206 | 
            +
                    # takes all tokens untill eos token, model would be faster if did this
         | 
| 207 | 
            +
                    # one step earlier, but then changes in vocab order would disrupt.
         | 
| 208 | 
            +
                    rev_tokens = list(
         | 
| 209 | 
            +
                        itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
         | 
| 210 | 
            +
                    if reverse:
         | 
| 211 | 
            +
                        rev_tokens = rev_tokens[::-1]
         | 
| 212 | 
            +
                    smiles = "".join(rev_tokens)
         | 
| 213 | 
            +
                    targets.append(smiles)
         | 
| 214 | 
            +
                for i in range(0, batch_size):
         | 
| 215 | 
            +
                    m = MolFromSmiles(targets[i])
         | 
| 216 | 
            +
                    p = MolFromSmiles(outputs[i])
         | 
| 217 | 
            +
                    if p is not None:
         | 
| 218 | 
            +
                        if m.HasSubstructMatch(p) and p.HasSubstructMatch(m):
         | 
| 219 | 
            +
                            matches += 1
         | 
| 220 | 
            +
                return matches
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            def complexity_whitlock(mol: Chem.Mol, includeAllDescs=False):
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                Complexity as defined in DOI:10.1021/jo9814546
         | 
| 226 | 
            +
                S: complexity = 4*#rings + 2*#unsat + #hetatm + 2*#chiral
         | 
| 227 | 
            +
                Other descriptors:
         | 
| 228 | 
            +
                    H: size = #bonds (Hydrogen atoms included)
         | 
| 229 | 
            +
                    G: S + H
         | 
| 230 | 
            +
                    Ratio: S / H
         | 
| 231 | 
            +
                """
         | 
| 232 | 
            +
                mol_ = Chem.Mol(mol)
         | 
| 233 | 
            +
                nrings = Lipinski.RingCount(mol_) - Lipinski.NumAromaticRings(mol_)
         | 
| 234 | 
            +
                Chem.rdmolops.SetAromaticity(mol_)
         | 
| 235 | 
            +
                unsat = sum(1 for bond in mol_.GetBonds()
         | 
| 236 | 
            +
                            if bond.GetBondTypeAsDouble() == 2)
         | 
| 237 | 
            +
                hetatm = len(mol_.GetSubstructMatches(Chem.MolFromSmarts("[!#6]")))
         | 
| 238 | 
            +
                AllChem.EmbedMolecule(mol_)
         | 
| 239 | 
            +
                Chem.AssignAtomChiralTagsFromStructure(mol_)
         | 
| 240 | 
            +
                chiral = len(Chem.FindMolChiralCenters(mol_))
         | 
| 241 | 
            +
                S = 4 * nrings + 2 * unsat + hetatm + 2 * chiral
         | 
| 242 | 
            +
                if not includeAllDescs:
         | 
| 243 | 
            +
                    return S
         | 
| 244 | 
            +
                Chem.rdmolops.Kekulize(mol_)
         | 
| 245 | 
            +
                mol_ = Chem.AddHs(mol_)
         | 
| 246 | 
            +
                H = sum(bond.GetBondTypeAsDouble() for bond in mol_.GetBonds())
         | 
| 247 | 
            +
                G = S + H
         | 
| 248 | 
            +
                R = S / H
         | 
| 249 | 
            +
                return {"WhitlockS": S, "WhitlockH": H, "WhitlockG": G, "WhitlockRatio": R}
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            def complexity_baronechanon(mol: Chem.Mol):
         | 
| 253 | 
            +
                """
         | 
| 254 | 
            +
                Complexity as defined in DOI:10.1021/ci000145p
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                mol_ = Chem.Mol(mol)
         | 
| 257 | 
            +
                Chem.Kekulize(mol_)
         | 
| 258 | 
            +
                Chem.RemoveStereochemistry(mol_)
         | 
| 259 | 
            +
                mol_ = Chem.RemoveHs(mol_, updateExplicitCount=True)
         | 
| 260 | 
            +
                degree, counts = 0, 0
         | 
| 261 | 
            +
                for atom in mol_.GetAtoms():
         | 
| 262 | 
            +
                    degree += 3 * 2**(atom.GetExplicitValence() - atom.GetNumExplicitHs() -
         | 
| 263 | 
            +
                                      1)
         | 
| 264 | 
            +
                    counts += 3 if atom.GetSymbol() == "C" else 6
         | 
| 265 | 
            +
                ringterm = sum(map(lambda x: 6 * len(x), mol_.GetRingInfo().AtomRings()))
         | 
| 266 | 
            +
                return degree + counts + ringterm
         | 
| 267 | 
            +
             | 
| 268 | 
            +
             | 
| 269 | 
            +
            def calc_complexity(array,
         | 
| 270 | 
            +
                                TRG,
         | 
| 271 | 
            +
                                reverse,
         | 
| 272 | 
            +
                                valids,
         | 
| 273 | 
            +
                                complexity_function=GraphDescriptors.BertzCT):
         | 
| 274 | 
            +
                """Calculates the complexity of inputs that are not correct.
         | 
| 275 | 
            +
                Arguments:
         | 
| 276 | 
            +
                    array: Tensor with target's token for each location for each sequence in batch
         | 
| 277 | 
            +
                        [trg len, batch size]
         | 
| 278 | 
            +
                    TRG: target field for getting tokens from vocab
         | 
| 279 | 
            +
                    reverse (bool): True if the target sequence is reversed
         | 
| 280 | 
            +
                    valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
         | 
| 281 | 
            +
                    complexity_function: the type of complexity measure that will be used
         | 
| 282 | 
            +
                        GraphDescriptors.BertzCT
         | 
| 283 | 
            +
                        complexity_whitlock
         | 
| 284 | 
            +
                        complexity_baronechanon
         | 
| 285 | 
            +
                Returns:
         | 
| 286 | 
            +
                     matches(int): mean of complexity values
         | 
| 287 | 
            +
                """
         | 
| 288 | 
            +
                trg_field = TRG
         | 
| 289 | 
            +
                sources = []
         | 
| 290 | 
            +
                complexities = []
         | 
| 291 | 
            +
                loc = torch.BoolTensor(valids)
         | 
| 292 | 
            +
                # only keeps rows in batch size dimension where valid is false
         | 
| 293 | 
            +
                array = array[:, loc == False]
         | 
| 294 | 
            +
                # should check if this still works
         | 
| 295 | 
            +
                # array = torch.transpose(array, 0, 1)
         | 
| 296 | 
            +
                array_size = array.size(1)
         | 
| 297 | 
            +
                for i in range(0, array_size):
         | 
| 298 | 
            +
                    # turns sequence from tensor to list skipps first row as this is not
         | 
| 299 | 
            +
                    # filled in in forward
         | 
| 300 | 
            +
                    sequence = (array[1:, i]).tolist()
         | 
| 301 | 
            +
                    # goes from embedded to tokens
         | 
| 302 | 
            +
                    trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
         | 
| 303 | 
            +
                    # takes all tokens untill eos token, model would be faster if did this
         | 
| 304 | 
            +
                    # one step earlier, but then changes in vocab order would disrupt.
         | 
| 305 | 
            +
                    rev_tokens = list(
         | 
| 306 | 
            +
                        itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
         | 
| 307 | 
            +
                    if reverse:
         | 
| 308 | 
            +
                        rev_tokens = rev_tokens[::-1]
         | 
| 309 | 
            +
                    smiles = "".join(rev_tokens)
         | 
| 310 | 
            +
                    sources.append(smiles)
         | 
| 311 | 
            +
                for source in sources:
         | 
| 312 | 
            +
                    try:
         | 
| 313 | 
            +
                        m = MolFromSmiles(source)
         | 
| 314 | 
            +
                    except BaseException:
         | 
| 315 | 
            +
                        m = MolFromSLN(source)
         | 
| 316 | 
            +
                    complexities.append(complexity_function(m))
         | 
| 317 | 
            +
                if len(complexities) > 0:
         | 
| 318 | 
            +
                    mean = statistics.mean(complexities)
         | 
| 319 | 
            +
                else:
         | 
| 320 | 
            +
                    mean = 0
         | 
| 321 | 
            +
                return mean
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            def epoch_time(start_time, end_time):
         | 
| 325 | 
            +
                elapsed_time = end_time - start_time
         | 
| 326 | 
            +
                elapsed_mins = int(elapsed_time / 60)
         | 
| 327 | 
            +
                elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
         | 
| 328 | 
            +
                return elapsed_mins, elapsed_secs
         | 
| 329 | 
            +
             | 
| 330 | 
            +
             | 
| 331 | 
            +
            class Convo:
         | 
| 332 | 
            +
                """Class for training and evaluating transformer and convolutional neural network
         | 
| 333 | 
            +
                
         | 
| 334 | 
            +
                Methods
         | 
| 335 | 
            +
                -------
         | 
| 336 | 
            +
                train_model()
         | 
| 337 | 
            +
                    train model for initialized number of epochs
         | 
| 338 | 
            +
                evaluate(return_output)
         | 
| 339 | 
            +
                    use model with validation loader (& optionally drugex loader) to get test loss & other metrics
         | 
| 340 | 
            +
                translate(loader)
         | 
| 341 | 
            +
                    translate inputs from loader (different from evaluate in that no target sequence is used)
         | 
| 342 | 
            +
                """
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def train_model(self):
         | 
| 345 | 
            +
                    optimizer = optim.Adam(self.parameters(), lr=self.lr)
         | 
| 346 | 
            +
                    log = open(f"{self.out}.log", "a")
         | 
| 347 | 
            +
                    best_error = np.inf
         | 
| 348 | 
            +
                    for epoch in range(self.epochs):
         | 
| 349 | 
            +
                        self.train()
         | 
| 350 | 
            +
                        start_time = time.time()
         | 
| 351 | 
            +
                        loss_train = 0
         | 
| 352 | 
            +
                        for i, batch in enumerate(self.loader_train):
         | 
| 353 | 
            +
                            optimizer.zero_grad()
         | 
| 354 | 
            +
                            # changed src,trg call to match with bentrevett
         | 
| 355 | 
            +
                            # src, trg = batch['src'], batch['trg']
         | 
| 356 | 
            +
                            trg = batch.trg
         | 
| 357 | 
            +
                            src = batch.src
         | 
| 358 | 
            +
                            output, attention = self(src, trg[:, :-1])
         | 
| 359 | 
            +
                            # feed the source and target into def forward to get the output
         | 
| 360 | 
            +
                            # Xuhan uses forward for this, with istrain = true
         | 
| 361 | 
            +
                            output_dim = output.shape[-1]
         | 
| 362 | 
            +
                            # changed
         | 
| 363 | 
            +
                            output = output.contiguous().view(-1, output_dim)
         | 
| 364 | 
            +
                            trg = trg[:, 1:].contiguous().view(-1)
         | 
| 365 | 
            +
                            # output = output[:,:,0]#.view(-1)
         | 
| 366 | 
            +
                            # output = output[1:].view(-1, output.shape[-1])
         | 
| 367 | 
            +
                            # trg = trg[1:].view(-1)
         | 
| 368 | 
            +
                            loss = nn.CrossEntropyLoss(
         | 
| 369 | 
            +
                                ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
         | 
| 370 | 
            +
                            a, b = output.view(-1), trg.to(self.device).view(-1)
         | 
| 371 | 
            +
                            # changed
         | 
| 372 | 
            +
                            # loss = loss(output.view(0), trg.view(0).to(device))
         | 
| 373 | 
            +
                            loss = loss(output, trg)
         | 
| 374 | 
            +
                            loss.backward()
         | 
| 375 | 
            +
                            torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)
         | 
| 376 | 
            +
                            optimizer.step()
         | 
| 377 | 
            +
                            loss_train += loss.item()
         | 
| 378 | 
            +
                            # turned off for now, as not using voc so won't work, output is a tensor
         | 
| 379 | 
            +
                            # output = [(trg len - 1) * batch size, output dim]
         | 
| 380 | 
            +
                            # smiles, valid = is_valid_smiles(output, reversed)
         | 
| 381 | 
            +
                            # if valid:
         | 
| 382 | 
            +
                            #    valids += 1
         | 
| 383 | 
            +
                            #    smiless.append(smiles)
         | 
| 384 | 
            +
                        # added .dataset becaue len(iterator) gives len(self.dataset) /
         | 
| 385 | 
            +
                        # self.batch_size)
         | 
| 386 | 
            +
                        loss_train /= len(self.loader_train)
         | 
| 387 | 
            +
                        info = f"Epoch: {epoch+1:02} step: {i} loss_train: {loss_train:.4g}"
         | 
| 388 | 
            +
                        # model is used to generate trg based on src from the validation set to assess performance
         | 
| 389 | 
            +
                        # similar to Xuhan, although he doesn't use the if loop
         | 
| 390 | 
            +
                        if self.loader_valid is not None:
         | 
| 391 | 
            +
                            return_output = False
         | 
| 392 | 
            +
                            if epoch + 1 == self.epochs:
         | 
| 393 | 
            +
                                return_output = True
         | 
| 394 | 
            +
                            (
         | 
| 395 | 
            +
                                valids,
         | 
| 396 | 
            +
                                loss_valid,
         | 
| 397 | 
            +
                                valids_de,
         | 
| 398 | 
            +
                                df_output,
         | 
| 399 | 
            +
                                df_output_de,
         | 
| 400 | 
            +
                                right_molecules,
         | 
| 401 | 
            +
                                complexity,
         | 
| 402 | 
            +
                                unchanged,
         | 
| 403 | 
            +
                                unchanged_de,
         | 
| 404 | 
            +
                            ) = self.evaluate(return_output)
         | 
| 405 | 
            +
                            reconstruction_error = 1 - right_molecules / len(
         | 
| 406 | 
            +
                                self.loader_valid.dataset)
         | 
| 407 | 
            +
                            error = 1 - valids / len(self.loader_valid.dataset)
         | 
| 408 | 
            +
                            complexity = complexity / len(self.loader_valid)
         | 
| 409 | 
            +
                            unchan = unchanged / (len(self.loader_valid.dataset) - valids)
         | 
| 410 | 
            +
                            info += f" loss_valid: {loss_valid:.4g} error_rate: {error:.4g} molecule_reconstruction_error_rate: {reconstruction_error:.4g} unchanged: {unchan:.4g} invalid_target_complexity: {complexity:.4g}"
         | 
| 411 | 
            +
                            if self.loader_drugex is not None:
         | 
| 412 | 
            +
                                error_de = 1 - valids_de / len(self.loader_drugex.dataset)
         | 
| 413 | 
            +
                                unchan_de = unchanged_de / (
         | 
| 414 | 
            +
                                    len(self.loader_drugex.dataset) - valids_de)
         | 
| 415 | 
            +
                                info += f" error_rate_drugex: {error_de:.4g} unchanged_drugex: {unchan_de:.4g}"
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                            if reconstruction_error < best_error:
         | 
| 418 | 
            +
                                torch.save(self.state_dict(), f"{self.out}.pkg")
         | 
| 419 | 
            +
                                best_error = reconstruction_error
         | 
| 420 | 
            +
                                last_save = epoch
         | 
| 421 | 
            +
                            else:
         | 
| 422 | 
            +
                                if epoch - last_save >= 10 and best_error != 1:
         | 
| 423 | 
            +
                                    torch.save(self.state_dict(), f"{self.out}_last.pkg")
         | 
| 424 | 
            +
                                    (
         | 
| 425 | 
            +
                                        valids,
         | 
| 426 | 
            +
                                        loss_valid,
         | 
| 427 | 
            +
                                        valids_de,
         | 
| 428 | 
            +
                                        df_output,
         | 
| 429 | 
            +
                                        df_output_de,
         | 
| 430 | 
            +
                                        right_molecules,
         | 
| 431 | 
            +
                                        complexity,
         | 
| 432 | 
            +
                                        unchanged,
         | 
| 433 | 
            +
                                        unchanged_de,
         | 
| 434 | 
            +
                                    ) = self.evaluate(True)
         | 
| 435 | 
            +
                                    end_time = time.time()
         | 
| 436 | 
            +
                                    epoch_mins, epoch_secs = epoch_time(
         | 
| 437 | 
            +
                                        start_time, end_time)
         | 
| 438 | 
            +
                                    info += f" Time: {epoch_mins}m {epoch_secs}s"
         | 
| 439 | 
            +
                             
         | 
| 440 | 
            +
                                    break
         | 
| 441 | 
            +
                        elif error < best_error:
         | 
| 442 | 
            +
                            torch.save(self.state_dict(), f"{self.out}.pkg")
         | 
| 443 | 
            +
                            best_error = error
         | 
| 444 | 
            +
                        end_time = time.time()
         | 
| 445 | 
            +
                        epoch_mins, epoch_secs = epoch_time(start_time, end_time)
         | 
| 446 | 
            +
                        info += f" Time: {epoch_mins}m {epoch_secs}s"
         | 
| 447 | 
            +
               
         | 
| 448 | 
            +
                    
         | 
| 449 | 
            +
                    torch.save(self.state_dict(), f"{self.out}_last.pkg")
         | 
| 450 | 
            +
                    log.close()
         | 
| 451 | 
            +
                    self.load_state_dict(torch.load(f"{self.out}.pkg"))
         | 
| 452 | 
            +
                    df_output.to_csv(f"{self.out}.csv", index=False)
         | 
| 453 | 
            +
                    df_output_de.to_csv(f"{self.out}_de.csv", index=False)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                def evaluate(self, return_output):
         | 
| 456 | 
            +
                    self.eval()
         | 
| 457 | 
            +
                    test_loss = 0
         | 
| 458 | 
            +
                    df_output = pd.DataFrame()
         | 
| 459 | 
            +
                    df_output_de = pd.DataFrame()
         | 
| 460 | 
            +
                    valids = 0
         | 
| 461 | 
            +
                    valids_de = 0
         | 
| 462 | 
            +
                    unchanged = 0
         | 
| 463 | 
            +
                    unchanged_de = 0
         | 
| 464 | 
            +
                    right_molecules = 0
         | 
| 465 | 
            +
                    complexity = 0
         | 
| 466 | 
            +
                    with torch.no_grad():
         | 
| 467 | 
            +
                        for _, batch in enumerate(self.loader_valid):
         | 
| 468 | 
            +
                            trg = batch.trg
         | 
| 469 | 
            +
                            src = batch.src
         | 
| 470 | 
            +
                            output, attention = self.forward(src, trg[:, :-1])
         | 
| 471 | 
            +
                            pred_token = output.argmax(2)
         | 
| 472 | 
            +
                            array = torch.transpose(pred_token, 0, 1)
         | 
| 473 | 
            +
                            trg_trans = torch.transpose(trg, 0, 1)
         | 
| 474 | 
            +
                            output_dim = output.shape[-1]
         | 
| 475 | 
            +
                            output = output.contiguous().view(-1, output_dim)
         | 
| 476 | 
            +
                            trg = trg[:, 1:].contiguous().view(-1)
         | 
| 477 | 
            +
                            src_trans = torch.transpose(src, 0, 1)
         | 
| 478 | 
            +
                            df_batch, valid, smiless = is_smiles(
         | 
| 479 | 
            +
                                array, self.TRG, reverse=True, return_output=return_output)
         | 
| 480 | 
            +
                            unchanged += is_unchanged(
         | 
| 481 | 
            +
                                array,
         | 
| 482 | 
            +
                                self.TRG,
         | 
| 483 | 
            +
                                reverse=True,
         | 
| 484 | 
            +
                                return_output=return_output,
         | 
| 485 | 
            +
                                src=src_trans,
         | 
| 486 | 
            +
                                src_field=self.SRC,
         | 
| 487 | 
            +
                            )
         | 
| 488 | 
            +
                            matches = molecule_reconstruction(trg_trans,
         | 
| 489 | 
            +
                                                              self.TRG,
         | 
| 490 | 
            +
                                                              reverse=True,
         | 
| 491 | 
            +
                                                              outputs=smiless)
         | 
| 492 | 
            +
                            complexity += calc_complexity(trg_trans,
         | 
| 493 | 
            +
                                                          self.TRG,
         | 
| 494 | 
            +
                                                          reverse=True,
         | 
| 495 | 
            +
                                                          valids=valid)
         | 
| 496 | 
            +
                            if df_batch is not None:
         | 
| 497 | 
            +
                                df_output = pd.concat([df_output, df_batch],
         | 
| 498 | 
            +
                                                      ignore_index=True)
         | 
| 499 | 
            +
                            right_molecules += matches
         | 
| 500 | 
            +
                            valids += sum(valid)
         | 
| 501 | 
            +
                            # trg = trg[1:].view(-1)
         | 
| 502 | 
            +
                            # output, trg = output[1:].view(-1, output.shape[-1]), trg[1:].view(-1)
         | 
| 503 | 
            +
                            loss = nn.CrossEntropyLoss(
         | 
| 504 | 
            +
                                ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
         | 
| 505 | 
            +
                            loss = loss(output, trg)
         | 
| 506 | 
            +
                        test_loss += loss.item()
         | 
| 507 | 
            +
                        if self.loader_drugex is not None:
         | 
| 508 | 
            +
                            for _, batch in enumerate(self.loader_drugex):
         | 
| 509 | 
            +
                                src = batch.src
         | 
| 510 | 
            +
                                output = self.translate_sentence(src, self.TRG,
         | 
| 511 | 
            +
                                                                 self.device)
         | 
| 512 | 
            +
                                # checks the number of valid smiles
         | 
| 513 | 
            +
                                pred_token = output.argmax(2)
         | 
| 514 | 
            +
                                array = torch.transpose(pred_token, 0, 1)
         | 
| 515 | 
            +
                                src_trans = torch.transpose(src, 0, 1)
         | 
| 516 | 
            +
                                df_batch, valid, smiless = is_smiles(
         | 
| 517 | 
            +
                                    array,
         | 
| 518 | 
            +
                                    self.TRG,
         | 
| 519 | 
            +
                                    reverse=True,
         | 
| 520 | 
            +
                                    return_output=return_output,
         | 
| 521 | 
            +
                                    src=src_trans,
         | 
| 522 | 
            +
                                    src_field=self.SRC,
         | 
| 523 | 
            +
                                )
         | 
| 524 | 
            +
                                unchanged_de += is_unchanged(
         | 
| 525 | 
            +
                                    array,
         | 
| 526 | 
            +
                                    self.TRG,
         | 
| 527 | 
            +
                                    reverse=True,
         | 
| 528 | 
            +
                                    return_output=return_output,
         | 
| 529 | 
            +
                                    src=src_trans,
         | 
| 530 | 
            +
                                    src_field=self.SRC,
         | 
| 531 | 
            +
                                )
         | 
| 532 | 
            +
                                if df_batch is not None:
         | 
| 533 | 
            +
                                    df_output_de = pd.concat([df_output_de, df_batch],
         | 
| 534 | 
            +
                                                             ignore_index=True)
         | 
| 535 | 
            +
                                valids_de += sum(valid)
         | 
| 536 | 
            +
                    return (
         | 
| 537 | 
            +
                        valids,
         | 
| 538 | 
            +
                        test_loss / len(self.loader_valid),
         | 
| 539 | 
            +
                        valids_de,
         | 
| 540 | 
            +
                        df_output,
         | 
| 541 | 
            +
                        df_output_de,
         | 
| 542 | 
            +
                        right_molecules,
         | 
| 543 | 
            +
                        complexity,
         | 
| 544 | 
            +
                        unchanged,
         | 
| 545 | 
            +
                        unchanged_de,
         | 
| 546 | 
            +
                    )
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                def translate(self, loader):
         | 
| 549 | 
            +
                    self.eval()
         | 
| 550 | 
            +
                    df_output_de = pd.DataFrame()
         | 
| 551 | 
            +
                    valids_de = 0
         | 
| 552 | 
            +
                    with torch.no_grad():
         | 
| 553 | 
            +
                        for _, batch in enumerate(loader):
         | 
| 554 | 
            +
                            src = batch.src
         | 
| 555 | 
            +
                            output = self.translate_sentence(src, self.TRG, self.device)
         | 
| 556 | 
            +
                            # checks the number of valid smiles
         | 
| 557 | 
            +
                            pred_token = output.argmax(2)
         | 
| 558 | 
            +
                            array = torch.transpose(pred_token, 0, 1)
         | 
| 559 | 
            +
                            src_trans = torch.transpose(src, 0, 1)
         | 
| 560 | 
            +
                            df_batch, valid, smiless = is_smiles(
         | 
| 561 | 
            +
                                array,
         | 
| 562 | 
            +
                                self.TRG,
         | 
| 563 | 
            +
                                reverse=True,
         | 
| 564 | 
            +
                                return_output=True,
         | 
| 565 | 
            +
                                src=src_trans,
         | 
| 566 | 
            +
                                src_field=self.SRC,
         | 
| 567 | 
            +
                            )
         | 
| 568 | 
            +
                            if df_batch is not None:
         | 
| 569 | 
            +
                                df_output_de = pd.concat([df_output_de, df_batch],
         | 
| 570 | 
            +
                                                         ignore_index=True)
         | 
| 571 | 
            +
                            valids_de += sum(valid)
         | 
| 572 | 
            +
                    return valids_de, df_output_de
         | 
| 573 | 
            +
             | 
| 574 | 
            +
             | 
| 575 | 
            +
            class Encoder(nn.Module):
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout,
         | 
| 578 | 
            +
                             max_length, device):
         | 
| 579 | 
            +
                    super().__init__()
         | 
| 580 | 
            +
                    self.device = device
         | 
| 581 | 
            +
                    self.tok_embedding = nn.Embedding(input_dim, hid_dim)
         | 
| 582 | 
            +
                    self.pos_embedding = nn.Embedding(max_length, hid_dim)
         | 
| 583 | 
            +
                    self.layers = nn.ModuleList([
         | 
| 584 | 
            +
                        EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
         | 
| 585 | 
            +
                        for _ in range(n_layers)
         | 
| 586 | 
            +
                    ])
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 589 | 
            +
                    self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                def forward(self, src, src_mask):
         | 
| 592 | 
            +
                    # src = [batch size, src len]
         | 
| 593 | 
            +
                    # src_mask = [batch size, src len]
         | 
| 594 | 
            +
                    batch_size = src.shape[0]
         | 
| 595 | 
            +
                    src_len = src.shape[1]
         | 
| 596 | 
            +
                    pos = (torch.arange(0, src_len).unsqueeze(0).repeat(batch_size,
         | 
| 597 | 
            +
                                                                        1).to(self.device))
         | 
| 598 | 
            +
                    # pos = [batch size, src len]
         | 
| 599 | 
            +
                    src = self.dropout((self.tok_embedding(src) * self.scale) +
         | 
| 600 | 
            +
                                       self.pos_embedding(pos))
         | 
| 601 | 
            +
                    # src = [batch size, src len, hid dim]
         | 
| 602 | 
            +
                    for layer in self.layers:
         | 
| 603 | 
            +
                        src = layer(src, src_mask)
         | 
| 604 | 
            +
                    # src = [batch size, src len, hid dim]
         | 
| 605 | 
            +
                    return src
         | 
| 606 | 
            +
             | 
| 607 | 
            +
             | 
| 608 | 
            +
            class EncoderLayer(nn.Module):
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
         | 
| 611 | 
            +
                    super().__init__()
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                    self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
         | 
| 614 | 
            +
                    self.ff_layer_norm = nn.LayerNorm(hid_dim)
         | 
| 615 | 
            +
                    self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
         | 
| 616 | 
            +
                                                                  dropout, device)
         | 
| 617 | 
            +
                    self.positionwise_feedforward = PositionwiseFeedforwardLayer(
         | 
| 618 | 
            +
                        hid_dim, pf_dim, dropout)
         | 
| 619 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                def forward(self, src, src_mask):
         | 
| 622 | 
            +
                    # src = [batch size, src len, hid dim]
         | 
| 623 | 
            +
                    # src_mask = [batch size, src len]
         | 
| 624 | 
            +
                    # self attention
         | 
| 625 | 
            +
                    _src, _ = self.self_attention(src, src, src, src_mask)
         | 
| 626 | 
            +
                    # dropout, residual connection and layer norm
         | 
| 627 | 
            +
                    src = self.self_attn_layer_norm(src + self.dropout(_src))
         | 
| 628 | 
            +
                    # src = [batch size, src len, hid dim]
         | 
| 629 | 
            +
                    # positionwise feedforward
         | 
| 630 | 
            +
                    _src = self.positionwise_feedforward(src)
         | 
| 631 | 
            +
                    # dropout, residual and layer norm
         | 
| 632 | 
            +
                    src = self.ff_layer_norm(src + self.dropout(_src))
         | 
| 633 | 
            +
                    # src = [batch size, src len, hid dim]
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                    return src
         | 
| 636 | 
            +
             | 
| 637 | 
            +
             | 
| 638 | 
            +
            class MultiHeadAttentionLayer(nn.Module):
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                def __init__(self, hid_dim, n_heads, dropout, device):
         | 
| 641 | 
            +
                    super().__init__()
         | 
| 642 | 
            +
                    assert hid_dim % n_heads == 0
         | 
| 643 | 
            +
                    self.hid_dim = hid_dim
         | 
| 644 | 
            +
                    self.n_heads = n_heads
         | 
| 645 | 
            +
                    self.head_dim = hid_dim // n_heads
         | 
| 646 | 
            +
                    self.fc_q = nn.Linear(hid_dim, hid_dim)
         | 
| 647 | 
            +
                    self.fc_k = nn.Linear(hid_dim, hid_dim)
         | 
| 648 | 
            +
                    self.fc_v = nn.Linear(hid_dim, hid_dim)
         | 
| 649 | 
            +
                    self.fc_o = nn.Linear(hid_dim, hid_dim)
         | 
| 650 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 651 | 
            +
                    self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                def forward(self, query, key, value, mask=None):
         | 
| 654 | 
            +
                    batch_size = query.shape[0]
         | 
| 655 | 
            +
                    # query = [batch size, query len, hid dim]
         | 
| 656 | 
            +
                    # key = [batch size, key len, hid dim]
         | 
| 657 | 
            +
                    # value = [batch size, value len, hid dim]
         | 
| 658 | 
            +
                    Q = self.fc_q(query)
         | 
| 659 | 
            +
                    K = self.fc_k(key)
         | 
| 660 | 
            +
                    V = self.fc_v(value)
         | 
| 661 | 
            +
                    # Q = [batch size, query len, hid dim]
         | 
| 662 | 
            +
                    # K = [batch size, key len, hid dim]
         | 
| 663 | 
            +
                    # V = [batch size, value len, hid dim]
         | 
| 664 | 
            +
                    Q = Q.view(batch_size, -1, self.n_heads,
         | 
| 665 | 
            +
                               self.head_dim).permute(0, 2, 1, 3)
         | 
| 666 | 
            +
                    K = K.view(batch_size, -1, self.n_heads,
         | 
| 667 | 
            +
                               self.head_dim).permute(0, 2, 1, 3)
         | 
| 668 | 
            +
                    V = V.view(batch_size, -1, self.n_heads,
         | 
| 669 | 
            +
                               self.head_dim).permute(0, 2, 1, 3)
         | 
| 670 | 
            +
                    # Q = [batch size, n heads, query len, head dim]
         | 
| 671 | 
            +
                    # K = [batch size, n heads, key len, head dim]
         | 
| 672 | 
            +
                    # V = [batch size, n heads, value len, head dim]
         | 
| 673 | 
            +
                    energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
         | 
| 674 | 
            +
                    # energy = [batch size, n heads, query len, key len]
         | 
| 675 | 
            +
                    if mask is not None:
         | 
| 676 | 
            +
                        energy = energy.masked_fill(mask == 0, -1e10)
         | 
| 677 | 
            +
                    attention = torch.softmax(energy, dim=-1)
         | 
| 678 | 
            +
                    # attention = [batch size, n heads, query len, key len]
         | 
| 679 | 
            +
                    x = torch.matmul(self.dropout(attention), V)
         | 
| 680 | 
            +
                    # x = [batch size, n heads, query len, head dim]
         | 
| 681 | 
            +
                    x = x.permute(0, 2, 1, 3).contiguous()
         | 
| 682 | 
            +
                    # x = [batch size, query len, n heads, head dim]
         | 
| 683 | 
            +
                    x = x.view(batch_size, -1, self.hid_dim)
         | 
| 684 | 
            +
                    # x = [batch size, query len, hid dim]
         | 
| 685 | 
            +
                    x = self.fc_o(x)
         | 
| 686 | 
            +
                    # x = [batch size, query len, hid dim]
         | 
| 687 | 
            +
                    return x, attention
         | 
| 688 | 
            +
             | 
| 689 | 
            +
             | 
| 690 | 
            +
            class PositionwiseFeedforwardLayer(nn.Module):
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                def __init__(self, hid_dim, pf_dim, dropout):
         | 
| 693 | 
            +
                    super().__init__()
         | 
| 694 | 
            +
                    self.fc_1 = nn.Linear(hid_dim, pf_dim)
         | 
| 695 | 
            +
                    self.fc_2 = nn.Linear(pf_dim, hid_dim)
         | 
| 696 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                def forward(self, x):
         | 
| 699 | 
            +
                    # x = [batch size, seq len, hid dim]
         | 
| 700 | 
            +
                    x = self.dropout(torch.relu(self.fc_1(x)))
         | 
| 701 | 
            +
                    # x = [batch size, seq len, pf dim]
         | 
| 702 | 
            +
                    x = self.fc_2(x)
         | 
| 703 | 
            +
                    # x = [batch size, seq len, hid dim]
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                    return x
         | 
| 706 | 
            +
             | 
| 707 | 
            +
             | 
| 708 | 
            +
            class Decoder(nn.Module):
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                def __init__(
         | 
| 711 | 
            +
                    self,
         | 
| 712 | 
            +
                    output_dim,
         | 
| 713 | 
            +
                    hid_dim,
         | 
| 714 | 
            +
                    n_layers,
         | 
| 715 | 
            +
                    n_heads,
         | 
| 716 | 
            +
                    pf_dim,
         | 
| 717 | 
            +
                    dropout,
         | 
| 718 | 
            +
                    max_length,
         | 
| 719 | 
            +
                    device,
         | 
| 720 | 
            +
                ):
         | 
| 721 | 
            +
                    super().__init__()
         | 
| 722 | 
            +
                    self.device = device
         | 
| 723 | 
            +
                    self.tok_embedding = nn.Embedding(output_dim, hid_dim)
         | 
| 724 | 
            +
                    self.pos_embedding = nn.Embedding(max_length, hid_dim)
         | 
| 725 | 
            +
                    self.layers = nn.ModuleList([
         | 
| 726 | 
            +
                        DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
         | 
| 727 | 
            +
                        for _ in range(n_layers)
         | 
| 728 | 
            +
                    ])
         | 
| 729 | 
            +
                    self.fc_out = nn.Linear(hid_dim, output_dim)
         | 
| 730 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 731 | 
            +
                    self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                def forward(self, trg, enc_src, trg_mask, src_mask):
         | 
| 734 | 
            +
                    # trg = [batch size, trg len]
         | 
| 735 | 
            +
                    # enc_src = [batch size, src len, hid dim]
         | 
| 736 | 
            +
                    # trg_mask = [batch size, trg len]
         | 
| 737 | 
            +
                    # src_mask = [batch size, src len]
         | 
| 738 | 
            +
                    batch_size = trg.shape[0]
         | 
| 739 | 
            +
                    trg_len = trg.shape[1]
         | 
| 740 | 
            +
                    pos = (torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size,
         | 
| 741 | 
            +
                                                                        1).to(self.device))
         | 
| 742 | 
            +
                    # pos = [batch size, trg len]
         | 
| 743 | 
            +
                    trg = self.dropout((self.tok_embedding(trg) * self.scale) +
         | 
| 744 | 
            +
                                       self.pos_embedding(pos))
         | 
| 745 | 
            +
                    # trg = [batch size, trg len, hid dim]
         | 
| 746 | 
            +
                    for layer in self.layers:
         | 
| 747 | 
            +
                        trg, attention = layer(trg, enc_src, trg_mask, src_mask)
         | 
| 748 | 
            +
                    # trg = [batch size, trg len, hid dim]
         | 
| 749 | 
            +
                    # attention = [batch size, n heads, trg len, src len]
         | 
| 750 | 
            +
                    output = self.fc_out(trg)
         | 
| 751 | 
            +
                    # output = [batch size, trg len, output dim]
         | 
| 752 | 
            +
                    return output, attention
         | 
| 753 | 
            +
             | 
| 754 | 
            +
             | 
| 755 | 
            +
            class DecoderLayer(nn.Module):
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
         | 
| 758 | 
            +
                    super().__init__()
         | 
| 759 | 
            +
                    self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
         | 
| 760 | 
            +
                    self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
         | 
| 761 | 
            +
                    self.ff_layer_norm = nn.LayerNorm(hid_dim)
         | 
| 762 | 
            +
                    self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
         | 
| 763 | 
            +
                                                                  dropout, device)
         | 
| 764 | 
            +
                    self.encoder_attention = MultiHeadAttentionLayer(
         | 
| 765 | 
            +
                        hid_dim, n_heads, dropout, device)
         | 
| 766 | 
            +
                    self.positionwise_feedforward = PositionwiseFeedforwardLayer(
         | 
| 767 | 
            +
                        hid_dim, pf_dim, dropout)
         | 
| 768 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                def forward(self, trg, enc_src, trg_mask, src_mask):
         | 
| 771 | 
            +
                    # trg = [batch size, trg len, hid dim]
         | 
| 772 | 
            +
                    # enc_src = [batch size, src len, hid dim]
         | 
| 773 | 
            +
                    # trg_mask = [batch size, trg len]
         | 
| 774 | 
            +
                    # src_mask = [batch size, src len]
         | 
| 775 | 
            +
                    # self attention
         | 
| 776 | 
            +
                    _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
         | 
| 777 | 
            +
                    # dropout, residual connection and layer norm
         | 
| 778 | 
            +
                    trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
         | 
| 779 | 
            +
                    # trg = [batch size, trg len, hid dim]
         | 
| 780 | 
            +
                    # encoder attention
         | 
| 781 | 
            +
                    _trg, attention = self.encoder_attention(trg, enc_src, enc_src,
         | 
| 782 | 
            +
                                                             src_mask)
         | 
| 783 | 
            +
                    # dropout, residual connection and layer norm
         | 
| 784 | 
            +
                    trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
         | 
| 785 | 
            +
                    # trg = [batch size, trg len, hid dim]
         | 
| 786 | 
            +
                    # positionwise feedforward
         | 
| 787 | 
            +
                    _trg = self.positionwise_feedforward(trg)
         | 
| 788 | 
            +
                    # dropout, residual and layer norm
         | 
| 789 | 
            +
                    trg = self.ff_layer_norm(trg + self.dropout(_trg))
         | 
| 790 | 
            +
                    # trg = [batch size, trg len, hid dim]
         | 
| 791 | 
            +
                    # attention = [batch size, n heads, trg len, src len]
         | 
| 792 | 
            +
                    return trg, attention
         | 
| 793 | 
            +
             | 
| 794 | 
            +
             | 
| 795 | 
            +
            class Seq2Seq(nn.Module, Convo):
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                def __init__(
         | 
| 798 | 
            +
                    self,
         | 
| 799 | 
            +
                    encoder,
         | 
| 800 | 
            +
                    decoder,
         | 
| 801 | 
            +
                    src_pad_idx,
         | 
| 802 | 
            +
                    trg_pad_idx,
         | 
| 803 | 
            +
                    device,
         | 
| 804 | 
            +
                    loader_train: DataLoader,
         | 
| 805 | 
            +
                    out: str,
         | 
| 806 | 
            +
                    loader_valid=None,
         | 
| 807 | 
            +
                    loader_drugex=None,
         | 
| 808 | 
            +
                    epochs=100,
         | 
| 809 | 
            +
                    lr=0.0005,
         | 
| 810 | 
            +
                    clip=0.1,
         | 
| 811 | 
            +
                    reverse=True,
         | 
| 812 | 
            +
                    TRG=None,
         | 
| 813 | 
            +
                    SRC=None,
         | 
| 814 | 
            +
                ):
         | 
| 815 | 
            +
                    super().__init__()
         | 
| 816 | 
            +
                    self.encoder = encoder
         | 
| 817 | 
            +
                    self.decoder = decoder
         | 
| 818 | 
            +
                    self.src_pad_idx = src_pad_idx
         | 
| 819 | 
            +
                    self.trg_pad_idx = trg_pad_idx
         | 
| 820 | 
            +
                    self.device = device
         | 
| 821 | 
            +
                    self.loader_train = loader_train
         | 
| 822 | 
            +
                    self.out = out
         | 
| 823 | 
            +
                    self.loader_valid = loader_valid
         | 
| 824 | 
            +
                    self.loader_drugex = loader_drugex
         | 
| 825 | 
            +
                    self.epochs = epochs
         | 
| 826 | 
            +
                    self.lr = lr
         | 
| 827 | 
            +
                    self.clip = clip
         | 
| 828 | 
            +
                    self.reverse = reverse
         | 
| 829 | 
            +
                    self.TRG = TRG
         | 
| 830 | 
            +
                    self.SRC = SRC
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                def make_src_mask(self, src):
         | 
| 833 | 
            +
                    # src = [batch size, src len]
         | 
| 834 | 
            +
                    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
         | 
| 835 | 
            +
                    # src_mask = [batch size, 1, 1, src len]
         | 
| 836 | 
            +
                    return src_mask
         | 
| 837 | 
            +
             | 
| 838 | 
            +
                def make_trg_mask(self, trg):
         | 
| 839 | 
            +
                    # trg = [batch size, trg len]
         | 
| 840 | 
            +
                    trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
         | 
| 841 | 
            +
                    # trg_pad_mask = [batch size, 1, 1, trg len]
         | 
| 842 | 
            +
                    trg_len = trg.shape[1]
         | 
| 843 | 
            +
                    trg_sub_mask = torch.tril(
         | 
| 844 | 
            +
                        torch.ones((trg_len, trg_len), device=self.device)).bool()
         | 
| 845 | 
            +
                    # trg_sub_mask = [trg len, trg len]
         | 
| 846 | 
            +
                    trg_mask = trg_pad_mask & trg_sub_mask
         | 
| 847 | 
            +
                    # trg_mask = [batch size, 1, trg len, trg len]
         | 
| 848 | 
            +
                    return trg_mask
         | 
| 849 | 
            +
             | 
| 850 | 
            +
                def forward(self, src, trg):
         | 
| 851 | 
            +
                    # src = [batch size, src len]
         | 
| 852 | 
            +
                    # trg = [batch size, trg len]
         | 
| 853 | 
            +
                    src_mask = self.make_src_mask(src)
         | 
| 854 | 
            +
                    trg_mask = self.make_trg_mask(trg)
         | 
| 855 | 
            +
                    # src_mask = [batch size, 1, 1, src len]
         | 
| 856 | 
            +
                    # trg_mask = [batch size, 1, trg len, trg len]
         | 
| 857 | 
            +
                    enc_src = self.encoder(src, src_mask)
         | 
| 858 | 
            +
                    # enc_src = [batch size, src len, hid dim]
         | 
| 859 | 
            +
                    output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
         | 
| 860 | 
            +
                    # output = [batch size, trg len, output dim]
         | 
| 861 | 
            +
                    # attention = [batch size, n heads, trg len, src len]
         | 
| 862 | 
            +
                    return output, attention
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                def translate_sentence(self, src, trg_field, device, max_len=202):
         | 
| 865 | 
            +
                    self.eval()
         | 
| 866 | 
            +
                    src_mask = self.make_src_mask(src)
         | 
| 867 | 
            +
                    with torch.no_grad():
         | 
| 868 | 
            +
                        enc_src = self.encoder(src, src_mask)
         | 
| 869 | 
            +
                    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
         | 
| 870 | 
            +
                    batch_size = src.shape[0]
         | 
| 871 | 
            +
                    trg = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
         | 
| 872 | 
            +
                    trg = trg.repeat(batch_size, 1)
         | 
| 873 | 
            +
                    for i in range(max_len):
         | 
| 874 | 
            +
                        # turned model into self.
         | 
| 875 | 
            +
                        trg_mask = self.make_trg_mask(trg)
         | 
| 876 | 
            +
                        with torch.no_grad():
         | 
| 877 | 
            +
                            output, attention = self.decoder(trg, enc_src, trg_mask,
         | 
| 878 | 
            +
                                                             src_mask)
         | 
| 879 | 
            +
                        pred_tokens = output.argmax(2)[:, -1].unsqueeze(1)
         | 
| 880 | 
            +
                        trg = torch.cat((trg, pred_tokens), 1)
         | 
| 881 | 
            +
             | 
| 882 | 
            +
                    return output
         | 
| 883 | 
            +
             | 
| 884 | 
            +
             | 
| 885 | 
            +
            def remove_floats(df: pd.DataFrame, subset: str):
         | 
| 886 | 
            +
                """Preprocessing step to remove any entries that are not strings"""
         | 
| 887 | 
            +
                df_subset = df[subset]
         | 
| 888 | 
            +
                df[subset] = df[subset].astype(str)
         | 
| 889 | 
            +
                # only keep entries that stayed the same after applying astype str
         | 
| 890 | 
            +
                df = df[df[subset] == df_subset].copy()
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                return df
         | 
| 893 | 
            +
             | 
| 894 | 
            +
             | 
| 895 | 
            +
            def smi_tokenizer(smi: str, reverse=False) -> list:
         | 
| 896 | 
            +
                """
         | 
| 897 | 
            +
                Tokenize a SMILES molecule
         | 
| 898 | 
            +
                """
         | 
| 899 | 
            +
                pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
         | 
| 900 | 
            +
                regex = re.compile(pattern)
         | 
| 901 | 
            +
                # tokens = ['<sos>'] + [token for token in regex.findall(smi)] + ['<eos>']
         | 
| 902 | 
            +
                tokens = [token for token in regex.findall(smi)]
         | 
| 903 | 
            +
                # assert smi == ''.join(tokens[1:-1])
         | 
| 904 | 
            +
                assert smi == "".join(tokens[:])
         | 
| 905 | 
            +
                # try:
         | 
| 906 | 
            +
                #     assert smi == "".join(tokens[:])
         | 
| 907 | 
            +
                # except:
         | 
| 908 | 
            +
                #     print(smi)
         | 
| 909 | 
            +
                #     print("".join(tokens[:]))
         | 
| 910 | 
            +
                if reverse:
         | 
| 911 | 
            +
                    return tokens[::-1]
         | 
| 912 | 
            +
                return tokens
         | 
| 913 | 
            +
             | 
| 914 | 
            +
             | 
| 915 | 
            +
            def init_weights(m: nn.Module):
         | 
| 916 | 
            +
                if hasattr(m, "weight") and m.weight.dim() > 1:
         | 
| 917 | 
            +
                    nn.init.xavier_uniform_(m.weight.data)
         | 
| 918 | 
            +
             | 
| 919 | 
            +
             | 
| 920 | 
            +
            def count_parameters(model: nn.Module):
         | 
| 921 | 
            +
                return sum(p.numel() for p in model.parameters() if p.requires_grad)
         | 
| 922 | 
            +
             | 
| 923 | 
            +
             | 
| 924 | 
            +
            def epoch_time(start_time, end_time):
         | 
| 925 | 
            +
                elapsed_time = end_time - start_time
         | 
| 926 | 
            +
                elapsed_mins = int(elapsed_time / 60)
         | 
| 927 | 
            +
                elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
         | 
| 928 | 
            +
                return elapsed_mins, elapsed_secs
         | 
| 929 | 
            +
             | 
| 930 | 
            +
             | 
| 931 | 
            +
            def initialize_model(folder_out: str,
         | 
| 932 | 
            +
                                 data_source: str,
         | 
| 933 | 
            +
                                 error_source: str,
         | 
| 934 | 
            +
                                 device: torch.device,
         | 
| 935 | 
            +
                                 threshold: int,
         | 
| 936 | 
            +
                                 epochs: int,
         | 
| 937 | 
            +
                                 layers: int = 3,
         | 
| 938 | 
            +
                                 batch_size: int = 16,
         | 
| 939 | 
            +
                                 invalid_type: str = "all",
         | 
| 940 | 
            +
                                 num_errors: int = 1,
         | 
| 941 | 
            +
                                 validation_step=False):
         | 
| 942 | 
            +
                """Create encoder decoder models for specified model (currently only translator) & type of invalid SMILES
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                param data: collection of invalid, valid SMILES pairs
         | 
| 945 | 
            +
                param invalid_smiles_path: path to previously generated invalid SMILES
         | 
| 946 | 
            +
                param invalid_type: type of errors introduced into invalid SMILES
         | 
| 947 | 
            +
             | 
| 948 | 
            +
                return:
         | 
| 949 | 
            +
             | 
| 950 | 
            +
                """
         | 
| 951 | 
            +
             | 
| 952 | 
            +
                # set fields
         | 
| 953 | 
            +
                SRC = Field(
         | 
| 954 | 
            +
                    tokenize=lambda x: smi_tokenizer(x),
         | 
| 955 | 
            +
                    init_token="<sos>",
         | 
| 956 | 
            +
                    eos_token="<eos>",
         | 
| 957 | 
            +
                    batch_first=True,
         | 
| 958 | 
            +
                )
         | 
| 959 | 
            +
                TRG = Field(
         | 
| 960 | 
            +
                    tokenize=lambda x: smi_tokenizer(x, reverse=True),
         | 
| 961 | 
            +
                    init_token="<sos>",
         | 
| 962 | 
            +
                    eos_token="<eos>",
         | 
| 963 | 
            +
                    batch_first=True,
         | 
| 964 | 
            +
                )
         | 
| 965 | 
            +
             | 
| 966 | 
            +
                if validation_step:
         | 
| 967 | 
            +
                    train, val = TabularDataset.splits(
         | 
| 968 | 
            +
                        path=f'{folder_out}errors/split/',
         | 
| 969 | 
            +
                        train=f"{data_source}_{invalid_type}_{num_errors}_errors_train.csv",
         | 
| 970 | 
            +
                        validation=
         | 
| 971 | 
            +
                        f"{data_source}_{invalid_type}_{num_errors}_errors_dev.csv",
         | 
| 972 | 
            +
                        format="CSV",
         | 
| 973 | 
            +
                        skip_header=False,
         | 
| 974 | 
            +
                        fields={
         | 
| 975 | 
            +
                            "ERROR": ("src", SRC),
         | 
| 976 | 
            +
                            "STD_SMILES": ("trg", TRG)
         | 
| 977 | 
            +
                        },
         | 
| 978 | 
            +
                    )
         | 
| 979 | 
            +
                    SRC.build_vocab(train, val, max_size=1000)
         | 
| 980 | 
            +
                    TRG.build_vocab(train, val, max_size=1000)
         | 
| 981 | 
            +
                else:
         | 
| 982 | 
            +
                    train = TabularDataset(
         | 
| 983 | 
            +
                        path=
         | 
| 984 | 
            +
                        f'{folder_out}{data_source}_{invalid_type}_{num_errors}_errors.csv',
         | 
| 985 | 
            +
                        format="CSV",
         | 
| 986 | 
            +
                        skip_header=False,
         | 
| 987 | 
            +
                        fields={
         | 
| 988 | 
            +
                            "ERROR": ("src", SRC),
         | 
| 989 | 
            +
                            "STD_SMILES": ("trg", TRG)
         | 
| 990 | 
            +
                        },
         | 
| 991 | 
            +
                    )
         | 
| 992 | 
            +
                    SRC.build_vocab(train, max_size=1000)
         | 
| 993 | 
            +
                    TRG.build_vocab(train, max_size=1000)
         | 
| 994 | 
            +
             | 
| 995 | 
            +
                drugex = TabularDataset(
         | 
| 996 | 
            +
                    path=error_source,
         | 
| 997 | 
            +
                    format="csv",
         | 
| 998 | 
            +
                    skip_header=False,
         | 
| 999 | 
            +
                    fields={
         | 
| 1000 | 
            +
                        "SMILES": ("src", SRC),
         | 
| 1001 | 
            +
                        "SMILES_TARGET": ("trg", TRG)
         | 
| 1002 | 
            +
                    },
         | 
| 1003 | 
            +
                )
         | 
| 1004 | 
            +
             | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                #SRC.vocab = torch.load('vocab_src.pth')
         | 
| 1007 | 
            +
                #TRG.vocab = torch.load('vocab_trg.pth')
         | 
| 1008 | 
            +
             | 
| 1009 | 
            +
                # model parameters
         | 
| 1010 | 
            +
                EPOCHS = epochs
         | 
| 1011 | 
            +
                BATCH_SIZE = batch_size
         | 
| 1012 | 
            +
                INPUT_DIM = len(SRC.vocab)
         | 
| 1013 | 
            +
                OUTPUT_DIM = len(TRG.vocab)
         | 
| 1014 | 
            +
                HID_DIM = 256
         | 
| 1015 | 
            +
                ENC_LAYERS = layers
         | 
| 1016 | 
            +
                DEC_LAYERS = layers
         | 
| 1017 | 
            +
                ENC_HEADS = 8
         | 
| 1018 | 
            +
                DEC_HEADS = 8
         | 
| 1019 | 
            +
                ENC_PF_DIM = 512
         | 
| 1020 | 
            +
                DEC_PF_DIM = 512
         | 
| 1021 | 
            +
                ENC_DROPOUT = 0.1
         | 
| 1022 | 
            +
                DEC_DROPOUT = 0.1
         | 
| 1023 | 
            +
                SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
         | 
| 1024 | 
            +
                TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
         | 
| 1025 | 
            +
                # add 2 to length for start and stop tokens
         | 
| 1026 | 
            +
                MAX_LENGTH = threshold + 2
         | 
| 1027 | 
            +
             | 
| 1028 | 
            +
                # model name
         | 
| 1029 | 
            +
                MODEL_OUT_FOLDER = f"{folder_out}"
         | 
| 1030 | 
            +
             | 
| 1031 | 
            +
                MODEL_NAME = "transformer_%s_%s_%s_%s_%s" % (
         | 
| 1032 | 
            +
                    invalid_type, num_errors, data_source, BATCH_SIZE, layers)
         | 
| 1033 | 
            +
                if not os.path.exists(MODEL_OUT_FOLDER):
         | 
| 1034 | 
            +
                    os.mkdir(MODEL_OUT_FOLDER)
         | 
| 1035 | 
            +
             | 
| 1036 | 
            +
                out = os.path.join(MODEL_OUT_FOLDER, MODEL_NAME)
         | 
| 1037 | 
            +
             | 
| 1038 | 
            +
                torch.save(SRC.vocab, f'{out}_vocab_src.pth')
         | 
| 1039 | 
            +
                torch.save(TRG.vocab, f'{out}_vocab_trg.pth')
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                # iterator is a dataloader
         | 
| 1042 | 
            +
                # iterator to pass to the same length and create batches in which the
         | 
| 1043 | 
            +
                # amount of padding is minimized
         | 
| 1044 | 
            +
                if validation_step:
         | 
| 1045 | 
            +
                    train_iter, val_iter = BucketIterator.splits(
         | 
| 1046 | 
            +
                        (train, val),
         | 
| 1047 | 
            +
                        batch_sizes=(BATCH_SIZE, 256),
         | 
| 1048 | 
            +
                        sort_within_batch=True,
         | 
| 1049 | 
            +
                        shuffle=True,
         | 
| 1050 | 
            +
                        # the BucketIterator needs to be told what function it should use to
         | 
| 1051 | 
            +
                        # group the data.
         | 
| 1052 | 
            +
                        sort_key=lambda x: len(x.src),
         | 
| 1053 | 
            +
                        device=device,
         | 
| 1054 | 
            +
                    )
         | 
| 1055 | 
            +
                else:
         | 
| 1056 | 
            +
                    train_iter = BucketIterator(
         | 
| 1057 | 
            +
                        train,
         | 
| 1058 | 
            +
                        batch_size=BATCH_SIZE,
         | 
| 1059 | 
            +
                        sort_within_batch=True,
         | 
| 1060 | 
            +
                        shuffle=True,
         | 
| 1061 | 
            +
                        # the BucketIterator needs to be told what function it should use to
         | 
| 1062 | 
            +
                        # group the data.
         | 
| 1063 | 
            +
                        sort_key=lambda x: len(x.src),
         | 
| 1064 | 
            +
                        device=device,
         | 
| 1065 | 
            +
                    )
         | 
| 1066 | 
            +
                    val_iter = None
         | 
| 1067 | 
            +
             | 
| 1068 | 
            +
                drugex_iter = Iterator(
         | 
| 1069 | 
            +
                    drugex,
         | 
| 1070 | 
            +
                    batch_size=64,
         | 
| 1071 | 
            +
                    device=device,
         | 
| 1072 | 
            +
                    sort=False,
         | 
| 1073 | 
            +
                    sort_within_batch=True,
         | 
| 1074 | 
            +
                    sort_key=lambda x: len(x.src),
         | 
| 1075 | 
            +
                    repeat=False,
         | 
| 1076 | 
            +
                )
         | 
| 1077 | 
            +
             | 
| 1078 | 
            +
             | 
| 1079 | 
            +
                # model initialization
         | 
| 1080 | 
            +
             | 
| 1081 | 
            +
                enc = Encoder(
         | 
| 1082 | 
            +
                    INPUT_DIM,
         | 
| 1083 | 
            +
                    HID_DIM,
         | 
| 1084 | 
            +
                    ENC_LAYERS,
         | 
| 1085 | 
            +
                    ENC_HEADS,
         | 
| 1086 | 
            +
                    ENC_PF_DIM,
         | 
| 1087 | 
            +
                    ENC_DROPOUT,
         | 
| 1088 | 
            +
                    MAX_LENGTH,
         | 
| 1089 | 
            +
                    device,
         | 
| 1090 | 
            +
                )
         | 
| 1091 | 
            +
                dec = Decoder(
         | 
| 1092 | 
            +
                    OUTPUT_DIM,
         | 
| 1093 | 
            +
                    HID_DIM,
         | 
| 1094 | 
            +
                    DEC_LAYERS,
         | 
| 1095 | 
            +
                    DEC_HEADS,
         | 
| 1096 | 
            +
                    DEC_PF_DIM,
         | 
| 1097 | 
            +
                    DEC_DROPOUT,
         | 
| 1098 | 
            +
                    MAX_LENGTH,
         | 
| 1099 | 
            +
                    device,
         | 
| 1100 | 
            +
                )
         | 
| 1101 | 
            +
             | 
| 1102 | 
            +
                model = Seq2Seq(
         | 
| 1103 | 
            +
                    enc,
         | 
| 1104 | 
            +
                    dec,
         | 
| 1105 | 
            +
                    SRC_PAD_IDX,
         | 
| 1106 | 
            +
                    TRG_PAD_IDX,
         | 
| 1107 | 
            +
                    device,
         | 
| 1108 | 
            +
                    train_iter,
         | 
| 1109 | 
            +
                    out=out,
         | 
| 1110 | 
            +
                    loader_valid=val_iter,
         | 
| 1111 | 
            +
                    loader_drugex=drugex_iter,
         | 
| 1112 | 
            +
                    epochs=EPOCHS,
         | 
| 1113 | 
            +
                    TRG=TRG,
         | 
| 1114 | 
            +
                    SRC=SRC,
         | 
| 1115 | 
            +
                ).to(device)
         | 
| 1116 | 
            +
             | 
| 1117 | 
            +
             | 
| 1118 | 
            +
             | 
| 1119 | 
            +
             | 
| 1120 | 
            +
                return model, out, SRC
         | 
| 1121 | 
            +
             | 
| 1122 | 
            +
             | 
| 1123 | 
            +
            def train_model(model, out, assess):
         | 
| 1124 | 
            +
                """Apply given weights (& assess performance or train further) or start training new model
         | 
| 1125 | 
            +
             | 
| 1126 | 
            +
                Args:
         | 
| 1127 | 
            +
                    model: initialized model
         | 
| 1128 | 
            +
                    out: .pkg file with model parameters
         | 
| 1129 | 
            +
                    asses: bool 
         | 
| 1130 | 
            +
             | 
| 1131 | 
            +
                Returns:
         | 
| 1132 | 
            +
                    model with (new) weights
         | 
| 1133 | 
            +
                """
         | 
| 1134 | 
            +
             | 
| 1135 | 
            +
                if os.path.exists(f"{out}.pkg") and assess:
         | 
| 1136 | 
            +
             | 
| 1137 | 
            +
             | 
| 1138 | 
            +
                    model.load_state_dict(torch.load(f=out + ".pkg"))
         | 
| 1139 | 
            +
                    (
         | 
| 1140 | 
            +
                        valids,
         | 
| 1141 | 
            +
                        loss_valid,
         | 
| 1142 | 
            +
                        valids_de,
         | 
| 1143 | 
            +
                        df_output,
         | 
| 1144 | 
            +
                        df_output_de,
         | 
| 1145 | 
            +
                        right_molecules,
         | 
| 1146 | 
            +
                        complexity,
         | 
| 1147 | 
            +
                        unchanged,
         | 
| 1148 | 
            +
                        unchanged_de,
         | 
| 1149 | 
            +
                    ) = model.evaluate(True)
         | 
| 1150 | 
            +
             | 
| 1151 | 
            +
             | 
| 1152 | 
            +
                    # log = open('unchanged.log', 'a')
         | 
| 1153 | 
            +
                    # info = f'type: comb unchanged: {unchan:.4g} unchanged_drugex: {unchan_de:.4g}'
         | 
| 1154 | 
            +
                    # print(info, file=log, flush = True)
         | 
| 1155 | 
            +
                    # print(valids_de)
         | 
| 1156 | 
            +
                    # print(unchanged_de)
         | 
| 1157 | 
            +
             | 
| 1158 | 
            +
                    # print(unchan)
         | 
| 1159 | 
            +
                    # print(unchan_de)
         | 
| 1160 | 
            +
                    # df_output_de.to_csv(f'{out}_de_new.csv', index = False)
         | 
| 1161 | 
            +
             | 
| 1162 | 
            +
                    # error_de = 1 - valids_de / len(drugex_iter.dataset)
         | 
| 1163 | 
            +
                    # print(error_de)
         | 
| 1164 | 
            +
                    # df_output.to_csv(f'{out}_par.csv', index = False)
         | 
| 1165 | 
            +
             | 
| 1166 | 
            +
                elif os.path.exists(f"{out}.pkg"):
         | 
| 1167 | 
            +
             | 
| 1168 | 
            +
                    # starts from the model after the last epoch, not the best epoch
         | 
| 1169 | 
            +
                    model.load_state_dict(torch.load(f=out + "_last.pkg"))
         | 
| 1170 | 
            +
                    # need to change how log file names epochs
         | 
| 1171 | 
            +
                    model.train_model()
         | 
| 1172 | 
            +
                else:
         | 
| 1173 | 
            +
             | 
| 1174 | 
            +
                    model = model.apply(init_weights)
         | 
| 1175 | 
            +
                    model.train_model()
         | 
| 1176 | 
            +
             | 
| 1177 | 
            +
                return model
         | 
| 1178 | 
            +
             | 
| 1179 | 
            +
             | 
| 1180 | 
            +
            def correct_SMILES(model, out, error_source, device, SRC):
         | 
| 1181 | 
            +
                """Model that is given corrects SMILES and return number of correct ouputs and dataframe containing all outputs
         | 
| 1182 | 
            +
                Args:
         | 
| 1183 | 
            +
                    model: initialized model
         | 
| 1184 | 
            +
                    out: .pkg file with model parameters
         | 
| 1185 | 
            +
                    asses: bool 
         | 
| 1186 | 
            +
             | 
| 1187 | 
            +
                Returns:
         | 
| 1188 | 
            +
                    valids: number of fixed outputs
         | 
| 1189 | 
            +
                    df_output: dataframe containing output (either correct or incorrect) & original input
         | 
| 1190 | 
            +
                """
         | 
| 1191 | 
            +
                ## account for tokens that are not yet in SRC without changing existing SRC token embeddings
         | 
| 1192 | 
            +
                errors = TabularDataset(
         | 
| 1193 | 
            +
                    path=error_source,
         | 
| 1194 | 
            +
                    format="csv",
         | 
| 1195 | 
            +
                    skip_header=False,
         | 
| 1196 | 
            +
                    fields={"SMILES": ("src", SRC)},
         | 
| 1197 | 
            +
                )
         | 
| 1198 | 
            +
             | 
| 1199 | 
            +
                errors_loader = Iterator(
         | 
| 1200 | 
            +
                    errors,
         | 
| 1201 | 
            +
                    batch_size=64,
         | 
| 1202 | 
            +
                    device=device,
         | 
| 1203 | 
            +
                    sort=False,
         | 
| 1204 | 
            +
                    sort_within_batch=True,
         | 
| 1205 | 
            +
                    sort_key=lambda x: len(x.src),
         | 
| 1206 | 
            +
                    repeat=False,
         | 
| 1207 | 
            +
                )
         | 
| 1208 | 
            +
                model.load_state_dict(torch.load(f=out + ".pkg",map_location=torch.device('cpu')))
         | 
| 1209 | 
            +
                # add option to use different iterator maybe?
         | 
| 1210 | 
            +
             | 
| 1211 | 
            +
                valids, df_output = model.translate(errors_loader)
         | 
| 1212 | 
            +
                #df_output.to_csv(f"{error_source}_fixed.csv", index=False)
         | 
| 1213 | 
            +
             | 
| 1214 | 
            +
             | 
| 1215 | 
            +
                return valids, df_output
         | 
| 1216 | 
            +
             | 
| 1217 | 
            +
             | 
| 1218 | 
            +
             | 
| 1219 | 
            +
            class smi_correct(object):
         | 
| 1220 | 
            +
                def __init__(self, model_name, trans_file_path):
         | 
| 1221 | 
            +
                # set random seed, used for error generation & initiation transformer
         | 
| 1222 | 
            +
                
         | 
| 1223 | 
            +
                    self.SEED = 42
         | 
| 1224 | 
            +
                    random.seed(self.SEED)
         | 
| 1225 | 
            +
                    self.model_name = model_name
         | 
| 1226 | 
            +
                    self.folder_out = "data/"
         | 
| 1227 | 
            +
                    
         | 
| 1228 | 
            +
                    self.trans_file_path = trans_file_path
         | 
| 1229 | 
            +
             | 
| 1230 | 
            +
                    if not os.path.exists(self.folder_out):
         | 
| 1231 | 
            +
                        os.makedirs(self.folder_out)
         | 
| 1232 | 
            +
                        
         | 
| 1233 | 
            +
                    self.invalid_type = 'multiple'
         | 
| 1234 | 
            +
                    self.num_errors = 12
         | 
| 1235 | 
            +
                    self.threshold = 200
         | 
| 1236 | 
            +
                    self.data_source = f"PAPYRUS_{self.threshold}"
         | 
| 1237 | 
            +
                    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
         | 
| 1238 | 
            +
                    self.initialize_source = 'data/papyrus_rnn_S.csv' # change this path
         | 
| 1239 | 
            +
                    
         | 
| 1240 | 
            +
                def standardization_pipeline(self, smile):
         | 
| 1241 | 
            +
                    desalter = MolStandardize.rdMolStandardize.LargestFragmentChooser()
         | 
| 1242 | 
            +
                    std_smile = None
         | 
| 1243 | 
            +
                    if not isinstance(smile, str): return None
         | 
| 1244 | 
            +
                    m = Chem.MolFromSmiles(smile)
         | 
| 1245 | 
            +
                    # skips smiles for which no mol file could be generated
         | 
| 1246 | 
            +
                    if m is not None:
         | 
| 1247 | 
            +
                        # standardizes
         | 
| 1248 | 
            +
                        std_m = standardizer.standardize_mol(m)
         | 
| 1249 | 
            +
                        # strips salts
         | 
| 1250 | 
            +
                        std_m_p, exclude = standardizer.get_parent_mol(std_m)
         | 
| 1251 | 
            +
                        if not exclude:
         | 
| 1252 | 
            +
                            # choose largest fragment for rare cases where chembl structure
         | 
| 1253 | 
            +
                            # pipeline leaves 2 fragments
         | 
| 1254 | 
            +
                            std_m_p_d = desalter.choose(std_m_p)
         | 
| 1255 | 
            +
                            std_smile = Chem.MolToSmiles(std_m_p_d)
         | 
| 1256 | 
            +
                    return std_smile      
         | 
| 1257 | 
            +
                
         | 
| 1258 | 
            +
                def remove_smiles_duplicates(self, dataframe: pd.DataFrame,
         | 
| 1259 | 
            +
                                         subset: str) -> pd.DataFrame:
         | 
| 1260 | 
            +
                    return dataframe.drop_duplicates(subset=subset)  
         | 
| 1261 | 
            +
                
         | 
| 1262 | 
            +
                def correct(self, smi):
         | 
| 1263 | 
            +
                
         | 
| 1264 | 
            +
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 1265 | 
            +
             | 
| 1266 | 
            +
                    model, out, SRC = initialize_model(self.folder_out,
         | 
| 1267 | 
            +
                                                    self.data_source,
         | 
| 1268 | 
            +
                                                    error_source=self.initialize_source,
         | 
| 1269 | 
            +
                                                    device=device,
         | 
| 1270 | 
            +
                                                    threshold=self.threshold,
         | 
| 1271 | 
            +
                                                    epochs=30,
         | 
| 1272 | 
            +
                                                    layers=3,
         | 
| 1273 | 
            +
                                                    batch_size=16,
         | 
| 1274 | 
            +
                                                    invalid_type=self.invalid_type,
         | 
| 1275 | 
            +
                                                    num_errors=self.num_errors)
         | 
| 1276 | 
            +
             | 
| 1277 | 
            +
                    valids, df_output = correct_SMILES(model, out, smi, device,
         | 
| 1278 | 
            +
                                                    SRC)
         | 
| 1279 | 
            +
                    
         | 
| 1280 | 
            +
                    df_output["SMILES"] = df_output.apply(lambda row: self.standardization_pipeline(row["CORRECT"]), axis=1)
         | 
| 1281 | 
            +
                    
         | 
| 1282 | 
            +
                    df_output = self.remove_smiles_duplicates(df_output, subset="SMILES").drop(columns=["CORRECT", "INCORRECT", "ORIGINAL"]).dropna()
         | 
| 1283 | 
            +
                    
         | 
| 1284 | 
            +
                    return df_output
         | 
    	
        src/util/utils.py
    ADDED
    
    | @@ -0,0 +1,930 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import datetime
         | 
| 5 | 
            +
            import warnings
         | 
| 6 | 
            +
            import itertools
         | 
| 7 | 
            +
            from copy import deepcopy
         | 
| 8 | 
            +
            from functools import partial
         | 
| 9 | 
            +
            from collections import Counter
         | 
| 10 | 
            +
            from multiprocessing import Pool
         | 
| 11 | 
            +
            from statistics import mean
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import matplotlib.pyplot as plt
         | 
| 15 | 
            +
            from matplotlib.lines import Line2D
         | 
| 16 | 
            +
            from scipy.spatial.distance import cosine as cos_distance
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            import wandb
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from rdkit import Chem, DataStructs, RDLogger
         | 
| 22 | 
            +
            from rdkit.Chem import (
         | 
| 23 | 
            +
                AllChem,
         | 
| 24 | 
            +
                Draw,
         | 
| 25 | 
            +
                Descriptors,
         | 
| 26 | 
            +
                Lipinski,
         | 
| 27 | 
            +
                Crippen,
         | 
| 28 | 
            +
                rdMolDescriptors,
         | 
| 29 | 
            +
                FilterCatalog,
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
            from rdkit.Chem.Scaffolds import MurckoScaffold
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # Disable RDKit warnings
         | 
| 34 | 
            +
            RDLogger.DisableLog("rdApp.*")
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class Metrics(object):
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
                Collection of static methods to compute various metrics for molecules.
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                @staticmethod
         | 
| 43 | 
            +
                def valid(x):
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    Checks whether the molecule is valid.
         | 
| 46 | 
            +
                    
         | 
| 47 | 
            +
                    Args:
         | 
| 48 | 
            +
                        x: RDKit molecule object.
         | 
| 49 | 
            +
                    
         | 
| 50 | 
            +
                    Returns:
         | 
| 51 | 
            +
                        bool: True if molecule is valid and has a non-empty SMILES representation.
         | 
| 52 | 
            +
                    """
         | 
| 53 | 
            +
                    return x is not None and Chem.MolToSmiles(x) != ''
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                @staticmethod
         | 
| 56 | 
            +
                def tanimoto_sim_1v2(data1, data2):
         | 
| 57 | 
            +
                    """
         | 
| 58 | 
            +
                    Computes the average Tanimoto similarity for paired fingerprints.
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    Args:
         | 
| 61 | 
            +
                        data1: Fingerprint data for first set.
         | 
| 62 | 
            +
                        data2: Fingerprint data for second set.
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    Returns:
         | 
| 65 | 
            +
                        float: The average Tanimoto similarity between corresponding fingerprints.
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    # Determine the minimum size between two arrays for pairing
         | 
| 68 | 
            +
                    min_len = data1.size if data1.size > data2.size else data2
         | 
| 69 | 
            +
                    sims = []
         | 
| 70 | 
            +
                    for i in range(min_len):
         | 
| 71 | 
            +
                        sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
         | 
| 72 | 
            +
                        sims.append(sim)
         | 
| 73 | 
            +
                    # Use 'mean' from statistics; note that variable 'sim' was used, corrected to use sims list.
         | 
| 74 | 
            +
                    mean_sim = mean(sims)
         | 
| 75 | 
            +
                    return mean_sim
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                @staticmethod
         | 
| 78 | 
            +
                def mol_length(x):
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
                    Computes the length of the largest fragment (by character count) in a SMILES string.
         | 
| 81 | 
            +
                    
         | 
| 82 | 
            +
                    Args:
         | 
| 83 | 
            +
                        x (str): SMILES string.
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    Returns:
         | 
| 86 | 
            +
                        int: Number of alphabetic characters in the longest fragment of the SMILES.
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    if x is not None:
         | 
| 89 | 
            +
                        # Split at dots (.) and take the fragment with maximum length, then count alphabetic characters.
         | 
| 90 | 
            +
                        return len([char for char in max(x.split(sep="."), key=len).upper() if char.isalpha()])
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        return 0
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                @staticmethod
         | 
| 95 | 
            +
                def max_component(data, max_len):
         | 
| 96 | 
            +
                    """
         | 
| 97 | 
            +
                    Returns the average normalized length of molecules in the dataset.
         | 
| 98 | 
            +
                    
         | 
| 99 | 
            +
                    Each molecule's length is computed and divided by max_len, then averaged.
         | 
| 100 | 
            +
                    
         | 
| 101 | 
            +
                    Args:
         | 
| 102 | 
            +
                        data (iterable): Collection of SMILES strings.
         | 
| 103 | 
            +
                        max_len (int): Maximum possible length for normalization.
         | 
| 104 | 
            +
                    
         | 
| 105 | 
            +
                    Returns:
         | 
| 106 | 
            +
                        float: Normalized average length.
         | 
| 107 | 
            +
                    """
         | 
| 108 | 
            +
                    lengths = np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)
         | 
| 109 | 
            +
                    return (lengths / max_len).mean()
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                @staticmethod
         | 
| 112 | 
            +
                def mean_atom_type(data):
         | 
| 113 | 
            +
                    """
         | 
| 114 | 
            +
                    Computes the average number of unique atom types in the provided node data.
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    Args:
         | 
| 117 | 
            +
                        data (iterable): Iterable containing node data with unique atom types.
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    Returns:
         | 
| 120 | 
            +
                        float: The average count of unique atom types, subtracting one.
         | 
| 121 | 
            +
                    """
         | 
| 122 | 
            +
                    atom_types_used = []
         | 
| 123 | 
            +
                    for i in data:
         | 
| 124 | 
            +
                        # Assuming each element i has a .unique() method that returns unique atom types.
         | 
| 125 | 
            +
                        atom_types_used.append(len(i.unique().tolist()))
         | 
| 126 | 
            +
                    av_type = np.mean(atom_types_used) - 1
         | 
| 127 | 
            +
                    return av_type
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            def mols2grid_image(mols, path):
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                Saves grid images for a list of molecules.
         | 
| 133 | 
            +
                
         | 
| 134 | 
            +
                For each molecule in the list, computes 2D coordinates and saves an image file.
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                Args:
         | 
| 137 | 
            +
                    mols (list): List of RDKit molecule objects.
         | 
| 138 | 
            +
                    path (str): Directory where images will be saved.
         | 
| 139 | 
            +
                """
         | 
| 140 | 
            +
                # Replace None molecules with an empty molecule
         | 
| 141 | 
            +
                mols = [e if e is not None else Chem.RWMol() for e in mols]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                for i in range(len(mols)):
         | 
| 144 | 
            +
                    if Metrics.valid(mols[i]):
         | 
| 145 | 
            +
                        AllChem.Compute2DCoords(mols[i])
         | 
| 146 | 
            +
                        file_path = os.path.join(path, "{}.png".format(i + 1))
         | 
| 147 | 
            +
                        Draw.MolToFile(mols[i], file_path, size=(1200, 1200))
         | 
| 148 | 
            +
                        # wandb.save(file_path)  # Optionally save to Weights & Biases
         | 
| 149 | 
            +
                    else:
         | 
| 150 | 
            +
                        continue
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def save_smiles_matrices(mols, edges_hard, nodes_hard, path, data_source=None):
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                Saves the edge and node matrices along with SMILES strings to text files.
         | 
| 156 | 
            +
                
         | 
| 157 | 
            +
                Each file contains the edge matrix, node matrix, and SMILES representation for a molecule.
         | 
| 158 | 
            +
                
         | 
| 159 | 
            +
                Args:
         | 
| 160 | 
            +
                    mols (list): List of RDKit molecule objects.
         | 
| 161 | 
            +
                    edges_hard (torch.Tensor): Tensor of edge features.
         | 
| 162 | 
            +
                    nodes_hard (torch.Tensor): Tensor of node features.
         | 
| 163 | 
            +
                    path (str): Directory where files will be saved.
         | 
| 164 | 
            +
                    data_source: Optional data source information (not used in function).
         | 
| 165 | 
            +
                """
         | 
| 166 | 
            +
                mols = [e if e is not None else Chem.RWMol() for e in mols]
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                for i in range(len(mols)):
         | 
| 169 | 
            +
                    if Metrics.valid(mols[i]):
         | 
| 170 | 
            +
                        save_path = os.path.join(path, "{}.txt".format(i + 1))
         | 
| 171 | 
            +
                        with open(save_path, "a") as f:
         | 
| 172 | 
            +
                            np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n", fmt='%1.2f')
         | 
| 173 | 
            +
                            f.write("\n")
         | 
| 174 | 
            +
                            np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:", fmt='%1.2f')
         | 
| 175 | 
            +
                            f.write("\n")
         | 
| 176 | 
            +
                        # Append the SMILES representation to the file
         | 
| 177 | 
            +
                        with open(save_path, "a") as f:
         | 
| 178 | 
            +
                            print(Chem.MolToSmiles(mols[i]), file=f)
         | 
| 179 | 
            +
                        # wandb.save(save_path)  # Optionally save to Weights & Biases
         | 
| 180 | 
            +
                    else:
         | 
| 181 | 
            +
                        continue
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            def dense_to_sparse_with_attr(adj):
         | 
| 184 | 
            +
                """
         | 
| 185 | 
            +
                Converts a dense adjacency matrix to a sparse representation.
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                Args:
         | 
| 188 | 
            +
                    adj (torch.Tensor): Adjacency matrix tensor (2D or 3D) with square last two dimensions.
         | 
| 189 | 
            +
                
         | 
| 190 | 
            +
                Returns:
         | 
| 191 | 
            +
                    tuple: A tuple containing indices and corresponding edge attributes.
         | 
| 192 | 
            +
                """
         | 
| 193 | 
            +
                assert adj.dim() >= 2 and adj.dim() <= 3
         | 
| 194 | 
            +
                assert adj.size(-1) == adj.size(-2)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                index = adj.nonzero(as_tuple=True)
         | 
| 197 | 
            +
                edge_attr = adj[index]
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                if len(index) == 3:
         | 
| 200 | 
            +
                    batch = index[0] * adj.size(-1)
         | 
| 201 | 
            +
                    index = (batch + index[1], batch + index[2])
         | 
| 202 | 
            +
                return index, edge_attr
         | 
| 203 | 
            +
             | 
| 204 | 
            +
             | 
| 205 | 
            +
            def mol_sample(sample_directory, edges, nodes, idx, i, matrices2mol, dataset_name):
         | 
| 206 | 
            +
                """
         | 
| 207 | 
            +
                Samples molecules from edge and node predictions, then saves grid images and text files.
         | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                Args:
         | 
| 210 | 
            +
                    sample_directory (str): Directory to save the samples.
         | 
| 211 | 
            +
                    edges (torch.Tensor): Edge predictions tensor.
         | 
| 212 | 
            +
                    nodes (torch.Tensor): Node predictions tensor.
         | 
| 213 | 
            +
                    idx (int): Current index for naming the sample.
         | 
| 214 | 
            +
                    i (int): Epoch/iteration index.
         | 
| 215 | 
            +
                    matrices2mol (callable): Function to convert matrices to RDKit molecule.
         | 
| 216 | 
            +
                    dataset_name (str): Name of the dataset for file naming.
         | 
| 217 | 
            +
                """
         | 
| 218 | 
            +
                sample_path = os.path.join(sample_directory, "{}_{}-epoch_iteration".format(idx + 1, i + 1))
         | 
| 219 | 
            +
                # Get the index of the maximum predicted feature along the last dimension
         | 
| 220 | 
            +
                g_edges_hat_sample = torch.max(edges, -1)[1]
         | 
| 221 | 
            +
                g_nodes_hat_sample = torch.max(nodes, -1)[1]
         | 
| 222 | 
            +
                # Convert matrices to molecule objects
         | 
| 223 | 
            +
                mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
         | 
| 224 | 
            +
                                    strict=True, file_name=dataset_name)
         | 
| 225 | 
            +
                       for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                if not os.path.exists(sample_path):
         | 
| 228 | 
            +
                    os.makedirs(sample_path)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                mols2grid_image(mol, sample_path)
         | 
| 231 | 
            +
                save_smiles_matrices(mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                # Remove the directory if no files were saved
         | 
| 234 | 
            +
                if len(os.listdir(sample_path)) == 0:
         | 
| 235 | 
            +
                    os.rmdir(sample_path)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                print("Valid molecules are saved.")
         | 
| 238 | 
            +
                print("Valid matrices and smiles are saved")
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
            def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node, 
         | 
| 242 | 
            +
                        matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
         | 
| 243 | 
            +
                """
         | 
| 244 | 
            +
                Logs training statistics and evaluation metrics.
         | 
| 245 | 
            +
                
         | 
| 246 | 
            +
                The function generates molecules from predictions, computes various metrics such as
         | 
| 247 | 
            +
                validity, uniqueness, novelty, and similarity scores, and logs them using wandb and a file.
         | 
| 248 | 
            +
                
         | 
| 249 | 
            +
                Args:
         | 
| 250 | 
            +
                    log_path (str): Path to save the log file.
         | 
| 251 | 
            +
                    start_time (float): Start time to compute elapsed time.
         | 
| 252 | 
            +
                    i (int): Current iteration index.
         | 
| 253 | 
            +
                    idx (int): Current epoch index.
         | 
| 254 | 
            +
                    loss (dict): Dictionary to update with loss and metric values.
         | 
| 255 | 
            +
                    save_path (str): Directory path to save sample outputs.
         | 
| 256 | 
            +
                    drug_smiles (list): List of reference drug SMILES.
         | 
| 257 | 
            +
                    edge (torch.Tensor): Edge prediction tensor.
         | 
| 258 | 
            +
                    node (torch.Tensor): Node prediction tensor.
         | 
| 259 | 
            +
                    matrices2mol (callable): Function to convert matrices to molecules.
         | 
| 260 | 
            +
                    dataset_name (str): Dataset name.
         | 
| 261 | 
            +
                    real_adj (torch.Tensor): Ground truth adjacency matrix tensor.
         | 
| 262 | 
            +
                    real_annot (torch.Tensor): Ground truth annotation tensor.
         | 
| 263 | 
            +
                    drug_vecs (list): List of drug vectors for similarity calculation.
         | 
| 264 | 
            +
                """
         | 
| 265 | 
            +
                g_edges_hat_sample = torch.max(edge, -1)[1]
         | 
| 266 | 
            +
                g_nodes_hat_sample = torch.max(node, -1)[1]
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                a_tensor_sample = torch.max(real_adj, -1)[1].float()
         | 
| 269 | 
            +
                x_tensor_sample = torch.max(real_annot, -1)[1].float()
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                # Generate molecules from predictions and real data
         | 
| 272 | 
            +
                mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
         | 
| 273 | 
            +
                                     strict=True, file_name=dataset_name)
         | 
| 274 | 
            +
                        for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
         | 
| 275 | 
            +
                real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
         | 
| 276 | 
            +
                                          strict=True, file_name=dataset_name)
         | 
| 277 | 
            +
                            for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                # Compute average number of atom types
         | 
| 280 | 
            +
                atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
         | 
| 281 | 
            +
                real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
         | 
| 282 | 
            +
                gen_smiles = []
         | 
| 283 | 
            +
                uniq_smiles = []
         | 
| 284 | 
            +
                for line in mols:
         | 
| 285 | 
            +
                    if line is not None:
         | 
| 286 | 
            +
                        gen_smiles.append(Chem.MolToSmiles(line))
         | 
| 287 | 
            +
                        uniq_smiles.append(Chem.MolToSmiles(line))
         | 
| 288 | 
            +
                    elif line is None:
         | 
| 289 | 
            +
                        gen_smiles.append(None)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                # Process SMILES to take the longest fragment if multiple are present
         | 
| 292 | 
            +
                gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
         | 
| 293 | 
            +
                uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                # Save the generated SMILES to a text file
         | 
| 296 | 
            +
                sample_save_dir = os.path.join(save_path, "samples.txt")
         | 
| 297 | 
            +
                with open(sample_save_dir, "a") as f:
         | 
| 298 | 
            +
                    for s in gen_smiles_saves:
         | 
| 299 | 
            +
                        if s is not None:
         | 
| 300 | 
            +
                            f.write(s + "\n")
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                k = len(set(uniq_smiles_saves) - {None})
         | 
| 303 | 
            +
                et = time.time() - start_time
         | 
| 304 | 
            +
                et = str(datetime.timedelta(seconds=et))[:-7]
         | 
| 305 | 
            +
                log_str = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i + 1)
         | 
| 306 | 
            +
                
         | 
| 307 | 
            +
                # Generate molecular fingerprints for similarity computations
         | 
| 308 | 
            +
                gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
         | 
| 309 | 
            +
                chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                # Compute evaluation metrics: validity, uniqueness, novelty, similarity scores, and average maximum molecule length.
         | 
| 312 | 
            +
                valid = fraction_valid(gen_smiles_saves)
         | 
| 313 | 
            +
                unique = fraction_unique(uniq_smiles_saves, k)
         | 
| 314 | 
            +
                novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
         | 
| 315 | 
            +
                novel_akt = novelty(gen_smiles_saves, drug_smiles)
         | 
| 316 | 
            +
                if len(uniq_smiles_saves) == 0:
         | 
| 317 | 
            +
                    snn_chembl = 0
         | 
| 318 | 
            +
                    snn_akt = 0
         | 
| 319 | 
            +
                    maxlen = 0
         | 
| 320 | 
            +
                else:
         | 
| 321 | 
            +
                    snn_chembl = average_agg_tanimoto(np.array(chembl_vecs), np.array(gen_vecs))
         | 
| 322 | 
            +
                    snn_akt = average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs))
         | 
| 323 | 
            +
                    maxlen = Metrics.max_component(uniq_smiles_saves, 45)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                # Update loss dictionary with computed metrics
         | 
| 326 | 
            +
                loss.update({
         | 
| 327 | 
            +
                    'Validity': valid,
         | 
| 328 | 
            +
                    'Uniqueness': unique,
         | 
| 329 | 
            +
                    'Novelty': novel_starting_mol,
         | 
| 330 | 
            +
                    'Novelty_akt': novel_akt,
         | 
| 331 | 
            +
                    'SNN_chembl': snn_chembl,
         | 
| 332 | 
            +
                    'SNN_akt': snn_akt,
         | 
| 333 | 
            +
                    'MaxLen': maxlen,
         | 
| 334 | 
            +
                    'Atom_types': atom_types_average
         | 
| 335 | 
            +
                })
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                # Log metrics using wandb
         | 
| 338 | 
            +
                wandb.log({
         | 
| 339 | 
            +
                    "Validity": valid,
         | 
| 340 | 
            +
                    "Uniqueness": unique,
         | 
| 341 | 
            +
                    "Novelty": novel_starting_mol,
         | 
| 342 | 
            +
                    "Novelty_akt": novel_akt,
         | 
| 343 | 
            +
                    "SNN_chembl": snn_chembl,
         | 
| 344 | 
            +
                    "SNN_akt": snn_akt,
         | 
| 345 | 
            +
                    "MaxLen": maxlen,
         | 
| 346 | 
            +
                    "Atom_types": atom_types_average
         | 
| 347 | 
            +
                })
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                # Append each metric to the log string and write to the log file
         | 
| 350 | 
            +
                for tag, value in loss.items():
         | 
| 351 | 
            +
                    log_str += ", {}: {:.4f}".format(tag, value)
         | 
| 352 | 
            +
                with open(log_path, "a") as f:
         | 
| 353 | 
            +
                    f.write(log_str + "\n")
         | 
| 354 | 
            +
                print(log_str)
         | 
| 355 | 
            +
                print("\n")
         | 
| 356 | 
            +
             | 
| 357 | 
            +
             | 
| 358 | 
            +
            def plot_grad_flow(named_parameters, model, itera, epoch, grad_flow_directory):
         | 
| 359 | 
            +
                """
         | 
| 360 | 
            +
                Plots the gradients flowing through different layers during training.
         | 
| 361 | 
            +
                
         | 
| 362 | 
            +
                This is useful to check for possible gradient vanishing or exploding problems.
         | 
| 363 | 
            +
                
         | 
| 364 | 
            +
                Args:
         | 
| 365 | 
            +
                    named_parameters (iterable): Iterable of (name, parameter) tuples from the model.
         | 
| 366 | 
            +
                    model (str): Name of the model (used for saving the plot).
         | 
| 367 | 
            +
                    itera (int): Iteration index.
         | 
| 368 | 
            +
                    epoch (int): Current epoch.
         | 
| 369 | 
            +
                    grad_flow_directory (str): Directory to save the gradient flow plot.
         | 
| 370 | 
            +
                """
         | 
| 371 | 
            +
                ave_grads = []
         | 
| 372 | 
            +
                max_grads = []
         | 
| 373 | 
            +
                layers = []
         | 
| 374 | 
            +
                for n, p in named_parameters:
         | 
| 375 | 
            +
                    if p.requires_grad and ("bias" not in n):
         | 
| 376 | 
            +
                        layers.append(n)
         | 
| 377 | 
            +
                        ave_grads.append(p.grad.abs().mean().cpu())
         | 
| 378 | 
            +
                        max_grads.append(p.grad.abs().max().cpu())
         | 
| 379 | 
            +
                # Plot maximum gradients and average gradients for each layer
         | 
| 380 | 
            +
                plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
         | 
| 381 | 
            +
                plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
         | 
| 382 | 
            +
                plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
         | 
| 383 | 
            +
                plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
         | 
| 384 | 
            +
                plt.xlim(left=0, right=len(ave_grads))
         | 
| 385 | 
            +
                plt.ylim(bottom=-0.001, top=1)  # Zoom in on lower gradient regions
         | 
| 386 | 
            +
                plt.xlabel("Layers")
         | 
| 387 | 
            +
                plt.ylabel("Average Gradient")
         | 
| 388 | 
            +
                plt.title("Gradient Flow")
         | 
| 389 | 
            +
                plt.grid(True)
         | 
| 390 | 
            +
                plt.legend([
         | 
| 391 | 
            +
                    Line2D([0], [0], color="c", lw=4),
         | 
| 392 | 
            +
                    Line2D([0], [0], color="b", lw=4),
         | 
| 393 | 
            +
                    Line2D([0], [0], color="k", lw=4)
         | 
| 394 | 
            +
                ], ['max-gradient', 'mean-gradient', 'zero-gradient'])
         | 
| 395 | 
            +
                # Save the plot to the specified directory
         | 
| 396 | 
            +
                plt.savefig(os.path.join(grad_flow_directory, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi=500, bbox_inches='tight')
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            def get_mol(smiles_or_mol):
         | 
| 400 | 
            +
                """
         | 
| 401 | 
            +
                Loads a SMILES string or molecule into an RDKit molecule object.
         | 
| 402 | 
            +
                
         | 
| 403 | 
            +
                Args:
         | 
| 404 | 
            +
                    smiles_or_mol (str or RDKit Mol): SMILES string or RDKit molecule.
         | 
| 405 | 
            +
                
         | 
| 406 | 
            +
                Returns:
         | 
| 407 | 
            +
                    RDKit Mol or None: Sanitized molecule object, or None if invalid.
         | 
| 408 | 
            +
                """
         | 
| 409 | 
            +
                if isinstance(smiles_or_mol, str):
         | 
| 410 | 
            +
                    if len(smiles_or_mol) == 0:
         | 
| 411 | 
            +
                        return None
         | 
| 412 | 
            +
                    mol = Chem.MolFromSmiles(smiles_or_mol)
         | 
| 413 | 
            +
                    if mol is None:
         | 
| 414 | 
            +
                        return None
         | 
| 415 | 
            +
                    try:
         | 
| 416 | 
            +
                        Chem.SanitizeMol(mol)
         | 
| 417 | 
            +
                    except ValueError:
         | 
| 418 | 
            +
                        return None
         | 
| 419 | 
            +
                    return mol
         | 
| 420 | 
            +
                return smiles_or_mol
         | 
| 421 | 
            +
             | 
| 422 | 
            +
             | 
| 423 | 
            +
            def mapper(n_jobs):
         | 
| 424 | 
            +
                """
         | 
| 425 | 
            +
                Returns a mapping function for parallel or serial processing.
         | 
| 426 | 
            +
                
         | 
| 427 | 
            +
                If n_jobs == 1, returns the built-in map function.
         | 
| 428 | 
            +
                If n_jobs > 1, returns a function that uses a multiprocessing pool.
         | 
| 429 | 
            +
                
         | 
| 430 | 
            +
                Args:
         | 
| 431 | 
            +
                    n_jobs (int or pool object): Number of jobs or a Pool instance.
         | 
| 432 | 
            +
                
         | 
| 433 | 
            +
                Returns:
         | 
| 434 | 
            +
                    callable: A function that acts like map.
         | 
| 435 | 
            +
                """
         | 
| 436 | 
            +
                if n_jobs == 1:
         | 
| 437 | 
            +
                    def _mapper(*args, **kwargs):
         | 
| 438 | 
            +
                        return list(map(*args, **kwargs))
         | 
| 439 | 
            +
                    return _mapper
         | 
| 440 | 
            +
                if isinstance(n_jobs, int):
         | 
| 441 | 
            +
                    pool = Pool(n_jobs)
         | 
| 442 | 
            +
                    def _mapper(*args, **kwargs):
         | 
| 443 | 
            +
                        try:
         | 
| 444 | 
            +
                            result = pool.map(*args, **kwargs)
         | 
| 445 | 
            +
                        finally:
         | 
| 446 | 
            +
                            pool.terminate()
         | 
| 447 | 
            +
                        return result
         | 
| 448 | 
            +
                    return _mapper
         | 
| 449 | 
            +
                return n_jobs.map
         | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
            def remove_invalid(gen, canonize=True, n_jobs=1):
         | 
| 453 | 
            +
                """
         | 
| 454 | 
            +
                Removes invalid molecules from the provided dataset.
         | 
| 455 | 
            +
                
         | 
| 456 | 
            +
                Optionally canonizes the SMILES strings.
         | 
| 457 | 
            +
                
         | 
| 458 | 
            +
                Args:
         | 
| 459 | 
            +
                    gen (list): List of SMILES strings.
         | 
| 460 | 
            +
                    canonize (bool): Whether to convert to canonical SMILES.
         | 
| 461 | 
            +
                    n_jobs (int): Number of parallel jobs.
         | 
| 462 | 
            +
                
         | 
| 463 | 
            +
                Returns:
         | 
| 464 | 
            +
                    list: Filtered list of valid molecules.
         | 
| 465 | 
            +
                """
         | 
| 466 | 
            +
                if not canonize:
         | 
| 467 | 
            +
                    mols = mapper(n_jobs)(get_mol, gen)
         | 
| 468 | 
            +
                    return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
         | 
| 469 | 
            +
                return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None]
         | 
| 470 | 
            +
             | 
| 471 | 
            +
             | 
| 472 | 
            +
            def fraction_valid(gen, n_jobs=1):
         | 
| 473 | 
            +
                """
         | 
| 474 | 
            +
                Computes the fraction of valid molecules in the dataset.
         | 
| 475 | 
            +
                
         | 
| 476 | 
            +
                Args:
         | 
| 477 | 
            +
                    gen (list): List of SMILES strings.
         | 
| 478 | 
            +
                    n_jobs (int): Number of parallel jobs.
         | 
| 479 | 
            +
                
         | 
| 480 | 
            +
                Returns:
         | 
| 481 | 
            +
                    float: Fraction of molecules that are valid.
         | 
| 482 | 
            +
                """
         | 
| 483 | 
            +
                gen = mapper(n_jobs)(get_mol, gen)
         | 
| 484 | 
            +
                return 1 - gen.count(None) / len(gen)
         | 
| 485 | 
            +
             | 
| 486 | 
            +
             | 
| 487 | 
            +
            def canonic_smiles(smiles_or_mol):
         | 
| 488 | 
            +
                """
         | 
| 489 | 
            +
                Converts a SMILES string or molecule to its canonical SMILES.
         | 
| 490 | 
            +
                
         | 
| 491 | 
            +
                Args:
         | 
| 492 | 
            +
                    smiles_or_mol (str or RDKit Mol): Input molecule.
         | 
| 493 | 
            +
                
         | 
| 494 | 
            +
                Returns:
         | 
| 495 | 
            +
                    str or None: Canonical SMILES string or None if invalid.
         | 
| 496 | 
            +
                """
         | 
| 497 | 
            +
                mol = get_mol(smiles_or_mol)
         | 
| 498 | 
            +
                if mol is None:
         | 
| 499 | 
            +
                    return None
         | 
| 500 | 
            +
                return Chem.MolToSmiles(mol)
         | 
| 501 | 
            +
             | 
| 502 | 
            +
             | 
| 503 | 
            +
            def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
         | 
| 504 | 
            +
                """
         | 
| 505 | 
            +
                Computes the fraction of unique molecules.
         | 
| 506 | 
            +
                
         | 
| 507 | 
            +
                Optionally computes unique@k, where only the first k molecules are considered.
         | 
| 508 | 
            +
                
         | 
| 509 | 
            +
                Args:
         | 
| 510 | 
            +
                    gen (list): List of SMILES strings.
         | 
| 511 | 
            +
                    k (int): Optional cutoff for unique@k computation.
         | 
| 512 | 
            +
                    n_jobs (int): Number of parallel jobs.
         | 
| 513 | 
            +
                    check_validity (bool): Whether to check for validity of molecules.
         | 
| 514 | 
            +
                
         | 
| 515 | 
            +
                Returns:
         | 
| 516 | 
            +
                    float: Fraction of unique molecules.
         | 
| 517 | 
            +
                """
         | 
| 518 | 
            +
                if k is not None:
         | 
| 519 | 
            +
                    if len(gen) < k:
         | 
| 520 | 
            +
                        warnings.warn("Can't compute unique@{}.".format(k) +
         | 
| 521 | 
            +
                                      " gen contains only {} molecules".format(len(gen)))
         | 
| 522 | 
            +
                    gen = gen[:k]
         | 
| 523 | 
            +
                if check_validity:
         | 
| 524 | 
            +
                    canonic = list(mapper(n_jobs)(canonic_smiles, gen))
         | 
| 525 | 
            +
                    canonic = [i for i in canonic if i is not None]
         | 
| 526 | 
            +
                set_cannonic = set(canonic)
         | 
| 527 | 
            +
                return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic)
         | 
| 528 | 
            +
             | 
| 529 | 
            +
             | 
| 530 | 
            +
            def novelty(gen, train, n_jobs=1):
         | 
| 531 | 
            +
                """
         | 
| 532 | 
            +
                Computes the novelty score of generated molecules.
         | 
| 533 | 
            +
                
         | 
| 534 | 
            +
                Novelty is defined as the fraction of generated molecules that do not appear in the training set.
         | 
| 535 | 
            +
                
         | 
| 536 | 
            +
                Args:
         | 
| 537 | 
            +
                    gen (list): List of generated SMILES strings.
         | 
| 538 | 
            +
                    train (list): List of training SMILES strings.
         | 
| 539 | 
            +
                    n_jobs (int): Number of parallel jobs.
         | 
| 540 | 
            +
                
         | 
| 541 | 
            +
                Returns:
         | 
| 542 | 
            +
                    float: Novelty score.
         | 
| 543 | 
            +
                """
         | 
| 544 | 
            +
                gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
         | 
| 545 | 
            +
                gen_smiles_set = set(gen_smiles) - {None}
         | 
| 546 | 
            +
                train_set = set(train)
         | 
| 547 | 
            +
                return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
         | 
| 548 | 
            +
             | 
| 549 | 
            +
             | 
| 550 | 
            +
            def internal_diversity(gen):
         | 
| 551 | 
            +
                """
         | 
| 552 | 
            +
                Computes the internal diversity of a set of molecules.
         | 
| 553 | 
            +
                
         | 
| 554 | 
            +
                Internal diversity is defined as one minus the average Tanimoto similarity between all pairs.
         | 
| 555 | 
            +
                
         | 
| 556 | 
            +
                Args:
         | 
| 557 | 
            +
                    gen: Array-like representation of molecules.
         | 
| 558 | 
            +
                
         | 
| 559 | 
            +
                Returns:
         | 
| 560 | 
            +
                    tuple: Mean and standard deviation of internal diversity.
         | 
| 561 | 
            +
                """
         | 
| 562 | 
            +
                diversity = [1 - x for x in average_agg_tanimoto(gen, gen, agg="mean", intdiv=True)]
         | 
| 563 | 
            +
                return np.mean(diversity), np.std(diversity)
         | 
| 564 | 
            +
             | 
| 565 | 
            +
             | 
| 566 | 
            +
            def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1, intdiv=False):
         | 
| 567 | 
            +
                """
         | 
| 568 | 
            +
                Computes the average aggregated Tanimoto similarity between two sets of molecular fingerprints.
         | 
| 569 | 
            +
                
         | 
| 570 | 
            +
                For each fingerprint in gen_vecs, finds the closest (max or mean) similarity with fingerprints in stock_vecs.
         | 
| 571 | 
            +
                
         | 
| 572 | 
            +
                Args:
         | 
| 573 | 
            +
                    stock_vecs (numpy.ndarray): Array of fingerprint vectors from the reference set.
         | 
| 574 | 
            +
                    gen_vecs (numpy.ndarray): Array of fingerprint vectors from the generated set.
         | 
| 575 | 
            +
                    batch_size (int): Batch size for processing fingerprints.
         | 
| 576 | 
            +
                    agg (str): Aggregation method, either 'max' or 'mean'.
         | 
| 577 | 
            +
                    device (str): Device to perform computations on.
         | 
| 578 | 
            +
                    p (int): Power for averaging.
         | 
| 579 | 
            +
                    intdiv (bool): Whether to return individual similarities or the average.
         | 
| 580 | 
            +
                
         | 
| 581 | 
            +
                Returns:
         | 
| 582 | 
            +
                    float or numpy.ndarray: Average aggregated Tanimoto similarity or array of individual scores.
         | 
| 583 | 
            +
                """
         | 
| 584 | 
            +
                assert agg in ['max', 'mean'], "Can aggregate only max or mean"
         | 
| 585 | 
            +
                agg_tanimoto = np.zeros(len(gen_vecs))
         | 
| 586 | 
            +
                total = np.zeros(len(gen_vecs))
         | 
| 587 | 
            +
                for j in range(0, stock_vecs.shape[0], batch_size):
         | 
| 588 | 
            +
                    x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
         | 
| 589 | 
            +
                    for i in range(0, gen_vecs.shape[0], batch_size):
         | 
| 590 | 
            +
                        y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
         | 
| 591 | 
            +
                        y_gen = y_gen.transpose(0, 1)
         | 
| 592 | 
            +
                        tp = torch.mm(x_stock, y_gen)
         | 
| 593 | 
            +
                        # Compute Jaccard/Tanimoto similarity
         | 
| 594 | 
            +
                        jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
         | 
| 595 | 
            +
                        jac[np.isnan(jac)] = 1
         | 
| 596 | 
            +
                        if p != 1:
         | 
| 597 | 
            +
                            jac = jac ** p
         | 
| 598 | 
            +
                        if agg == 'max':
         | 
| 599 | 
            +
                            agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
         | 
| 600 | 
            +
                                agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
         | 
| 601 | 
            +
                        elif agg == 'mean':
         | 
| 602 | 
            +
                            agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
         | 
| 603 | 
            +
                            total[i:i + y_gen.shape[1]] += jac.shape[0]
         | 
| 604 | 
            +
                if agg == 'mean':
         | 
| 605 | 
            +
                    agg_tanimoto /= total
         | 
| 606 | 
            +
                if p != 1:
         | 
| 607 | 
            +
                    agg_tanimoto = (agg_tanimoto) ** (1 / p)
         | 
| 608 | 
            +
                if intdiv:
         | 
| 609 | 
            +
                    return agg_tanimoto
         | 
| 610 | 
            +
                else:
         | 
| 611 | 
            +
                    return np.mean(agg_tanimoto)
         | 
| 612 | 
            +
             | 
| 613 | 
            +
             | 
| 614 | 
            +
            def str2bool(v):
         | 
| 615 | 
            +
                """
         | 
| 616 | 
            +
                Converts a string to a boolean.
         | 
| 617 | 
            +
                
         | 
| 618 | 
            +
                Args:
         | 
| 619 | 
            +
                    v (str): Input string.
         | 
| 620 | 
            +
                
         | 
| 621 | 
            +
                Returns:
         | 
| 622 | 
            +
                    bool: True if the string is 'true' (case insensitive), else False.
         | 
| 623 | 
            +
                """
         | 
| 624 | 
            +
                return v.lower() in ('true')
         | 
| 625 | 
            +
             | 
| 626 | 
            +
             | 
| 627 | 
            +
            def obey_lipinski(mol):
         | 
| 628 | 
            +
                """
         | 
| 629 | 
            +
                Checks if a molecule obeys Lipinski's Rule of Five.
         | 
| 630 | 
            +
                
         | 
| 631 | 
            +
                The function evaluates weight, hydrogen bond donors and acceptors, logP, and rotatable bonds.
         | 
| 632 | 
            +
                
         | 
| 633 | 
            +
                Args:
         | 
| 634 | 
            +
                    mol (RDKit Mol): Molecule object.
         | 
| 635 | 
            +
                
         | 
| 636 | 
            +
                Returns:
         | 
| 637 | 
            +
                    int: Number of Lipinski rules satisfied.
         | 
| 638 | 
            +
                """
         | 
| 639 | 
            +
                mol = deepcopy(mol)
         | 
| 640 | 
            +
                Chem.SanitizeMol(mol)
         | 
| 641 | 
            +
                rule_1 = Descriptors.ExactMolWt(mol) < 500
         | 
| 642 | 
            +
                rule_2 = Lipinski.NumHDonors(mol) <= 5
         | 
| 643 | 
            +
                rule_3 = Lipinski.NumHAcceptors(mol) <= 10
         | 
| 644 | 
            +
                rule_4 = (logp := Crippen.MolLogP(mol) >= -2) & (logp <= 5)
         | 
| 645 | 
            +
                rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
         | 
| 646 | 
            +
                return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
         | 
| 647 | 
            +
             | 
| 648 | 
            +
             | 
| 649 | 
            +
            def obey_veber(mol):
         | 
| 650 | 
            +
                """
         | 
| 651 | 
            +
                Checks if a molecule obeys Veber's rules.
         | 
| 652 | 
            +
                
         | 
| 653 | 
            +
                Veber's rules focus on the number of rotatable bonds and topological polar surface area.
         | 
| 654 | 
            +
                
         | 
| 655 | 
            +
                Args:
         | 
| 656 | 
            +
                    mol (RDKit Mol): Molecule object.
         | 
| 657 | 
            +
                
         | 
| 658 | 
            +
                Returns:
         | 
| 659 | 
            +
                    int: Number of Veber's rules satisfied.
         | 
| 660 | 
            +
                """
         | 
| 661 | 
            +
                mol = deepcopy(mol)
         | 
| 662 | 
            +
                Chem.SanitizeMol(mol)
         | 
| 663 | 
            +
                rule_1 = rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
         | 
| 664 | 
            +
                rule_2 = rdMolDescriptors.CalcTPSA(mol) <= 140
         | 
| 665 | 
            +
                return np.sum([int(a) for a in [rule_1, rule_2]])
         | 
| 666 | 
            +
             | 
| 667 | 
            +
             | 
| 668 | 
            +
            def load_pains_filters():
         | 
| 669 | 
            +
                """
         | 
| 670 | 
            +
                Loads the PAINS (Pan-Assay INterference compoundS) filters A, B, and C.
         | 
| 671 | 
            +
                
         | 
| 672 | 
            +
                Returns:
         | 
| 673 | 
            +
                    FilterCatalog: An RDKit FilterCatalog object containing PAINS filters.
         | 
| 674 | 
            +
                """
         | 
| 675 | 
            +
                params = FilterCatalog.FilterCatalogParams()
         | 
| 676 | 
            +
                params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_A)
         | 
| 677 | 
            +
                params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_B)
         | 
| 678 | 
            +
                params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_C)
         | 
| 679 | 
            +
                catalog = FilterCatalog.FilterCatalog(params)
         | 
| 680 | 
            +
                return catalog
         | 
| 681 | 
            +
             | 
| 682 | 
            +
             | 
| 683 | 
            +
            def is_pains(mol, catalog):
         | 
| 684 | 
            +
                """
         | 
| 685 | 
            +
                Checks if the given molecule is a PAINS compound.
         | 
| 686 | 
            +
                
         | 
| 687 | 
            +
                Args:
         | 
| 688 | 
            +
                    mol (RDKit Mol): Molecule object.
         | 
| 689 | 
            +
                    catalog (FilterCatalog): A catalog of PAINS filters.
         | 
| 690 | 
            +
                
         | 
| 691 | 
            +
                Returns:
         | 
| 692 | 
            +
                    bool: True if the molecule matches a PAINS filter, else False.
         | 
| 693 | 
            +
                """
         | 
| 694 | 
            +
                entry = catalog.GetFirstMatch(mol)
         | 
| 695 | 
            +
                return entry is not None
         | 
| 696 | 
            +
             | 
| 697 | 
            +
             | 
| 698 | 
            +
            def mapper(n_jobs):
         | 
| 699 | 
            +
                """
         | 
| 700 | 
            +
                Returns a mapping function for parallel or serial processing.
         | 
| 701 | 
            +
                
         | 
| 702 | 
            +
                If n_jobs == 1, returns the built-in map function.
         | 
| 703 | 
            +
                If n_jobs > 1, returns a function that uses a multiprocessing pool.
         | 
| 704 | 
            +
                
         | 
| 705 | 
            +
                Args:
         | 
| 706 | 
            +
                    n_jobs (int or pool object): Number of jobs or a Pool instance.
         | 
| 707 | 
            +
                
         | 
| 708 | 
            +
                Returns:
         | 
| 709 | 
            +
                    callable: A function that acts like map.
         | 
| 710 | 
            +
                """
         | 
| 711 | 
            +
                if n_jobs == 1:
         | 
| 712 | 
            +
                    def _mapper(*args, **kwargs):
         | 
| 713 | 
            +
                        return list(map(*args, **kwargs))
         | 
| 714 | 
            +
                    return _mapper
         | 
| 715 | 
            +
                if isinstance(n_jobs, int):
         | 
| 716 | 
            +
                    pool = Pool(n_jobs)
         | 
| 717 | 
            +
                    def _mapper(*args, **kwargs):
         | 
| 718 | 
            +
                        try:
         | 
| 719 | 
            +
                            result = pool.map(*args, **kwargs)
         | 
| 720 | 
            +
                        finally:
         | 
| 721 | 
            +
                            pool.terminate()
         | 
| 722 | 
            +
                        return result
         | 
| 723 | 
            +
                    return _mapper
         | 
| 724 | 
            +
                return n_jobs.map
         | 
| 725 | 
            +
             | 
| 726 | 
            +
             | 
| 727 | 
            +
            def fragmenter(mol):
         | 
| 728 | 
            +
                """
         | 
| 729 | 
            +
                Fragments a molecule using BRICS and returns a list of fragment SMILES.
         | 
| 730 | 
            +
                
         | 
| 731 | 
            +
                Args:
         | 
| 732 | 
            +
                    mol (str or RDKit Mol): Input molecule.
         | 
| 733 | 
            +
                
         | 
| 734 | 
            +
                Returns:
         | 
| 735 | 
            +
                    list: List of fragment SMILES strings.
         | 
| 736 | 
            +
                """
         | 
| 737 | 
            +
                fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
         | 
| 738 | 
            +
                fgs_smi = Chem.MolToSmiles(fgs).split(".")
         | 
| 739 | 
            +
                return fgs_smi
         | 
| 740 | 
            +
             | 
| 741 | 
            +
             | 
| 742 | 
            +
            def get_mol(smiles_or_mol):
         | 
| 743 | 
            +
                """
         | 
| 744 | 
            +
                Loads a SMILES string or molecule into an RDKit molecule object.
         | 
| 745 | 
            +
                
         | 
| 746 | 
            +
                Args:
         | 
| 747 | 
            +
                    smiles_or_mol (str or RDKit Mol): SMILES string or molecule.
         | 
| 748 | 
            +
                
         | 
| 749 | 
            +
                Returns:
         | 
| 750 | 
            +
                    RDKit Mol or None: Sanitized molecule object or None if invalid.
         | 
| 751 | 
            +
                """
         | 
| 752 | 
            +
                if isinstance(smiles_or_mol, str):
         | 
| 753 | 
            +
                    if len(smiles_or_mol) == 0:
         | 
| 754 | 
            +
                        return None
         | 
| 755 | 
            +
                    mol = Chem.MolFromSmiles(smiles_or_mol)
         | 
| 756 | 
            +
                    if mol is None:
         | 
| 757 | 
            +
                        return None
         | 
| 758 | 
            +
                    try:
         | 
| 759 | 
            +
                        Chem.SanitizeMol(mol)
         | 
| 760 | 
            +
                    except ValueError:
         | 
| 761 | 
            +
                        return None
         | 
| 762 | 
            +
                    return mol
         | 
| 763 | 
            +
                return smiles_or_mol
         | 
| 764 | 
            +
             | 
| 765 | 
            +
             | 
| 766 | 
            +
            def compute_fragments(mol_list, n_jobs=1):
         | 
| 767 | 
            +
                """
         | 
| 768 | 
            +
                Fragments a list of molecules using BRICS and returns a counter of fragment occurrences.
         | 
| 769 | 
            +
                
         | 
| 770 | 
            +
                Args:
         | 
| 771 | 
            +
                    mol_list (list): List of molecules (SMILES or RDKit Mol).
         | 
| 772 | 
            +
                    n_jobs (int): Number of parallel jobs.
         | 
| 773 | 
            +
                
         | 
| 774 | 
            +
                Returns:
         | 
| 775 | 
            +
                    Counter: A Counter dictionary mapping fragment SMILES to counts.
         | 
| 776 | 
            +
                """
         | 
| 777 | 
            +
                fragments = Counter()
         | 
| 778 | 
            +
                for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
         | 
| 779 | 
            +
                    fragments.update(mol_frag)
         | 
| 780 | 
            +
                return fragments
         | 
| 781 | 
            +
             | 
| 782 | 
            +
             | 
| 783 | 
            +
            def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
         | 
| 784 | 
            +
                """
         | 
| 785 | 
            +
                Extracts scaffolds from a list of molecules as canonical SMILES.
         | 
| 786 | 
            +
                
         | 
| 787 | 
            +
                Only scaffolds with at least min_rings rings are considered.
         | 
| 788 | 
            +
                
         | 
| 789 | 
            +
                Args:
         | 
| 790 | 
            +
                    mol_list (list): List of molecules.
         | 
| 791 | 
            +
                    n_jobs (int): Number of parallel jobs.
         | 
| 792 | 
            +
                    min_rings (int): Minimum number of rings required in a scaffold.
         | 
| 793 | 
            +
                
         | 
| 794 | 
            +
                Returns:
         | 
| 795 | 
            +
                    Counter: A Counter mapping scaffold SMILES to counts.
         | 
| 796 | 
            +
                """
         | 
| 797 | 
            +
                scaffolds = Counter()
         | 
| 798 | 
            +
                map_ = mapper(n_jobs)
         | 
| 799 | 
            +
                scaffolds = Counter(map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
         | 
| 800 | 
            +
                if None in scaffolds:
         | 
| 801 | 
            +
                    scaffolds.pop(None)
         | 
| 802 | 
            +
                return scaffolds
         | 
| 803 | 
            +
             | 
| 804 | 
            +
             | 
| 805 | 
            +
            def get_n_rings(mol):
         | 
| 806 | 
            +
                """
         | 
| 807 | 
            +
                Computes the number of rings in a molecule.
         | 
| 808 | 
            +
                
         | 
| 809 | 
            +
                Args:
         | 
| 810 | 
            +
                    mol (RDKit Mol): Molecule object.
         | 
| 811 | 
            +
                
         | 
| 812 | 
            +
                Returns:
         | 
| 813 | 
            +
                    int: Number of rings.
         | 
| 814 | 
            +
                """
         | 
| 815 | 
            +
                return mol.GetRingInfo().NumRings()
         | 
| 816 | 
            +
             | 
| 817 | 
            +
             | 
| 818 | 
            +
            def compute_scaffold(mol, min_rings=2):
         | 
| 819 | 
            +
                """
         | 
| 820 | 
            +
                Computes the Murcko scaffold of a molecule and returns its canonical SMILES if it has enough rings.
         | 
| 821 | 
            +
                
         | 
| 822 | 
            +
                Args:
         | 
| 823 | 
            +
                    mol (str or RDKit Mol): Input molecule.
         | 
| 824 | 
            +
                    min_rings (int): Minimum number of rings required.
         | 
| 825 | 
            +
                
         | 
| 826 | 
            +
                Returns:
         | 
| 827 | 
            +
                    str or None: Canonical SMILES of the scaffold if valid, else None.
         | 
| 828 | 
            +
                """
         | 
| 829 | 
            +
                mol = get_mol(mol)
         | 
| 830 | 
            +
                try:
         | 
| 831 | 
            +
                    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
         | 
| 832 | 
            +
                except (ValueError, RuntimeError):
         | 
| 833 | 
            +
                    return None
         | 
| 834 | 
            +
                n_rings = get_n_rings(scaffold)
         | 
| 835 | 
            +
                scaffold_smiles = Chem.MolToSmiles(scaffold)
         | 
| 836 | 
            +
                if scaffold_smiles == '' or n_rings < min_rings:
         | 
| 837 | 
            +
                    return None
         | 
| 838 | 
            +
                return scaffold_smiles
         | 
| 839 | 
            +
             | 
| 840 | 
            +
             | 
| 841 | 
            +
            class Metric:
         | 
| 842 | 
            +
                """
         | 
| 843 | 
            +
                Abstract base class for chemical metrics.
         | 
| 844 | 
            +
                
         | 
| 845 | 
            +
                Derived classes should implement the precalc and metric methods.
         | 
| 846 | 
            +
                """
         | 
| 847 | 
            +
                def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
         | 
| 848 | 
            +
                    self.n_jobs = n_jobs
         | 
| 849 | 
            +
                    self.device = device
         | 
| 850 | 
            +
                    self.batch_size = batch_size
         | 
| 851 | 
            +
                    for k, v in kwargs.items():
         | 
| 852 | 
            +
                        setattr(self, k, v)
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                def __call__(self, ref=None, gen=None, pref=None, pgen=None):
         | 
| 855 | 
            +
                    """
         | 
| 856 | 
            +
                    Computes the metric between reference and generated molecules.
         | 
| 857 | 
            +
                    
         | 
| 858 | 
            +
                    Exactly one of ref or pref, and gen or pgen should be provided.
         | 
| 859 | 
            +
                    
         | 
| 860 | 
            +
                    Args:
         | 
| 861 | 
            +
                        ref: Reference molecule list.
         | 
| 862 | 
            +
                        gen: Generated molecule list.
         | 
| 863 | 
            +
                        pref: Precalculated reference metric.
         | 
| 864 | 
            +
                        pgen: Precalculated generated metric.
         | 
| 865 | 
            +
                    
         | 
| 866 | 
            +
                    Returns:
         | 
| 867 | 
            +
                        Metric value computed by the metric method.
         | 
| 868 | 
            +
                    """
         | 
| 869 | 
            +
                    assert (ref is None) != (pref is None), "specify ref xor pref"
         | 
| 870 | 
            +
                    assert (gen is None) != (pgen is None), "specify gen xor pgen"
         | 
| 871 | 
            +
                    if pref is None:
         | 
| 872 | 
            +
                        pref = self.precalc(ref)
         | 
| 873 | 
            +
                    if pgen is None:
         | 
| 874 | 
            +
                        pgen = self.precalc(gen)
         | 
| 875 | 
            +
                    return self.metric(pref, pgen)
         | 
| 876 | 
            +
             | 
| 877 | 
            +
                def precalc(self, molecules):
         | 
| 878 | 
            +
                    """
         | 
| 879 | 
            +
                    Pre-calculates necessary representations from a list of molecules.
         | 
| 880 | 
            +
                    Should be implemented by derived classes.
         | 
| 881 | 
            +
                    """
         | 
| 882 | 
            +
                    raise NotImplementedError
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                def metric(self, pref, pgen):
         | 
| 885 | 
            +
                    """
         | 
| 886 | 
            +
                    Computes the metric given precalculated representations.
         | 
| 887 | 
            +
                    Should be implemented by derived classes.
         | 
| 888 | 
            +
                    """
         | 
| 889 | 
            +
                    raise NotImplementedError
         | 
| 890 | 
            +
             | 
| 891 | 
            +
             | 
| 892 | 
            +
            class FragMetric(Metric):
         | 
| 893 | 
            +
                """
         | 
| 894 | 
            +
                Metrics based on molecular fragments.
         | 
| 895 | 
            +
                """
         | 
| 896 | 
            +
                def precalc(self, mols):
         | 
| 897 | 
            +
                    return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
         | 
| 898 | 
            +
             | 
| 899 | 
            +
                def metric(self, pref, pgen):
         | 
| 900 | 
            +
                    return cos_similarity(pref['frag'], pgen['frag'])
         | 
| 901 | 
            +
             | 
| 902 | 
            +
             | 
| 903 | 
            +
            class ScafMetric(Metric):
         | 
| 904 | 
            +
                """
         | 
| 905 | 
            +
                Metrics based on molecular scaffolds.
         | 
| 906 | 
            +
                """
         | 
| 907 | 
            +
                def precalc(self, mols):
         | 
| 908 | 
            +
                    return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                def metric(self, pref, pgen):
         | 
| 911 | 
            +
                    return cos_similarity(pref['scaf'], pgen['scaf'])
         | 
| 912 | 
            +
             | 
| 913 | 
            +
             | 
| 914 | 
            +
            def cos_similarity(ref_counts, gen_counts):
         | 
| 915 | 
            +
                """
         | 
| 916 | 
            +
                Computes cosine similarity between two molecular vectors.
         | 
| 917 | 
            +
                
         | 
| 918 | 
            +
                Args:
         | 
| 919 | 
            +
                    ref_counts (dict): Reference molecular vectors.
         | 
| 920 | 
            +
                    gen_counts (dict): Generated molecular vectors.
         | 
| 921 | 
            +
                
         | 
| 922 | 
            +
                Returns:
         | 
| 923 | 
            +
                    float: Cosine similarity between the two molecular vectors.
         | 
| 924 | 
            +
                """
         | 
| 925 | 
            +
                if len(ref_counts) == 0 or len(gen_counts) == 0:
         | 
| 926 | 
            +
                    return np.nan
         | 
| 927 | 
            +
                keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
         | 
| 928 | 
            +
                ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
         | 
| 929 | 
            +
                gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
         | 
| 930 | 
            +
                return 1 - cos_distance(ref_vec, gen_vec)
         | 
    	
        train.py
    ADDED
    
    | @@ -0,0 +1,462 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import pickle
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import os.path as osp
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.utils.data
         | 
| 10 | 
            +
            from torch import nn
         | 
| 11 | 
            +
            from torch_geometric.loader import DataLoader
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import wandb
         | 
| 14 | 
            +
            from rdkit import RDLogger
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            torch.set_num_threads(5)
         | 
| 17 | 
            +
            RDLogger.DisableLog('rdApp.*')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from src.util.utils import *
         | 
| 20 | 
            +
            from src.model.models import Generator, Discriminator, simple_disc
         | 
| 21 | 
            +
            from src.data.dataset import DruggenDataset
         | 
| 22 | 
            +
            from src.data.utils import get_encoders_decoders, load_molecules
         | 
| 23 | 
            +
            from src.model.loss import discriminator_loss, generator_loss
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            class Train(object):
         | 
| 26 | 
            +
                """Trainer for DrugGEN."""
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def __init__(self, config):
         | 
| 29 | 
            +
                    if config.set_seed:
         | 
| 30 | 
            +
                        np.random.seed(config.seed)
         | 
| 31 | 
            +
                        random.seed(config.seed)
         | 
| 32 | 
            +
                        torch.manual_seed(config.seed)
         | 
| 33 | 
            +
                        torch.cuda.manual_seed_all(config.seed)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                        torch.backends.cudnn.deterministic = True
         | 
| 36 | 
            +
                        torch.backends.cudnn.benchmark = False
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                        os.environ["PYTHONHASHSEED"] = str(config.seed)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        print(f'Using seed {config.seed}')
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # Initialize configurations
         | 
| 45 | 
            +
                    self.submodel = config.submodel
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    # Data loader.
         | 
| 48 | 
            +
                    self.raw_file = config.raw_file  # SMILES containing text file for dataset. 
         | 
| 49 | 
            +
                                                     # Write the full path to file.
         | 
| 50 | 
            +
                    self.drug_raw_file = config.drug_raw_file  # SMILES containing text file for second dataset. 
         | 
| 51 | 
            +
                                                               # Write the full path to file.
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                    # Automatically infer dataset file names from raw file names
         | 
| 54 | 
            +
                    raw_file_basename = osp.basename(self.raw_file)
         | 
| 55 | 
            +
                    drug_raw_file_basename = osp.basename(self.drug_raw_file)
         | 
| 56 | 
            +
                    
         | 
| 57 | 
            +
                    # Get the base name without extension and add max_atom to it
         | 
| 58 | 
            +
                    self.max_atom = config.max_atom  # Model is based on one-shot generation.
         | 
| 59 | 
            +
                    raw_file_base = os.path.splitext(raw_file_basename)[0]
         | 
| 60 | 
            +
                    drug_raw_file_base = os.path.splitext(drug_raw_file_basename)[0]
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # Change extension from .smi to .pt and add max_atom to the filename
         | 
| 63 | 
            +
                    self.dataset_file = f"{raw_file_base}{self.max_atom}.pt"
         | 
| 64 | 
            +
                    self.drugs_dataset_file = f"{drug_raw_file_base}{self.max_atom}.pt"
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.mol_data_dir = config.mol_data_dir  # Directory where the dataset files are stored.
         | 
| 67 | 
            +
                    self.drug_data_dir = config.drug_data_dir  # Directory where the drug dataset files are stored.
         | 
| 68 | 
            +
                    self.dataset_name = self.dataset_file.split(".")[0]
         | 
| 69 | 
            +
                    self.drugs_dataset_name = self.drugs_dataset_file.split(".")[0]
         | 
| 70 | 
            +
                    self.features = config.features  # Small model uses atom types as node features. (Boolean, False uses atom types only.)
         | 
| 71 | 
            +
                                                     # Additional node features can be added. Please check new_dataloarder.py Line 102.
         | 
| 72 | 
            +
                    self.batch_size = config.batch_size  # Batch size for training.
         | 
| 73 | 
            +
                    
         | 
| 74 | 
            +
                    self.parallel = config.parallel
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # Get atom and bond encoders/decoders
         | 
| 77 | 
            +
                    atom_encoder, atom_decoder, bond_encoder, bond_decoder = get_encoders_decoders(
         | 
| 78 | 
            +
                        self.raw_file,
         | 
| 79 | 
            +
                        self.drug_raw_file,
         | 
| 80 | 
            +
                        self.max_atom
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                    self.atom_encoder = atom_encoder
         | 
| 83 | 
            +
                    self.atom_decoder = atom_decoder
         | 
| 84 | 
            +
                    self.bond_encoder = bond_encoder
         | 
| 85 | 
            +
                    self.bond_decoder = bond_decoder
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    self.dataset = DruggenDataset(self.mol_data_dir,
         | 
| 88 | 
            +
                                                 self.dataset_file,
         | 
| 89 | 
            +
                                                 self.raw_file,
         | 
| 90 | 
            +
                                                 self.max_atom,
         | 
| 91 | 
            +
                                                 self.features,
         | 
| 92 | 
            +
                                                 atom_encoder=atom_encoder,
         | 
| 93 | 
            +
                                                 atom_decoder=atom_decoder,
         | 
| 94 | 
            +
                                                 bond_encoder=bond_encoder,
         | 
| 95 | 
            +
                                                 bond_decoder=bond_decoder)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.loader = DataLoader(self.dataset,
         | 
| 98 | 
            +
                                             shuffle=True,
         | 
| 99 | 
            +
                                             batch_size=self.batch_size,
         | 
| 100 | 
            +
                                             drop_last=True)  # PyG dataloader for the GAN.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    self.drugs = DruggenDataset(self.drug_data_dir, 
         | 
| 103 | 
            +
                                             self.drugs_dataset_file, 
         | 
| 104 | 
            +
                                             self.drug_raw_file, 
         | 
| 105 | 
            +
                                             self.max_atom, 
         | 
| 106 | 
            +
                                             self.features,
         | 
| 107 | 
            +
                                             atom_encoder=atom_encoder,
         | 
| 108 | 
            +
                                             atom_decoder=atom_decoder,
         | 
| 109 | 
            +
                                             bond_encoder=bond_encoder,
         | 
| 110 | 
            +
                                             bond_decoder=bond_decoder)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.drugs_loader = DataLoader(self.drugs, 
         | 
| 113 | 
            +
                                                   shuffle=True,
         | 
| 114 | 
            +
                                                   batch_size=self.batch_size, 
         | 
| 115 | 
            +
                                                   drop_last=True)  # PyG dataloader for the second GAN.
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    self.m_dim = len(self.atom_decoder) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
         | 
| 118 | 
            +
                    self.b_dim = len(self.bond_decoder) # Bond type dimension.
         | 
| 119 | 
            +
                    self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    # Model configurations.
         | 
| 122 | 
            +
                    self.act = config.act
         | 
| 123 | 
            +
                    self.lambda_gp = config.lambda_gp
         | 
| 124 | 
            +
                    self.dim = config.dim
         | 
| 125 | 
            +
                    self.depth = config.depth
         | 
| 126 | 
            +
                    self.heads = config.heads
         | 
| 127 | 
            +
                    self.mlp_ratio = config.mlp_ratio
         | 
| 128 | 
            +
                    self.ddepth = config.ddepth
         | 
| 129 | 
            +
                    self.ddropout = config.ddropout
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    # Training configurations.
         | 
| 132 | 
            +
                    self.epoch = config.epoch
         | 
| 133 | 
            +
                    self.g_lr = config.g_lr
         | 
| 134 | 
            +
                    self.d_lr = config.d_lr
         | 
| 135 | 
            +
                    self.dropout = config.dropout
         | 
| 136 | 
            +
                    self.beta1 = config.beta1
         | 
| 137 | 
            +
                    self.beta2 = config.beta2
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # Directories.
         | 
| 140 | 
            +
                    self.log_dir = config.log_dir
         | 
| 141 | 
            +
                    self.sample_dir = config.sample_dir
         | 
| 142 | 
            +
                    self.model_save_dir = config.model_save_dir
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    # Step size.
         | 
| 145 | 
            +
                    self.log_step = config.log_sample_step
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # resume training
         | 
| 148 | 
            +
                    self.resume = config.resume
         | 
| 149 | 
            +
                    self.resume_epoch = config.resume_epoch
         | 
| 150 | 
            +
                    self.resume_iter = config.resume_iter
         | 
| 151 | 
            +
                    self.resume_directory = config.resume_directory
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # wandb configuration
         | 
| 154 | 
            +
                    self.use_wandb = config.use_wandb
         | 
| 155 | 
            +
                    self.online = config.online
         | 
| 156 | 
            +
                    self.exp_name = config.exp_name
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    # Arguments for the model.
         | 
| 159 | 
            +
                    self.arguments = "{}_{}_glr{}_dlr{}_dim{}_depth{}_heads{}_batch{}_epoch{}_dataset{}_dropout{}".format(self.exp_name, self.submodel, self.g_lr, self.d_lr, self.dim, self.depth, self.heads, self.batch_size, self.epoch, self.dataset_name, self.dropout)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    self.build_model(self.model_save_dir, self.arguments)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
                def build_model(self, model_save_dir, arguments):
         | 
| 165 | 
            +
                    """Create generators and discriminators."""
         | 
| 166 | 
            +
                    
         | 
| 167 | 
            +
                    ''' Generator is based on Transformer Encoder: 
         | 
| 168 | 
            +
                        
         | 
| 169 | 
            +
                        @ g_conv_dim: Dimensions for MLP layers before Transformer Encoder
         | 
| 170 | 
            +
                        @ vertexes: maximum length of generated molecules (atom length)
         | 
| 171 | 
            +
                        @ b_dim: number of bond types
         | 
| 172 | 
            +
                        @ m_dim: number of atom types (or number of features used)
         | 
| 173 | 
            +
                        @ dropout: dropout possibility
         | 
| 174 | 
            +
                        @ dim: Hidden dimension of Transformer Encoder
         | 
| 175 | 
            +
                        @ depth: Transformer layer number
         | 
| 176 | 
            +
                        @ heads: Number of multihead-attention heads
         | 
| 177 | 
            +
                        @ mlp_ratio: Read-out layer dimension of Transformer
         | 
| 178 | 
            +
                        @ drop_rate: depricated  
         | 
| 179 | 
            +
                        @ tra_conv: Whether module creates output for TransformerConv discriminator
         | 
| 180 | 
            +
                        '''
         | 
| 181 | 
            +
                    self.G = Generator(self.act,
         | 
| 182 | 
            +
                                       self.vertexes,
         | 
| 183 | 
            +
                                       self.b_dim,
         | 
| 184 | 
            +
                                       self.m_dim,
         | 
| 185 | 
            +
                                       self.dropout,
         | 
| 186 | 
            +
                                       dim=self.dim,
         | 
| 187 | 
            +
                                       depth=self.depth,
         | 
| 188 | 
            +
                                       heads=self.heads,
         | 
| 189 | 
            +
                                       mlp_ratio=self.mlp_ratio)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    ''' Discriminator implementation with Transformer Encoder:
         | 
| 192 | 
            +
                        
         | 
| 193 | 
            +
                        @ act: Activation function for MLP
         | 
| 194 | 
            +
                        @ vertexes: maximum length of generated molecules (molecule length)
         | 
| 195 | 
            +
                        @ b_dim: number of bond types
         | 
| 196 | 
            +
                        @ m_dim: number of atom types (or number of features used)
         | 
| 197 | 
            +
                        @ dropout: dropout possibility
         | 
| 198 | 
            +
                        @ dim: Hidden dimension of Transformer Encoder
         | 
| 199 | 
            +
                        @ depth: Transformer layer number
         | 
| 200 | 
            +
                        @ heads: Number of multihead-attention heads
         | 
| 201 | 
            +
                        @ mlp_ratio: Read-out layer dimension of Transformer'''
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    self.D = Discriminator(self.act,
         | 
| 204 | 
            +
                                            self.vertexes,
         | 
| 205 | 
            +
                                            self.b_dim,
         | 
| 206 | 
            +
                                            self.m_dim,
         | 
| 207 | 
            +
                                            self.ddropout,
         | 
| 208 | 
            +
                                            dim=self.dim,
         | 
| 209 | 
            +
                                            depth=self.ddepth,
         | 
| 210 | 
            +
                                            heads=self.heads,
         | 
| 211 | 
            +
                                            mlp_ratio=self.mlp_ratio)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
         | 
| 214 | 
            +
                    self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    network_path = os.path.join(model_save_dir, arguments)
         | 
| 217 | 
            +
                    self.print_network(self.G, 'G', network_path)
         | 
| 218 | 
            +
                    self.print_network(self.D, 'D', network_path)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    if self.parallel and torch.cuda.device_count() > 1:
         | 
| 221 | 
            +
                        print(f"Using {torch.cuda.device_count()} GPUs!")
         | 
| 222 | 
            +
                        self.G = nn.DataParallel(self.G)
         | 
| 223 | 
            +
                        self.D = nn.DataParallel(self.D)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    self.G.to(self.device)
         | 
| 226 | 
            +
                    self.D.to(self.device)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def print_network(self, model, name, save_dir):
         | 
| 229 | 
            +
                    """Print out the network information."""
         | 
| 230 | 
            +
                    num_params = 0
         | 
| 231 | 
            +
                    for p in model.parameters():
         | 
| 232 | 
            +
                        num_params += p.numel()
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    if not os.path.exists(save_dir):
         | 
| 235 | 
            +
                        os.makedirs(save_dir)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    network_path = os.path.join(save_dir, "{}_modules.txt".format(name))
         | 
| 238 | 
            +
                    with open(network_path, "w+") as file:
         | 
| 239 | 
            +
                        for module in model.modules():
         | 
| 240 | 
            +
                            file.write(f"{module.__class__.__name__}:\n")
         | 
| 241 | 
            +
                            print(module.__class__.__name__)
         | 
| 242 | 
            +
                            for n, param in module.named_parameters():
         | 
| 243 | 
            +
                                if param is not None:
         | 
| 244 | 
            +
                                    file.write(f"  - {n}: {param.size()}\n")
         | 
| 245 | 
            +
                                    print(f"  - {n}: {param.size()}")
         | 
| 246 | 
            +
                            break
         | 
| 247 | 
            +
                        file.write(f"Total number of parameters: {num_params}\n")
         | 
| 248 | 
            +
                        print(f"Total number of parameters: {num_params}\n\n")
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                def restore_model(self, epoch, iteration, model_directory):
         | 
| 251 | 
            +
                    """Restore the trained generator and discriminator."""
         | 
| 252 | 
            +
                    print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
         | 
| 255 | 
            +
                    D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
         | 
| 256 | 
            +
                    self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
         | 
| 257 | 
            +
                    self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                def save_model(self, model_directory, idx,i):
         | 
| 260 | 
            +
                    G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
         | 
| 261 | 
            +
                    D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
         | 
| 262 | 
            +
                    torch.save(self.G.state_dict(), G_path)
         | 
| 263 | 
            +
                    torch.save(self.D.state_dict(), D_path)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def reset_grad(self):
         | 
| 266 | 
            +
                    """Reset the gradient buffers."""
         | 
| 267 | 
            +
                    self.g_optimizer.zero_grad()
         | 
| 268 | 
            +
                    self.d_optimizer.zero_grad()
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                def train(self, config):
         | 
| 271 | 
            +
                    ''' Training Script starts from here'''
         | 
| 272 | 
            +
                    if self.use_wandb:
         | 
| 273 | 
            +
                        mode = 'online' if self.online else 'offline'
         | 
| 274 | 
            +
                    else:
         | 
| 275 | 
            +
                        mode = 'disabled'
         | 
| 276 | 
            +
                    kwargs = {'name': self.exp_name, 'project': 'druggen', 'config': config,
         | 
| 277 | 
            +
                            'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode, 'save_code': True}
         | 
| 278 | 
            +
                    wandb.init(**kwargs)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    wandb.save(os.path.join(self.model_save_dir, self.arguments, "G_modules.txt"))
         | 
| 281 | 
            +
                    wandb.save(os.path.join(self.model_save_dir, self.arguments, "D_modules.txt"))
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    self.model_directory = os.path.join(self.model_save_dir, self.arguments)
         | 
| 284 | 
            +
                    self.sample_directory = os.path.join(self.sample_dir, self.arguments)
         | 
| 285 | 
            +
                    self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
         | 
| 286 | 
            +
                    if not os.path.exists(self.model_directory):
         | 
| 287 | 
            +
                        os.makedirs(self.model_directory)
         | 
| 288 | 
            +
                    if not os.path.exists(self.sample_directory):
         | 
| 289 | 
            +
                        os.makedirs(self.sample_directory)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    # smiles data for metrics calculation.
         | 
| 292 | 
            +
                    drug_smiles = [line for line in open(self.drug_raw_file, 'r').read().splitlines()]
         | 
| 293 | 
            +
                    drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
         | 
| 294 | 
            +
                    drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    if self.resume:
         | 
| 297 | 
            +
                        self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    # Start training.
         | 
| 300 | 
            +
                    print('Start training...')
         | 
| 301 | 
            +
                    self.start_time = time.time()
         | 
| 302 | 
            +
                    for idx in range(self.epoch):
         | 
| 303 | 
            +
                        # =================================================================================== #
         | 
| 304 | 
            +
                        #                             1. Preprocess input data                                #
         | 
| 305 | 
            +
                        # =================================================================================== #
         | 
| 306 | 
            +
                        # Load the data
         | 
| 307 | 
            +
                        dataloader_iterator = iter(self.drugs_loader)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                        wandb.log({"epoch": idx})
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                        for i, data in enumerate(self.loader):
         | 
| 312 | 
            +
                            try:
         | 
| 313 | 
            +
                                drugs = next(dataloader_iterator)
         | 
| 314 | 
            +
                            except StopIteration:
         | 
| 315 | 
            +
                                dataloader_iterator = iter(self.drugs_loader)
         | 
| 316 | 
            +
                                drugs = next(dataloader_iterator)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                            wandb.log({"iter": i})
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                            # Preprocess both dataset
         | 
| 321 | 
            +
                            real_graphs, a_tensor, x_tensor = load_molecules(
         | 
| 322 | 
            +
                                data=data,
         | 
| 323 | 
            +
                                batch_size=self.batch_size,
         | 
| 324 | 
            +
                                device=self.device,
         | 
| 325 | 
            +
                                b_dim=self.b_dim,
         | 
| 326 | 
            +
                                m_dim=self.m_dim,
         | 
| 327 | 
            +
                            )
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                            drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
         | 
| 330 | 
            +
                                data=drugs,
         | 
| 331 | 
            +
                                batch_size=self.batch_size,
         | 
| 332 | 
            +
                                device=self.device,
         | 
| 333 | 
            +
                                b_dim=self.b_dim,
         | 
| 334 | 
            +
                                m_dim=self.m_dim,
         | 
| 335 | 
            +
                            )
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                            # Training configuration.
         | 
| 338 | 
            +
                            GEN_node = x_tensor             # Generator input node features (annotation matrix of real molecules)
         | 
| 339 | 
            +
                            GEN_edge = a_tensor             # Generator input edge features (adjacency matrix of real molecules)
         | 
| 340 | 
            +
                            if self.submodel == "DrugGEN":
         | 
| 341 | 
            +
                                DISC_node = drugs_x_tensor  # Discriminator input node features (annotation matrix of drug molecules)
         | 
| 342 | 
            +
                                DISC_edge = drugs_a_tensor  # Discriminator input edge features (adjacency matrix of drug molecules)
         | 
| 343 | 
            +
                            elif self.submodel == "NoTarget":
         | 
| 344 | 
            +
                                DISC_node = x_tensor      # Discriminator input node features (annotation matrix of real molecules)
         | 
| 345 | 
            +
                                DISC_edge = a_tensor      # Discriminator input edge features (adjacency matrix of real molecules)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                            # =================================================================================== #
         | 
| 348 | 
            +
                            #                                     2. Train the GAN                                #
         | 
| 349 | 
            +
                            # =================================================================================== #
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                            loss = {}
         | 
| 352 | 
            +
                            self.reset_grad()
         | 
| 353 | 
            +
                            # Compute discriminator loss.
         | 
| 354 | 
            +
                            node, edge, d_loss = discriminator_loss(self.G,
         | 
| 355 | 
            +
                                                        self.D,
         | 
| 356 | 
            +
                                                        DISC_edge,
         | 
| 357 | 
            +
                                                        DISC_node,
         | 
| 358 | 
            +
                                                        GEN_edge,
         | 
| 359 | 
            +
                                                        GEN_node,
         | 
| 360 | 
            +
                                                        self.batch_size,
         | 
| 361 | 
            +
                                                        self.device,
         | 
| 362 | 
            +
                                                        self.lambda_gp)
         | 
| 363 | 
            +
                            d_total = d_loss
         | 
| 364 | 
            +
                            wandb.log({"d_loss": d_total.item()})
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                            loss["d_total"] = d_total.item()
         | 
| 367 | 
            +
                            d_total.backward()
         | 
| 368 | 
            +
                            self.d_optimizer.step()
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                            self.reset_grad()
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                            # Compute generator loss.
         | 
| 373 | 
            +
                            generator_output = generator_loss(self.G,
         | 
| 374 | 
            +
                                                                self.D,
         | 
| 375 | 
            +
                                                                GEN_edge,
         | 
| 376 | 
            +
                                                                GEN_node,
         | 
| 377 | 
            +
                                                                self.batch_size)
         | 
| 378 | 
            +
                            g_loss, node, edge, node_sample, edge_sample = generator_output
         | 
| 379 | 
            +
                            g_total = g_loss
         | 
| 380 | 
            +
                            wandb.log({"g_loss": g_total.item()})
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                            loss["g_total"] = g_total.item()
         | 
| 383 | 
            +
                            g_total.backward()
         | 
| 384 | 
            +
                            self.g_optimizer.step()
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                            # Logging.
         | 
| 387 | 
            +
                            if (i+1) % self.log_step == 0:
         | 
| 388 | 
            +
                                logging(self.log_path, self.start_time, i, idx, loss, self.sample_directory,
         | 
| 389 | 
            +
                                        drug_smiles,edge_sample, node_sample, self.dataset.matrices2mol,
         | 
| 390 | 
            +
                                        self.dataset_name, a_tensor, x_tensor, drug_vecs)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                                mol_sample(self.sample_directory, edge_sample.detach(), node_sample.detach(),
         | 
| 393 | 
            +
                                           idx, i, self.dataset.matrices2mol, self.dataset_name)
         | 
| 394 | 
            +
                                print("samples saved at epoch {} and iteration {}".format(idx,i))
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                                self.save_model(self.model_directory, idx, i)
         | 
| 397 | 
            +
                                print("model saved at epoch {} and iteration {}".format(idx,i))
         | 
| 398 | 
            +
             | 
| 399 | 
            +
             | 
| 400 | 
            +
            if __name__ == '__main__':
         | 
| 401 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                # Data configuration.
         | 
| 404 | 
            +
                parser.add_argument('--raw_file', type=str, required=True)
         | 
| 405 | 
            +
                parser.add_argument('--drug_raw_file', type=str, required=False, help='Required for DrugGEN model, optional for NoTarget')
         | 
| 406 | 
            +
                parser.add_argument('--drug_data_dir', type=str, default='data')
         | 
| 407 | 
            +
                parser.add_argument('--mol_data_dir', type=str, default='data')
         | 
| 408 | 
            +
                parser.add_argument('--features', action='store_true', help='features dimension for nodes')
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                # Model configuration.
         | 
| 411 | 
            +
                parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
         | 
| 412 | 
            +
                parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
         | 
| 413 | 
            +
                parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
         | 
| 414 | 
            +
                parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
         | 
| 415 | 
            +
                parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
         | 
| 416 | 
            +
                parser.add_argument('--ddepth', type=int, default=1, help='Depth of the Transformer model from the discriminator.')
         | 
| 417 | 
            +
                parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
         | 
| 418 | 
            +
                parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
         | 
| 419 | 
            +
                parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
         | 
| 420 | 
            +
                parser.add_argument('--ddropout', type=float, default=0., help='dropout rate for the discriminator')
         | 
| 421 | 
            +
                parser.add_argument('--lambda_gp', type=float, default=10, help='Gradient penalty lambda multiplier for the GAN.')
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                # Training configuration.
         | 
| 424 | 
            +
                parser.add_argument('--batch_size', type=int, default=128, help='Batch size for the training.')
         | 
| 425 | 
            +
                parser.add_argument('--epoch', type=int, default=10, help='Epoch number for Training.')
         | 
| 426 | 
            +
                parser.add_argument('--g_lr', type=float, default=0.00001, help='learning rate for G')
         | 
| 427 | 
            +
                parser.add_argument('--d_lr', type=float, default=0.00001, help='learning rate for D')
         | 
| 428 | 
            +
                parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer')
         | 
| 429 | 
            +
                parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
         | 
| 430 | 
            +
                parser.add_argument('--log_dir', type=str, default='experiments/logs')
         | 
| 431 | 
            +
                parser.add_argument('--sample_dir', type=str, default='experiments/samples')
         | 
| 432 | 
            +
                parser.add_argument('--model_save_dir', type=str, default='experiments/models')
         | 
| 433 | 
            +
                parser.add_argument('--log_sample_step', type=int, default=1000, help='step size for sampling during training')
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                # Resume training.
         | 
| 436 | 
            +
                parser.add_argument('--resume', type=bool, default=False, help='resume training')
         | 
| 437 | 
            +
                parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch')
         | 
| 438 | 
            +
                parser.add_argument('--resume_iter', type=int, default=None, help='resume training from this step')
         | 
| 439 | 
            +
                parser.add_argument('--resume_directory', type=str, default=None, help='load pretrained weights from this directory')
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                # Seed configuration.
         | 
| 442 | 
            +
                parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
         | 
| 443 | 
            +
                parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                # wandb configuration.
         | 
| 446 | 
            +
                parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging')
         | 
| 447 | 
            +
                parser.add_argument('--online', action='store_true', help='use wandb online')
         | 
| 448 | 
            +
                parser.add_argument('--exp_name', type=str, default='druggen', help='experiment name')
         | 
| 449 | 
            +
                parser.add_argument('--parallel', action='store_true', help='Parallelize training')
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                config = parser.parse_args()
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                # Check if drug_raw_file is provided when using DrugGEN model
         | 
| 454 | 
            +
                if config.submodel == "DrugGEN" and not config.drug_raw_file:
         | 
| 455 | 
            +
                    parser.error("--drug_raw_file is required when using DrugGEN model")
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                # If using NoTarget model and drug_raw_file is not provided, use a dummy file
         | 
| 458 | 
            +
                if config.submodel == "NoTarget" and not config.drug_raw_file:
         | 
| 459 | 
            +
                    config.drug_raw_file = "data/akt_train.smi"  # Use a reference file for NoTarget model (AKT) (not used for training for ease of use and encoder/decoder's)
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                trainer = Train(config)
         | 
| 462 | 
            +
                trainer.train(config)
         | 
