Spaces:
Runtime error
Runtime error
Pocket-conditioned generation
Browse files- app.py +331 -61
- examples/3hz1_fragments.sdf +54 -0
- examples/3hz1_protein.pdb +0 -0
- examples/5ou2_fragments.sdf +56 -0
- examples/5ou2_protein.pdb +0 -0
- output.py +184 -9
- src/datasets.py +14 -4
- src/generation.py +83 -4
- src/lightning.py +7 -4
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
|
@@ -9,10 +10,12 @@ import output
|
|
| 9 |
|
| 10 |
from rdkit import Chem
|
| 11 |
from src import const
|
| 12 |
-
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
|
| 13 |
from src.lightning import DDPM
|
| 14 |
from src.linker_size_lightning import SizeClassifier
|
| 15 |
-
from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf
|
|
|
|
|
|
|
| 16 |
|
| 17 |
MODELS_METADATA = {
|
| 18 |
'geom_difflinker': {
|
|
@@ -85,65 +88,167 @@ def read_molecule(path):
|
|
| 85 |
raise Exception('Unknown file extension')
|
| 86 |
|
| 87 |
|
| 88 |
-
def
|
| 89 |
-
if
|
| 90 |
-
|
| 91 |
-
if isinstance(input_file, str):
|
| 92 |
-
path = input_file
|
| 93 |
else:
|
| 94 |
-
path =
|
| 95 |
extension = path.split('.')[-1]
|
| 96 |
-
|
|
|
|
| 97 |
msg = output.INVALID_FORMAT_MSG.format(extension=extension)
|
| 98 |
-
return
|
| 99 |
-
output.IFRAME_TEMPLATE.format(html=msg),
|
| 100 |
-
gr.Radio.update(visible=False),
|
| 101 |
-
None,
|
| 102 |
-
]
|
| 103 |
|
| 104 |
try:
|
| 105 |
-
|
| 106 |
except Exception as e:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def draw_sample(idx, out_files):
|
|
|
|
|
|
|
| 122 |
if isinstance(idx, str):
|
| 123 |
idx = int(idx.strip().split(' ')[-1]) - 1
|
| 124 |
|
| 125 |
-
in_file = out_files[
|
| 126 |
in_sdf = in_file if isinstance(in_file, str) else in_file.name
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
|
|
|
| 132 |
generated_molecule_content = read_molecule_content(out_sdf)
|
| 133 |
-
|
| 134 |
-
fragments_fmt = in_sdf.split('.')[-1]
|
| 135 |
molecule_fmt = out_sdf.split('.')[-1]
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
return output.IFRAME_TEMPLATE.format(html=html)
|
| 144 |
|
| 145 |
|
| 146 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# Parsing selected atoms (javascript output)
|
| 148 |
selected_atoms = selected_atoms.strip()
|
| 149 |
if selected_atoms == '':
|
|
@@ -157,9 +262,6 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
|
| 157 |
else:
|
| 158 |
selected_model_name = 'geom_difflinker_given_anchors'
|
| 159 |
|
| 160 |
-
if input_file is None:
|
| 161 |
-
return [None, None, None, None]
|
| 162 |
-
|
| 163 |
print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
|
| 164 |
ddpm = diffusion_models[selected_model_name]
|
| 165 |
path = input_file.name
|
|
@@ -170,20 +272,25 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
|
| 170 |
|
| 171 |
try:
|
| 172 |
molecule = read_molecule(path)
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
| 174 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
| 175 |
inp_sdf = f'results/input_{name}.sdf'
|
| 176 |
except Exception as e:
|
|
|
|
| 177 |
error = f'Could not read the molecule: {e}'
|
| 178 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 179 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 180 |
|
| 181 |
-
if molecule.GetNumAtoms() >
|
| 182 |
-
error = f'Too large molecule: upper limit is
|
| 183 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 184 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 185 |
|
| 186 |
with Chem.SDWriter(inp_sdf) as w:
|
|
|
|
| 187 |
w.write(molecule)
|
| 188 |
|
| 189 |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
|
|
@@ -227,14 +334,152 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
|
| 227 |
|
| 228 |
for data in dataloader:
|
| 229 |
try:
|
| 230 |
-
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name)
|
| 231 |
except Exception as e:
|
|
|
|
| 232 |
error = f'Caught exception while generating linkers: {e}'
|
| 233 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 234 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 235 |
|
| 236 |
out_files = try_to_convert_to_sdf(name)
|
| 237 |
out_files = [inp_sdf] + out_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
return [
|
| 240 |
draw_sample(radio_samples, out_files),
|
|
@@ -260,19 +505,34 @@ with demo:
|
|
| 260 |
with gr.Box():
|
| 261 |
with gr.Row():
|
| 262 |
with gr.Column():
|
| 263 |
-
gr.Markdown('## Input
|
| 264 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
n_atoms = gr.Slider(
|
| 268 |
minimum=0, maximum=20,
|
| 269 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
| 270 |
step=1
|
| 271 |
)
|
| 272 |
examples = gr.Dataset(
|
| 273 |
-
components=[gr.File(visible=False)],
|
| 274 |
-
samples=[
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
)
|
| 277 |
|
| 278 |
button = gr.Button('Generate Linker!')
|
|
@@ -294,24 +554,34 @@ with demo:
|
|
| 294 |
)
|
| 295 |
visualization = gr.HTML()
|
| 296 |
|
| 297 |
-
|
| 298 |
fn=show_input,
|
| 299 |
-
inputs=[
|
| 300 |
outputs=[visualization, samples, hidden],
|
| 301 |
)
|
| 302 |
-
|
| 303 |
-
fn=
|
| 304 |
-
inputs=[],
|
| 305 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
)
|
| 307 |
examples.click(
|
| 308 |
-
fn=
|
| 309 |
inputs=[examples],
|
| 310 |
-
outputs=[
|
| 311 |
)
|
| 312 |
button.click(
|
| 313 |
fn=generate,
|
| 314 |
-
inputs=[
|
| 315 |
outputs=[visualization, output_files, samples, hidden],
|
| 316 |
_js=output.RETURN_SELECTION_JS,
|
| 317 |
)
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import shutil
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
|
|
|
| 10 |
|
| 11 |
from rdkit import Chem
|
| 12 |
from src import const
|
| 13 |
+
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
|
| 14 |
from src.lightning import DDPM
|
| 15 |
from src.linker_size_lightning import SizeClassifier
|
| 16 |
+
from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf, get_pocket
|
| 17 |
+
from zipfile import ZipFile
|
| 18 |
+
|
| 19 |
|
| 20 |
MODELS_METADATA = {
|
| 21 |
'geom_difflinker': {
|
|
|
|
| 88 |
raise Exception('Unknown file extension')
|
| 89 |
|
| 90 |
|
| 91 |
+
def read_molecule_file(in_file, allowed_extentions):
|
| 92 |
+
if isinstance(in_file, str):
|
| 93 |
+
path = in_file
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
+
path = in_file.name
|
| 96 |
extension = path.split('.')[-1]
|
| 97 |
+
|
| 98 |
+
if extension not in allowed_extentions:
|
| 99 |
msg = output.INVALID_FORMAT_MSG.format(extension=extension)
|
| 100 |
+
return None, None, msg
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
try:
|
| 103 |
+
mol = read_molecule(path)
|
| 104 |
except Exception as e:
|
| 105 |
+
e = str(e).replace('\'', '')
|
| 106 |
+
msg = output.ERROR_FORMAT_MSG.format(message=e)
|
| 107 |
+
return None, None, msg
|
| 108 |
+
|
| 109 |
+
if extension == 'pdb':
|
| 110 |
+
content = Chem.MolToPDBBlock(mol)
|
| 111 |
+
elif extension in ['mol', 'mol2', 'sdf']:
|
| 112 |
+
content = Chem.MolToMolBlock(mol, kekulize=False)
|
| 113 |
+
extension = 'mol'
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError
|
| 116 |
|
| 117 |
+
return content, extension, None
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def show_input(in_fragments, in_protein):
|
| 121 |
+
vis = ''
|
| 122 |
+
if in_fragments is not None and in_protein is None:
|
| 123 |
+
vis = show_fragments(in_fragments)
|
| 124 |
+
elif in_fragments is None and in_protein is not None:
|
| 125 |
+
vis = show_target(in_protein)
|
| 126 |
+
elif in_fragments is not None and in_protein is not None:
|
| 127 |
+
vis = show_fragments_and_target(in_fragments, in_protein)
|
| 128 |
+
return [vis, gr.Radio.update(visible=False), None]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def show_fragments(in_fragments):
|
| 132 |
+
molecule, extension, html = read_molecule_file(in_fragments, allowed_extentions=['sdf', 'pdb', 'mol', 'mol2'])
|
| 133 |
+
if molecule is not None:
|
| 134 |
+
html = output.FRAGMENTS_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
|
| 135 |
+
|
| 136 |
+
return output.IFRAME_TEMPLATE.format(html=html)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def show_target(in_protein):
|
| 140 |
+
molecule, extension, html = read_molecule_file(in_protein, allowed_extentions=['pdb'])
|
| 141 |
+
if molecule is not None:
|
| 142 |
+
html = output.TARGET_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
|
| 143 |
+
|
| 144 |
+
return output.IFRAME_TEMPLATE.format(html=html)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def show_fragments_and_target(in_fragments, in_protein):
|
| 148 |
+
fragments_molecule, fragments_extension, msg = read_molecule_file(in_fragments, ['sdf', 'pdb', 'mol', 'mol2'])
|
| 149 |
+
if fragments_molecule is None:
|
| 150 |
+
return output.IFRAME_TEMPLATE.format(html=msg)
|
| 151 |
+
|
| 152 |
+
target_molecule, target_extension, msg = read_molecule_file(in_protein, allowed_extentions=['pdb'])
|
| 153 |
+
if fragments_molecule is None:
|
| 154 |
+
return output.IFRAME_TEMPLATE.format(html=msg)
|
| 155 |
+
|
| 156 |
+
html = output.FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE.format(
|
| 157 |
+
molecule=fragments_molecule,
|
| 158 |
+
fmt=fragments_extension,
|
| 159 |
+
target=target_molecule,
|
| 160 |
+
target_fmt=target_extension,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return output.IFRAME_TEMPLATE.format(html=html)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def clear_fragments_input(in_protein):
|
| 167 |
+
vis = ''
|
| 168 |
+
if in_protein is not None:
|
| 169 |
+
vis = show_target(in_protein)
|
| 170 |
+
return [None, vis, gr.Radio.update(visible=False), None]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def clear_protein_input(in_fragments):
|
| 174 |
+
vis = ''
|
| 175 |
+
if in_fragments is not None:
|
| 176 |
+
vis = show_fragments(in_fragments)
|
| 177 |
+
return [None, vis, gr.Radio.update(visible=False), None]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def click_on_example(example):
|
| 181 |
+
print('Clicked:', example)
|
| 182 |
+
fragment_fname, target_fname = example
|
| 183 |
+
fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None
|
| 184 |
+
target_path = f'examples/{target_fname}' if target_fname != '' else None
|
| 185 |
+
return [fragment_path, target_path, 50, 0] + show_input(fragment_path, target_path)
|
| 186 |
|
| 187 |
|
| 188 |
def draw_sample(idx, out_files):
|
| 189 |
+
with_protein = (len(out_files) == N_SAMPLES + 3)
|
| 190 |
+
|
| 191 |
if isinstance(idx, str):
|
| 192 |
idx = int(idx.strip().split(' ')[-1]) - 1
|
| 193 |
|
| 194 |
+
in_file = out_files[1]
|
| 195 |
in_sdf = in_file if isinstance(in_file, str) else in_file.name
|
| 196 |
+
input_fragments_content = read_molecule_content(in_sdf)
|
| 197 |
+
fragments_fmt = in_sdf.split('.')[-1]
|
| 198 |
|
| 199 |
+
offset = 2
|
| 200 |
+
input_target_content = None
|
| 201 |
+
target_fmt = None
|
| 202 |
+
if with_protein:
|
| 203 |
+
offset += 1
|
| 204 |
+
in_pdb = out_files[2] if isinstance(out_files[2], str) else out_files[2].name
|
| 205 |
+
input_target_content = read_molecule_content(in_pdb)
|
| 206 |
+
target_fmt = in_pdb.split('.')[-1]
|
| 207 |
|
| 208 |
+
out_file = out_files[idx + offset]
|
| 209 |
+
out_sdf = out_file if isinstance(out_file, str) else out_file.name
|
| 210 |
generated_molecule_content = read_molecule_content(out_sdf)
|
|
|
|
|
|
|
| 211 |
molecule_fmt = out_sdf.split('.')[-1]
|
| 212 |
|
| 213 |
+
if with_protein:
|
| 214 |
+
html = output.SAMPLES_WITH_TARGET_RENDERING_TEMPLATE.format(
|
| 215 |
+
fragments=input_fragments_content,
|
| 216 |
+
fragments_fmt=fragments_fmt,
|
| 217 |
+
molecule=generated_molecule_content,
|
| 218 |
+
molecule_fmt=molecule_fmt,
|
| 219 |
+
target=input_target_content,
|
| 220 |
+
target_fmt=target_fmt,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
html = output.SAMPLES_RENDERING_TEMPLATE.format(
|
| 224 |
+
fragments=input_fragments_content,
|
| 225 |
+
fragments_fmt=fragments_fmt,
|
| 226 |
+
molecule=generated_molecule_content,
|
| 227 |
+
molecule_fmt=molecule_fmt,
|
| 228 |
+
)
|
| 229 |
return output.IFRAME_TEMPLATE.format(html=html)
|
| 230 |
|
| 231 |
|
| 232 |
+
def compress(output_fnames, name):
|
| 233 |
+
archive_path = f'results/all_files_{name}.zip'
|
| 234 |
+
with ZipFile(archive_path, 'w') as archive:
|
| 235 |
+
for fname in output_fnames:
|
| 236 |
+
archive.write(fname)
|
| 237 |
+
|
| 238 |
+
return archive_path
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def generate(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
|
| 242 |
+
if in_fragments is None:
|
| 243 |
+
return [None, None, None, None]
|
| 244 |
+
|
| 245 |
+
if in_protein is None:
|
| 246 |
+
return generate_without_pocket(in_fragments, n_steps, n_atoms, radio_samples, selected_atoms)
|
| 247 |
+
else:
|
| 248 |
+
return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
| 252 |
# Parsing selected atoms (javascript output)
|
| 253 |
selected_atoms = selected_atoms.strip()
|
| 254 |
if selected_atoms == '':
|
|
|
|
| 262 |
else:
|
| 263 |
selected_model_name = 'geom_difflinker_given_anchors'
|
| 264 |
|
|
|
|
|
|
|
|
|
|
| 265 |
print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
|
| 266 |
ddpm = diffusion_models[selected_model_name]
|
| 267 |
path = input_file.name
|
|
|
|
| 272 |
|
| 273 |
try:
|
| 274 |
molecule = read_molecule(path)
|
| 275 |
+
try:
|
| 276 |
+
molecule = Chem.RemoveAllHs(molecule)
|
| 277 |
+
except:
|
| 278 |
+
pass
|
| 279 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
| 280 |
inp_sdf = f'results/input_{name}.sdf'
|
| 281 |
except Exception as e:
|
| 282 |
+
e = str(e).replace('\'', '')
|
| 283 |
error = f'Could not read the molecule: {e}'
|
| 284 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 285 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 286 |
|
| 287 |
+
if molecule.GetNumAtoms() > 100:
|
| 288 |
+
error = f'Too large molecule: upper limit is 100 heavy atoms'
|
| 289 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 290 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 291 |
|
| 292 |
with Chem.SDWriter(inp_sdf) as w:
|
| 293 |
+
w.SetKekulize(False)
|
| 294 |
w.write(molecule)
|
| 295 |
|
| 296 |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
|
|
|
|
| 334 |
|
| 335 |
for data in dataloader:
|
| 336 |
try:
|
| 337 |
+
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False)
|
| 338 |
except Exception as e:
|
| 339 |
+
e = str(e).replace('\'', '')
|
| 340 |
error = f'Caught exception while generating linkers: {e}'
|
| 341 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 342 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 343 |
|
| 344 |
out_files = try_to_convert_to_sdf(name)
|
| 345 |
out_files = [inp_sdf] + out_files
|
| 346 |
+
out_files = [compress(out_files, name=name)] + out_files
|
| 347 |
+
|
| 348 |
+
return [
|
| 349 |
+
draw_sample(radio_samples, out_files),
|
| 350 |
+
out_files,
|
| 351 |
+
gr.Radio.update(visible=True),
|
| 352 |
+
None
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
|
| 357 |
+
# Parsing selected atoms (javascript output)
|
| 358 |
+
selected_atoms = selected_atoms.strip()
|
| 359 |
+
if selected_atoms == '':
|
| 360 |
+
selected_atoms = []
|
| 361 |
+
else:
|
| 362 |
+
selected_atoms = list(map(int, selected_atoms.split(',')))
|
| 363 |
+
|
| 364 |
+
# Selecting model
|
| 365 |
+
if len(selected_atoms) == 0:
|
| 366 |
+
selected_model_name = 'pockets_difflinker'
|
| 367 |
+
else:
|
| 368 |
+
selected_model_name = 'pockets_difflinker_given_anchors'
|
| 369 |
+
|
| 370 |
+
print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
|
| 371 |
+
ddpm = diffusion_models[selected_model_name]
|
| 372 |
+
|
| 373 |
+
fragments_path = in_fragments.name
|
| 374 |
+
fragments_extension = fragments_path.split('.')[-1]
|
| 375 |
+
if fragments_extension not in ['sdf', 'pdb', 'mol', 'mol2']:
|
| 376 |
+
msg = output.INVALID_FORMAT_MSG.format(extension=fragments_extension)
|
| 377 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 378 |
+
|
| 379 |
+
protein_path = in_protein.name
|
| 380 |
+
protein_extension = protein_path.split('.')[-1]
|
| 381 |
+
if protein_extension not in ['pdb']:
|
| 382 |
+
msg = output.INVALID_FORMAT_MSG.format(extension=protein_extension)
|
| 383 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 384 |
+
|
| 385 |
+
try:
|
| 386 |
+
fragments_mol = read_molecule(fragments_path)
|
| 387 |
+
name = '.'.join(fragments_path.split('/')[-1].split('.')[:-1])
|
| 388 |
+
except Exception as e:
|
| 389 |
+
e = str(e).replace('\'', '')
|
| 390 |
+
error = f'Could not read the molecule: {e}'
|
| 391 |
+
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 392 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 393 |
+
|
| 394 |
+
if fragments_mol.GetNumAtoms() > 100:
|
| 395 |
+
error = f'Too large molecule: upper limit is 100 heavy atoms'
|
| 396 |
+
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 397 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 398 |
+
|
| 399 |
+
inp_sdf = f'results/input_{name}.sdf'
|
| 400 |
+
with Chem.SDWriter(inp_sdf) as w:
|
| 401 |
+
w.SetKekulize(False)
|
| 402 |
+
w.write(fragments_mol)
|
| 403 |
+
|
| 404 |
+
inp_pdb = f'results/target_{name}.pdb'
|
| 405 |
+
shutil.copy(protein_path, inp_pdb)
|
| 406 |
+
|
| 407 |
+
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments_mol, is_geom=True)
|
| 408 |
+
pocket_pos, pocket_one_hot, pocket_charges = get_pocket(fragments_mol, protein_path)
|
| 409 |
+
print(f'Detected pocket with {len(pocket_pos)} atoms')
|
| 410 |
+
|
| 411 |
+
positions = np.concatenate([frag_pos, pocket_pos], axis=0)
|
| 412 |
+
one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0)
|
| 413 |
+
charges = np.concatenate([frag_charges, pocket_charges], axis=0)
|
| 414 |
+
anchors = np.zeros_like(charges)
|
| 415 |
+
anchors[selected_atoms] = 1
|
| 416 |
+
|
| 417 |
+
fragment_only_mask = np.concatenate([
|
| 418 |
+
np.ones_like(frag_charges),
|
| 419 |
+
np.zeros_like(pocket_charges),
|
| 420 |
+
])
|
| 421 |
+
pocket_mask = np.concatenate([
|
| 422 |
+
np.zeros_like(frag_charges),
|
| 423 |
+
np.ones_like(pocket_charges),
|
| 424 |
+
])
|
| 425 |
+
linker_mask = np.concatenate([
|
| 426 |
+
np.zeros_like(frag_charges),
|
| 427 |
+
np.zeros_like(pocket_charges),
|
| 428 |
+
])
|
| 429 |
+
fragment_mask = np.concatenate([
|
| 430 |
+
np.ones_like(frag_charges),
|
| 431 |
+
np.ones_like(pocket_charges),
|
| 432 |
+
])
|
| 433 |
+
print('Read and parsed molecule')
|
| 434 |
+
|
| 435 |
+
dataset = [{
|
| 436 |
+
'uuid': '0',
|
| 437 |
+
'name': '0',
|
| 438 |
+
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
| 439 |
+
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
| 440 |
+
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
| 441 |
+
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
|
| 442 |
+
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 443 |
+
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 444 |
+
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 445 |
+
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 446 |
+
'num_atoms': len(positions),
|
| 447 |
+
}] * N_SAMPLES
|
| 448 |
+
dataset = MOADDataset(data=dataset)
|
| 449 |
+
ddpm.val_dataset = dataset
|
| 450 |
+
|
| 451 |
+
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
| 452 |
+
print('Created dataloader')
|
| 453 |
+
|
| 454 |
+
ddpm.edm.T = n_steps
|
| 455 |
+
|
| 456 |
+
if n_atoms == 0:
|
| 457 |
+
def sample_fn(_data):
|
| 458 |
+
out, _ = size_nn.forward(_data, return_loss=False)
|
| 459 |
+
probabilities = torch.softmax(out, dim=1)
|
| 460 |
+
distribution = torch.distributions.Categorical(probs=probabilities)
|
| 461 |
+
samples = distribution.sample()
|
| 462 |
+
sizes = []
|
| 463 |
+
for label in samples.detach().cpu().numpy():
|
| 464 |
+
sizes.append(size_nn.linker_id2size[label])
|
| 465 |
+
sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
|
| 466 |
+
return sizes
|
| 467 |
+
else:
|
| 468 |
+
def sample_fn(_data):
|
| 469 |
+
return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
|
| 470 |
+
|
| 471 |
+
for data in dataloader:
|
| 472 |
+
try:
|
| 473 |
+
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=True)
|
| 474 |
+
except Exception as e:
|
| 475 |
+
e = str(e).replace('\'', '')
|
| 476 |
+
error = f'Caught exception while generating linkers: {e}'
|
| 477 |
+
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
| 478 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
| 479 |
+
|
| 480 |
+
out_files = try_to_convert_to_sdf(name)
|
| 481 |
+
out_files = [inp_sdf, inp_pdb] + out_files
|
| 482 |
+
out_files = [compress(out_files, name=name)] + out_files
|
| 483 |
|
| 484 |
return [
|
| 485 |
draw_sample(radio_samples, out_files),
|
|
|
|
| 505 |
with gr.Box():
|
| 506 |
with gr.Row():
|
| 507 |
with gr.Column():
|
| 508 |
+
gr.Markdown('## Input')
|
| 509 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
| 510 |
+
with gr.Column():
|
| 511 |
+
input_fragments_file = gr.File(
|
| 512 |
+
file_count='single',
|
| 513 |
+
label='Input Fragments',
|
| 514 |
+
file_types=['.sdf', '.pdb', '.mol', '.mol2']
|
| 515 |
+
)
|
| 516 |
+
# gr.Markdown('(Optionally) upload the file of the target protein in .pdb format:')
|
| 517 |
+
with gr.Column():
|
| 518 |
+
input_protein_file = gr.File(file_count='single', label='Target Protein', file_types=['.pdb'])
|
| 519 |
+
|
| 520 |
+
n_steps = gr.Slider(minimum=50, maximum=500, label="Number of Denoising Steps", step=10)
|
| 521 |
n_atoms = gr.Slider(
|
| 522 |
minimum=0, maximum=20,
|
| 523 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
| 524 |
step=1
|
| 525 |
)
|
| 526 |
examples = gr.Dataset(
|
| 527 |
+
components=[gr.File(visible=False), gr.File(visible=False)],
|
| 528 |
+
samples=[
|
| 529 |
+
['examples/example_1.sdf', None],
|
| 530 |
+
['examples/example_2.sdf', None],
|
| 531 |
+
['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'],
|
| 532 |
+
['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'],
|
| 533 |
+
],
|
| 534 |
+
headers=['Fragments', 'Target Protein'],
|
| 535 |
+
type='values',
|
| 536 |
)
|
| 537 |
|
| 538 |
button = gr.Button('Generate Linker!')
|
|
|
|
| 554 |
)
|
| 555 |
visualization = gr.HTML()
|
| 556 |
|
| 557 |
+
input_fragments_file.change(
|
| 558 |
fn=show_input,
|
| 559 |
+
inputs=[input_fragments_file, input_protein_file],
|
| 560 |
outputs=[visualization, samples, hidden],
|
| 561 |
)
|
| 562 |
+
input_protein_file.change(
|
| 563 |
+
fn=show_input,
|
| 564 |
+
inputs=[input_fragments_file, input_protein_file],
|
| 565 |
+
outputs=[visualization, samples, hidden],
|
| 566 |
+
)
|
| 567 |
+
input_fragments_file.clear(
|
| 568 |
+
fn=clear_fragments_input,
|
| 569 |
+
inputs=[input_protein_file],
|
| 570 |
+
outputs=[input_fragments_file, visualization, samples, hidden],
|
| 571 |
+
)
|
| 572 |
+
input_protein_file.clear(
|
| 573 |
+
fn=clear_protein_input,
|
| 574 |
+
inputs=[input_fragments_file],
|
| 575 |
+
outputs=[input_protein_file, visualization, samples, hidden],
|
| 576 |
)
|
| 577 |
examples.click(
|
| 578 |
+
fn=click_on_example,
|
| 579 |
inputs=[examples],
|
| 580 |
+
outputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, visualization, samples, hidden]
|
| 581 |
)
|
| 582 |
button.click(
|
| 583 |
fn=generate,
|
| 584 |
+
inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, samples, hidden],
|
| 585 |
outputs=[visualization, output_files, samples, hidden],
|
| 586 |
_js=output.RETURN_SELECTION_JS,
|
| 587 |
)
|
examples/3hz1_fragments.sdf
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fragments
|
| 2 |
+
PyMOL2.5 3D 0
|
| 3 |
+
|
| 4 |
+
23 25 0 0 0 0 0 0 0 0999 V2000
|
| 5 |
+
0.7050 10.1160 25.5000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 6 |
+
-0.4250 10.6930 24.7810 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 7 |
+
-1.6420 10.9060 25.5370 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 8 |
+
-1.7510 10.5210 26.8370 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 9 |
+
-0.6900 9.9510 27.4380 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 10 |
+
0.4770 9.7630 26.7990 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 11 |
+
-0.6830 11.1870 23.5600 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 12 |
+
-1.9660 11.6240 23.5390 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 13 |
+
-2.5810 11.4250 24.7070 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 14 |
+
1.9520 9.8170 24.8700 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 15 |
+
3.1230 9.3980 25.6290 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 16 |
+
2.1100 9.7530 23.4320 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 17 |
+
7.8600 10.1360 22.6040 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 18 |
+
6.5530 9.6800 22.8080 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 19 |
+
5.8720 10.7150 23.6130 O 0 0 0 0 0 0 0 0 0 0 0 0
|
| 20 |
+
6.8390 11.6780 23.7840 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 21 |
+
8.0580 11.3690 23.2280 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 22 |
+
6.6560 12.9400 24.5720 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 23 |
+
7.6630 13.4980 25.2340 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 24 |
+
7.1190 14.6210 25.8930 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 25 |
+
5.8050 14.8140 25.6500 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 26 |
+
5.4220 13.6990 24.7720 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 27 |
+
4.9170 15.9400 26.1920 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 28 |
+
1 2 4 0 0 0 0
|
| 29 |
+
1 6 4 0 0 0 0
|
| 30 |
+
1 10 1 0 0 0 0
|
| 31 |
+
2 3 4 0 0 0 0
|
| 32 |
+
2 7 4 0 0 0 0
|
| 33 |
+
3 4 4 0 0 0 0
|
| 34 |
+
3 9 4 0 0 0 0
|
| 35 |
+
4 5 4 0 0 0 0
|
| 36 |
+
5 6 4 0 0 0 0
|
| 37 |
+
7 8 4 0 0 0 0
|
| 38 |
+
8 9 4 0 0 0 0
|
| 39 |
+
10 11 1 0 0 0 0
|
| 40 |
+
10 12 1 0 0 0 0
|
| 41 |
+
13 14 4 0 0 0 0
|
| 42 |
+
13 17 4 0 0 0 0
|
| 43 |
+
14 15 4 0 0 0 0
|
| 44 |
+
15 16 4 0 0 0 0
|
| 45 |
+
16 17 4 0 0 0 0
|
| 46 |
+
16 18 1 0 0 0 0
|
| 47 |
+
18 19 4 0 0 0 0
|
| 48 |
+
18 22 4 0 0 0 0
|
| 49 |
+
19 20 4 0 0 0 0
|
| 50 |
+
20 21 4 0 0 0 0
|
| 51 |
+
21 22 4 0 0 0 0
|
| 52 |
+
21 23 1 0 0 0 0
|
| 53 |
+
M END
|
| 54 |
+
$$$$
|
examples/3hz1_protein.pdb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/5ou2_fragments.sdf
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
5ou2_fragments
|
| 2 |
+
PyMOL2.5 3D 0
|
| 3 |
+
|
| 4 |
+
24 26 0 0 0 0 0 0 0 0999 V2000
|
| 5 |
+
135.6651 -15.3583 0.1325 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 6 |
+
134.8356 -14.4706 -0.4078 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 7 |
+
134.5969 -13.5549 0.5236 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 8 |
+
135.2672 -13.8787 1.6104 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 9 |
+
135.9361 -15.0095 1.3626 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 10 |
+
135.2407 -13.1072 2.8878 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 11 |
+
135.5339 -13.7328 4.0539 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 12 |
+
135.5239 -13.0695 5.2284 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 13 |
+
135.1995 -11.7489 5.2810 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 14 |
+
134.9023 -11.1173 4.1089 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 15 |
+
134.9113 -11.7774 2.9035 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 16 |
+
135.1362 -10.8138 6.9517 Br 0 0 0 0 0 0 0 0 0 0 0 0
|
| 17 |
+
126.8521 -19.0355 0.2522 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 18 |
+
126.0921 -18.0299 -0.2360 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 19 |
+
126.8721 -17.2548 -1.0322 N 0 0 0 0 0 0 0 0 0 0 0 0
|
| 20 |
+
128.1098 -17.7707 -1.0325 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 21 |
+
128.0889 -18.8815 -0.2256 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 22 |
+
129.3145 -17.2106 -1.7791 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 23 |
+
130.5850 -17.7185 -1.5264 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 24 |
+
131.6879 -17.2095 -2.1865 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 25 |
+
131.5211 -16.1844 -3.1052 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 26 |
+
130.2586 -15.6644 -3.3699 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 27 |
+
129.1548 -16.1795 -2.7058 C 0 0 0 0 0 0 0 0 0 0 0 0
|
| 28 |
+
133.0656 -15.5029 -4.0086 Br 0 0 0 0 0 0 0 0 0 0 0 0
|
| 29 |
+
1 2 4 0 0 0 0
|
| 30 |
+
2 3 4 0 0 0 0
|
| 31 |
+
3 4 4 0 0 0 0
|
| 32 |
+
4 6 1 0 0 0 0
|
| 33 |
+
1 5 4 0 0 0 0
|
| 34 |
+
4 5 4 0 0 0 0
|
| 35 |
+
6 7 4 0 0 0 0
|
| 36 |
+
6 11 4 0 0 0 0
|
| 37 |
+
7 8 4 0 0 0 0
|
| 38 |
+
8 9 4 0 0 0 0
|
| 39 |
+
9 10 4 0 0 0 0
|
| 40 |
+
9 12 1 0 0 0 0
|
| 41 |
+
10 11 4 0 0 0 0
|
| 42 |
+
13 14 4 0 0 0 0
|
| 43 |
+
14 15 4 0 0 0 0
|
| 44 |
+
15 16 4 0 0 0 0
|
| 45 |
+
16 18 1 0 0 0 0
|
| 46 |
+
13 17 4 0 0 0 0
|
| 47 |
+
16 17 4 0 0 0 0
|
| 48 |
+
18 19 4 0 0 0 0
|
| 49 |
+
18 23 4 0 0 0 0
|
| 50 |
+
19 20 4 0 0 0 0
|
| 51 |
+
20 21 4 0 0 0 0
|
| 52 |
+
21 22 4 0 0 0 0
|
| 53 |
+
21 24 1 0 0 0 0
|
| 54 |
+
22 23 4 0 0 0 0
|
| 55 |
+
M END
|
| 56 |
+
$$$$
|
examples/5ou2_protein.pdb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
output.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
|
| 2 |
<html>
|
| 3 |
<head>
|
| 4 |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
|
@@ -26,7 +26,6 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
| 26 |
let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
|
| 27 |
viewer.addModel(`{molecule}`, "{fmt}");
|
| 28 |
viewer.getModel(0).setStyle(defaultStyle);
|
| 29 |
-
// document.cookie = document.cookie + "|selected_atoms:";
|
| 30 |
|
| 31 |
viewer.getModel(0).setClickable(
|
| 32 |
{{}},
|
|
@@ -38,20 +37,16 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
| 38 |
{{"serial": _atom.serial, "model": 0}},
|
| 39 |
{{"sphere": {{"color": "magenta", "radius": 0.4}} }}
|
| 40 |
);
|
| 41 |
-
// document.cookie = document.cookie + "atom_" + String(_atom.serial) + "-";
|
| 42 |
window.parent.postMessage({{
|
| 43 |
name: "atom_selection",
|
| 44 |
data: {{"atom": _atom.serial, "add": true}}
|
| 45 |
-
// data: JSON.stringify({{"add": _atom.serial}})
|
| 46 |
}}, "*");
|
| 47 |
}} else {{
|
| 48 |
delete _atom.isClicked;
|
| 49 |
_viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
|
| 50 |
-
// document.cookie = document.cookie.replace("atom_" + String(_atom.serial) + "-", "");
|
| 51 |
window.parent.postMessage({{
|
| 52 |
name: "atom_selection",
|
| 53 |
data: {{"atom": _atom.serial, "add": false}}
|
| 54 |
-
// data: JSON.stringify({{"remove": _atom.serial}})
|
| 55 |
}}, "*");
|
| 56 |
}}
|
| 57 |
_viewer.render();
|
|
@@ -67,6 +62,112 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
| 67 |
</html>
|
| 68 |
"""
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
| 72 |
<html>
|
|
@@ -88,6 +189,7 @@ SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
| 88 |
|
| 89 |
<body>
|
| 90 |
<div id="container" class="mol-container"></div>
|
|
|
|
| 91 |
<button id="fragments">Input Fragments</button>
|
| 92 |
<button id="molecule">Output Molecule</button>
|
| 93 |
<script>
|
|
@@ -120,6 +222,74 @@ SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
| 120 |
</html>
|
| 121 |
"""
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
INVALID_FORMAT_MSG = """
|
| 125 |
<!DOCTYPE html>
|
|
@@ -135,13 +305,18 @@ INVALID_FORMAT_MSG = """
|
|
| 135 |
|
| 136 |
<body>
|
| 137 |
<h3>Invalid file format: {extension}</h3>
|
| 138 |
-
|
| 139 |
<ul>
|
| 140 |
<li>.pdb</li>
|
| 141 |
<li>.sdf</li>
|
| 142 |
<li>.mol</li>
|
| 143 |
<li>.mol2</li>
|
| 144 |
</ul>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
</body>
|
| 146 |
</html>
|
| 147 |
"""
|
|
@@ -190,7 +365,7 @@ STARTUP_JS = """
|
|
| 190 |
"""
|
| 191 |
|
| 192 |
RETURN_SELECTION_JS = """
|
| 193 |
-
(input_file, n_steps, n_atoms, samples, hidden) => {
|
| 194 |
let selected = []
|
| 195 |
for (const [atom, add] of Object.entries(window.selected_elements)) {
|
| 196 |
if (add) {
|
|
@@ -203,6 +378,6 @@ RETURN_SELECTION_JS = """
|
|
| 203 |
}
|
| 204 |
}
|
| 205 |
console.log("Finished parsing");
|
| 206 |
-
return [input_file, n_steps, n_atoms, samples, selected.join(",")];
|
| 207 |
}
|
| 208 |
"""
|
|
|
|
| 1 |
+
FRAGMENTS_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
| 2 |
<html>
|
| 3 |
<head>
|
| 4 |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
|
|
|
| 26 |
let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
|
| 27 |
viewer.addModel(`{molecule}`, "{fmt}");
|
| 28 |
viewer.getModel(0).setStyle(defaultStyle);
|
|
|
|
| 29 |
|
| 30 |
viewer.getModel(0).setClickable(
|
| 31 |
{{}},
|
|
|
|
| 37 |
{{"serial": _atom.serial, "model": 0}},
|
| 38 |
{{"sphere": {{"color": "magenta", "radius": 0.4}} }}
|
| 39 |
);
|
|
|
|
| 40 |
window.parent.postMessage({{
|
| 41 |
name: "atom_selection",
|
| 42 |
data: {{"atom": _atom.serial, "add": true}}
|
|
|
|
| 43 |
}}, "*");
|
| 44 |
}} else {{
|
| 45 |
delete _atom.isClicked;
|
| 46 |
_viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
|
|
|
|
| 47 |
window.parent.postMessage({{
|
| 48 |
name: "atom_selection",
|
| 49 |
data: {{"atom": _atom.serial, "add": false}}
|
|
|
|
| 50 |
}}, "*");
|
| 51 |
}}
|
| 52 |
_viewer.render();
|
|
|
|
| 62 |
</html>
|
| 63 |
"""
|
| 64 |
|
| 65 |
+
TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
| 66 |
+
<html>
|
| 67 |
+
<head>
|
| 68 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
| 69 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
| 70 |
+
<script src="https://3Dmol.org/build/3Dmol.js"></script>
|
| 71 |
+
<style>
|
| 72 |
+
.mol-container {{
|
| 73 |
+
width: 600px;
|
| 74 |
+
height: 600px;
|
| 75 |
+
position: relative;
|
| 76 |
+
}}
|
| 77 |
+
.mol-container select{{
|
| 78 |
+
background-image:None;
|
| 79 |
+
}}
|
| 80 |
+
</style>
|
| 81 |
+
</head>
|
| 82 |
+
|
| 83 |
+
<body>
|
| 84 |
+
<div id="container" class="mol-container"></div>
|
| 85 |
+
<script>
|
| 86 |
+
$(document).ready(function() {{
|
| 87 |
+
let element = $("#container");
|
| 88 |
+
let config = {{ backgroundColor: "white" }};
|
| 89 |
+
let viewer = $3Dmol.createViewer(element, config);
|
| 90 |
+
let proteinStyle = {{ cartoon: {{ colorscheme: "ssPyMOL" }} }};
|
| 91 |
+
viewer.addModel(`{molecule}`, "{fmt}");
|
| 92 |
+
viewer.getModel(0).setStyle(proteinStyle);
|
| 93 |
+
|
| 94 |
+
viewer.zoomTo();
|
| 95 |
+
viewer.zoom(0.7);
|
| 96 |
+
viewer.render();
|
| 97 |
+
}});
|
| 98 |
+
</script>
|
| 99 |
+
</body>
|
| 100 |
+
</html>
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
| 104 |
+
<html>
|
| 105 |
+
<head>
|
| 106 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
| 107 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
| 108 |
+
<script src="https://3Dmol.org/build/3Dmol.js"></script>
|
| 109 |
+
<style>
|
| 110 |
+
.mol-container {{
|
| 111 |
+
width: 600px;
|
| 112 |
+
height: 600px;
|
| 113 |
+
position: relative;
|
| 114 |
+
}}
|
| 115 |
+
.mol-container select{{
|
| 116 |
+
background-image:None;
|
| 117 |
+
}}
|
| 118 |
+
</style>
|
| 119 |
+
</head>
|
| 120 |
+
|
| 121 |
+
<body>
|
| 122 |
+
<div id="container" class="mol-container"></div>
|
| 123 |
+
<script>
|
| 124 |
+
$(document).ready(function() {{
|
| 125 |
+
let element = $("#container");
|
| 126 |
+
let config = {{ backgroundColor: "white" }};
|
| 127 |
+
let viewer = $3Dmol.createViewer(element, config);
|
| 128 |
+
let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
|
| 129 |
+
let proteinStyle = {{ cartoon: {{ colorscheme: "ssPyMOL" }} }};
|
| 130 |
+
|
| 131 |
+
viewer.addModel(`{molecule}`, "{fmt}");
|
| 132 |
+
viewer.getModel(0).setStyle(defaultStyle);
|
| 133 |
+
viewer.getModel(0).setClickable(
|
| 134 |
+
{{}},
|
| 135 |
+
true,
|
| 136 |
+
function (_atom, _viewer, _event, _container) {{
|
| 137 |
+
if (!_atom.isClicked) {{
|
| 138 |
+
_atom.isClicked = true;
|
| 139 |
+
_viewer.addStyle(
|
| 140 |
+
{{"serial": _atom.serial, "model": 0}},
|
| 141 |
+
{{"sphere": {{"color": "magenta", "radius": 0.4}} }}
|
| 142 |
+
);
|
| 143 |
+
window.parent.postMessage({{
|
| 144 |
+
name: "atom_selection",
|
| 145 |
+
data: {{"atom": _atom.serial, "add": true}}
|
| 146 |
+
}}, "*");
|
| 147 |
+
}} else {{
|
| 148 |
+
delete _atom.isClicked;
|
| 149 |
+
_viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
|
| 150 |
+
window.parent.postMessage({{
|
| 151 |
+
name: "atom_selection",
|
| 152 |
+
data: {{"atom": _atom.serial, "add": false}}
|
| 153 |
+
}}, "*");
|
| 154 |
+
}}
|
| 155 |
+
_viewer.render();
|
| 156 |
+
}}
|
| 157 |
+
);
|
| 158 |
+
|
| 159 |
+
viewer.addModel(`{target}`, "{target_fmt}");
|
| 160 |
+
viewer.getModel(1).setStyle(proteinStyle);
|
| 161 |
+
|
| 162 |
+
viewer.zoomTo();
|
| 163 |
+
viewer.zoom(0.7);
|
| 164 |
+
viewer.render();
|
| 165 |
+
}});
|
| 166 |
+
</script>
|
| 167 |
+
</body>
|
| 168 |
+
</html>
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
|
| 172 |
SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
| 173 |
<html>
|
|
|
|
| 189 |
|
| 190 |
<body>
|
| 191 |
<div id="container" class="mol-container"></div>
|
| 192 |
+
<br>
|
| 193 |
<button id="fragments">Input Fragments</button>
|
| 194 |
<button id="molecule">Output Molecule</button>
|
| 195 |
<script>
|
|
|
|
| 222 |
</html>
|
| 223 |
"""
|
| 224 |
|
| 225 |
+
SAMPLES_WITH_TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
| 226 |
+
<html>
|
| 227 |
+
<head>
|
| 228 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
| 229 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
| 230 |
+
<script src="https://3Dmol.org/build/3Dmol.js"></script>
|
| 231 |
+
<style>
|
| 232 |
+
.mol-container {{
|
| 233 |
+
width: 600px;
|
| 234 |
+
height: 600px;
|
| 235 |
+
position: relative;
|
| 236 |
+
}}
|
| 237 |
+
.mol-container select{{
|
| 238 |
+
background-image:None;
|
| 239 |
+
}}
|
| 240 |
+
</style>
|
| 241 |
+
</head>
|
| 242 |
+
|
| 243 |
+
<body>
|
| 244 |
+
<div id="container" class="mol-container"></div>
|
| 245 |
+
<br>
|
| 246 |
+
<button id="fragments">Input Fragments</button>
|
| 247 |
+
<button id="molecule">Output Molecule</button>
|
| 248 |
+
<button id="show-target">Show Target</button>
|
| 249 |
+
<button id="hide-target">Hide Target</button>
|
| 250 |
+
<script>
|
| 251 |
+
let element = $("#container");
|
| 252 |
+
let config = {{ backgroundColor: "white" }};
|
| 253 |
+
let viewer = $3Dmol.createViewer( element, config );
|
| 254 |
+
|
| 255 |
+
$(document).ready(function() {{
|
| 256 |
+
viewer.addModel(`{fragments}`, "{fragments_fmt}")
|
| 257 |
+
viewer.getModel(0).setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
|
| 258 |
+
viewer.getModel(0).hide();
|
| 259 |
+
|
| 260 |
+
viewer.addModel(`{molecule}`, "{molecule_fmt}")
|
| 261 |
+
viewer.getModel(1).setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
|
| 262 |
+
|
| 263 |
+
viewer.addModel(`{target}`, "{target_fmt}")
|
| 264 |
+
viewer.getModel(2).setStyle({{ cartoon: {{ colorscheme: "ssPyMOL" }} }})
|
| 265 |
+
|
| 266 |
+
viewer.zoomTo();
|
| 267 |
+
viewer.zoom(0.7);
|
| 268 |
+
viewer.render();
|
| 269 |
+
}});
|
| 270 |
+
$("#fragments").click(function() {{
|
| 271 |
+
viewer.getModel(0).show();
|
| 272 |
+
viewer.getModel(1).hide();
|
| 273 |
+
viewer.render();
|
| 274 |
+
}});
|
| 275 |
+
$("#molecule").click(function() {{
|
| 276 |
+
viewer.getModel(1).show();
|
| 277 |
+
viewer.getModel(0).hide();
|
| 278 |
+
viewer.render();
|
| 279 |
+
}});
|
| 280 |
+
$("#show-target").click(function() {{
|
| 281 |
+
viewer.getModel(2).show();
|
| 282 |
+
viewer.render();
|
| 283 |
+
}});
|
| 284 |
+
$("#hide-target").click(function() {{
|
| 285 |
+
viewer.getModel(2).hide();
|
| 286 |
+
viewer.render();
|
| 287 |
+
}});
|
| 288 |
+
</script>
|
| 289 |
+
</body>
|
| 290 |
+
</html>
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
|
| 294 |
INVALID_FORMAT_MSG = """
|
| 295 |
<!DOCTYPE html>
|
|
|
|
| 305 |
|
| 306 |
<body>
|
| 307 |
<h3>Invalid file format: {extension}</h3>
|
| 308 |
+
Allowed formats for the fragments file:
|
| 309 |
<ul>
|
| 310 |
<li>.pdb</li>
|
| 311 |
<li>.sdf</li>
|
| 312 |
<li>.mol</li>
|
| 313 |
<li>.mol2</li>
|
| 314 |
</ul>
|
| 315 |
+
|
| 316 |
+
Allowed formats for the optional protein file:
|
| 317 |
+
<ul>
|
| 318 |
+
<li>.pdb</li>
|
| 319 |
+
</ul>
|
| 320 |
</body>
|
| 321 |
</html>
|
| 322 |
"""
|
|
|
|
| 365 |
"""
|
| 366 |
|
| 367 |
RETURN_SELECTION_JS = """
|
| 368 |
+
(input_file, input_protein_file, n_steps, n_atoms, samples, hidden) => {
|
| 369 |
let selected = []
|
| 370 |
for (const [atom, add] of Object.entries(window.selected_elements)) {
|
| 371 |
if (add) {
|
|
|
|
| 378 |
}
|
| 379 |
}
|
| 380 |
console.log("Finished parsing");
|
| 381 |
+
return [input_file, input_protein_file, n_steps, n_atoms, samples, selected.join(",")];
|
| 382 |
}
|
| 383 |
"""
|
src/datasets.py
CHANGED
|
@@ -101,15 +101,25 @@ class ZincDataset(Dataset):
|
|
| 101 |
|
| 102 |
|
| 103 |
class MOADDataset(Dataset):
|
| 104 |
-
def __init__(self, data_path, prefix, device):
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
|
| 108 |
if os.path.exists(dataset_path):
|
| 109 |
self.data = torch.load(dataset_path, map_location=device)
|
| 110 |
else:
|
| 111 |
print(f'Preprocessing dataset with prefix {prefix}')
|
| 112 |
-
self.data =
|
| 113 |
torch.save(self.data, dataset_path)
|
| 114 |
|
| 115 |
def __len__(self):
|
|
@@ -264,7 +274,7 @@ def collate_with_fragment_edges(batch):
|
|
| 264 |
out = {}
|
| 265 |
|
| 266 |
# Filter out big molecules
|
| 267 |
-
batch = [data for data in batch if data['num_atoms'] <= 50]
|
| 268 |
|
| 269 |
for i, data in enumerate(batch):
|
| 270 |
for key, value in data.items():
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
class MOADDataset(Dataset):
|
| 104 |
+
def __init__(self, data=None, data_path=None, prefix=None, device=None):
|
| 105 |
+
assert (data is not None) or all(x is not None for x in (data_path, prefix, device))
|
| 106 |
+
if data is not None:
|
| 107 |
+
self.data = data
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
if '.' in prefix:
|
| 111 |
+
prefix, pocket_mode = prefix.split('.')
|
| 112 |
+
else:
|
| 113 |
+
parts = prefix.split('_')
|
| 114 |
+
prefix = '_'.join(parts[:-1])
|
| 115 |
+
pocket_mode = parts[-1]
|
| 116 |
|
| 117 |
dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
|
| 118 |
if os.path.exists(dataset_path):
|
| 119 |
self.data = torch.load(dataset_path, map_location=device)
|
| 120 |
else:
|
| 121 |
print(f'Preprocessing dataset with prefix {prefix}')
|
| 122 |
+
self.data = self.preprocess(data_path, prefix, pocket_mode, device)
|
| 123 |
torch.save(self.data, dataset_path)
|
| 124 |
|
| 125 |
def __len__(self):
|
|
|
|
| 274 |
out = {}
|
| 275 |
|
| 276 |
# Filter out big molecules
|
| 277 |
+
# batch = [data for data in batch if data['num_atoms'] <= 50]
|
| 278 |
|
| 279 |
for i, data in enumerate(batch):
|
| 280 |
for key, value in data.items():
|
src/generation.py
CHANGED
|
@@ -1,24 +1,44 @@
|
|
|
|
|
| 1 |
import os.path
|
| 2 |
import subprocess
|
| 3 |
import torch
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from src.visualizer import save_xyz_file
|
|
|
|
|
|
|
| 6 |
|
| 7 |
N_SAMPLES = 5
|
| 8 |
|
| 9 |
|
| 10 |
-
def generate_linkers(ddpm, data, sample_fn, name):
|
| 11 |
-
chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
print('Generated linker')
|
| 13 |
x = chain[0][:, :, :ddpm.n_dims]
|
| 14 |
h = chain[0][:, :, ddpm.n_dims:]
|
| 15 |
|
| 16 |
# Put the molecule back to the initial orientation
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
|
| 20 |
x = x + mean * node_mask
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
|
| 23 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
| 24 |
print('Saved XYZ files')
|
|
@@ -36,3 +56,62 @@ def try_to_convert_to_sdf(name):
|
|
| 36 |
out_files.append(out_xyz)
|
| 37 |
|
| 38 |
return out_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
import os.path
|
| 3 |
import subprocess
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from Bio.PDB import PDBParser
|
| 7 |
+
from src import const
|
| 8 |
from src.visualizer import save_xyz_file
|
| 9 |
+
from src.utils import FoundNaNException
|
| 10 |
+
from src.datasets import get_one_hot
|
| 11 |
|
| 12 |
N_SAMPLES = 5
|
| 13 |
|
| 14 |
|
| 15 |
+
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
|
| 16 |
+
chain = node_mask = None
|
| 17 |
+
for i in range(5):
|
| 18 |
+
try:
|
| 19 |
+
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
| 20 |
+
break
|
| 21 |
+
except FoundNaNException:
|
| 22 |
+
continue
|
| 23 |
+
|
| 24 |
print('Generated linker')
|
| 25 |
x = chain[0][:, :, :ddpm.n_dims]
|
| 26 |
h = chain[0][:, :, ddpm.n_dims:]
|
| 27 |
|
| 28 |
# Put the molecule back to the initial orientation
|
| 29 |
+
if with_pocket:
|
| 30 |
+
com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
|
| 31 |
+
else:
|
| 32 |
+
com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
|
| 33 |
+
|
| 34 |
+
pos_masked = data['positions'] * com_mask
|
| 35 |
+
N = com_mask.sum(1, keepdims=True)
|
| 36 |
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
|
| 37 |
x = x + mean * node_mask
|
| 38 |
|
| 39 |
+
if with_pocket:
|
| 40 |
+
node_mask[torch.where(data['pocket_mask'])] = 0
|
| 41 |
+
|
| 42 |
names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
|
| 43 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
| 44 |
print('Saved XYZ files')
|
|
|
|
| 56 |
out_files.append(out_xyz)
|
| 57 |
|
| 58 |
return out_files
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_pocket(mol, pdb_path):
|
| 62 |
+
struct = PDBParser().get_structure('', pdb_path)
|
| 63 |
+
residue_ids = []
|
| 64 |
+
atom_coords = []
|
| 65 |
+
|
| 66 |
+
for residue in struct.get_residues():
|
| 67 |
+
resid = residue.get_id()[1]
|
| 68 |
+
for atom in residue.get_atoms():
|
| 69 |
+
atom_coords.append(atom.get_coord())
|
| 70 |
+
residue_ids.append(resid)
|
| 71 |
+
|
| 72 |
+
residue_ids = np.array(residue_ids)
|
| 73 |
+
atom_coords = np.array(atom_coords)
|
| 74 |
+
mol_atom_coords = mol.GetConformer().GetPositions()
|
| 75 |
+
|
| 76 |
+
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
|
| 77 |
+
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])
|
| 78 |
+
|
| 79 |
+
pocket_coords_full = []
|
| 80 |
+
pocket_types_full = []
|
| 81 |
+
|
| 82 |
+
pocket_coords_bb = []
|
| 83 |
+
pocket_types_bb = []
|
| 84 |
+
|
| 85 |
+
for residue in struct.get_residues():
|
| 86 |
+
resid = residue.get_id()[1]
|
| 87 |
+
if resid not in contact_residues:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
for atom in residue.get_atoms():
|
| 91 |
+
atom_name = atom.get_name()
|
| 92 |
+
atom_type = atom.element.upper()
|
| 93 |
+
atom_coord = atom.get_coord()
|
| 94 |
+
|
| 95 |
+
pocket_coords_full.append(atom_coord.tolist())
|
| 96 |
+
pocket_types_full.append(atom_type)
|
| 97 |
+
|
| 98 |
+
if atom_name in {'N', 'CA', 'C', 'O'}:
|
| 99 |
+
pocket_coords_bb.append(atom_coord.tolist())
|
| 100 |
+
pocket_types_bb.append(atom_type)
|
| 101 |
+
|
| 102 |
+
pocket_pos = []
|
| 103 |
+
pocket_one_hot = []
|
| 104 |
+
pocket_charges = []
|
| 105 |
+
for coord, atom_type in zip(pocket_coords_full, pocket_types_full):
|
| 106 |
+
if atom_type not in const.GEOM_ATOM2IDX.keys():
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
pocket_pos.append(coord)
|
| 110 |
+
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
|
| 111 |
+
pocket_charges.append(const.GEOM_CHARGES[atom_type])
|
| 112 |
+
|
| 113 |
+
pocket_pos = np.array(pocket_pos)
|
| 114 |
+
pocket_one_hot = np.array(pocket_one_hot)
|
| 115 |
+
pocket_charges = np.array(pocket_charges)
|
| 116 |
+
|
| 117 |
+
return pocket_pos, pocket_one_hot, pocket_charges
|
src/lightning.py
CHANGED
|
@@ -21,7 +21,6 @@ from pdb import set_trace
|
|
| 21 |
|
| 22 |
|
| 23 |
def get_activation(activation):
|
| 24 |
-
print(activation)
|
| 25 |
if activation == 'silu':
|
| 26 |
return torch.nn.SiLU()
|
| 27 |
else:
|
|
@@ -158,7 +157,7 @@ class DDPM(pl.LightningModule):
|
|
| 158 |
context = fragment_mask
|
| 159 |
|
| 160 |
# Add information about pocket to the context
|
| 161 |
-
if
|
| 162 |
fragment_pocket_mask = fragment_mask
|
| 163 |
fragment_only_mask = data['fragment_only_mask']
|
| 164 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
|
@@ -170,6 +169,8 @@ class DDPM(pl.LightningModule):
|
|
| 170 |
# Removing COM of fragment from the atom coordinates
|
| 171 |
if self.inpainting:
|
| 172 |
center_of_mass_mask = node_mask
|
|
|
|
|
|
|
| 173 |
elif self.center_of_mass == 'fragments':
|
| 174 |
center_of_mass_mask = fragment_mask
|
| 175 |
elif self.center_of_mass == 'anchors':
|
|
@@ -423,9 +424,9 @@ class DDPM(pl.LightningModule):
|
|
| 423 |
context = fragment_mask
|
| 424 |
|
| 425 |
# Add information about pocket to the context
|
| 426 |
-
if
|
| 427 |
fragment_pocket_mask = fragment_mask
|
| 428 |
-
fragment_only_mask =
|
| 429 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
| 430 |
if self.anchors_context:
|
| 431 |
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
|
|
@@ -435,6 +436,8 @@ class DDPM(pl.LightningModule):
|
|
| 435 |
# Removing COM of fragment from the atom coordinates
|
| 436 |
if self.inpainting:
|
| 437 |
center_of_mass_mask = node_mask
|
|
|
|
|
|
|
| 438 |
elif self.center_of_mass == 'fragments':
|
| 439 |
center_of_mass_mask = fragment_mask
|
| 440 |
elif self.center_of_mass == 'anchors':
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def get_activation(activation):
|
|
|
|
| 24 |
if activation == 'silu':
|
| 25 |
return torch.nn.SiLU()
|
| 26 |
else:
|
|
|
|
| 157 |
context = fragment_mask
|
| 158 |
|
| 159 |
# Add information about pocket to the context
|
| 160 |
+
if isinstance(self.train_dataset, MOADDataset):
|
| 161 |
fragment_pocket_mask = fragment_mask
|
| 162 |
fragment_only_mask = data['fragment_only_mask']
|
| 163 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
|
|
|
| 169 |
# Removing COM of fragment from the atom coordinates
|
| 170 |
if self.inpainting:
|
| 171 |
center_of_mass_mask = node_mask
|
| 172 |
+
elif isinstance(self.train_dataset, MOADDataset) and self.center_of_mass == 'fragments':
|
| 173 |
+
center_of_mass_mask = data['fragment_only_mask']
|
| 174 |
elif self.center_of_mass == 'fragments':
|
| 175 |
center_of_mass_mask = fragment_mask
|
| 176 |
elif self.center_of_mass == 'anchors':
|
|
|
|
| 424 |
context = fragment_mask
|
| 425 |
|
| 426 |
# Add information about pocket to the context
|
| 427 |
+
if isinstance(self.val_dataset, MOADDataset):
|
| 428 |
fragment_pocket_mask = fragment_mask
|
| 429 |
+
fragment_only_mask = template_data['fragment_only_mask']
|
| 430 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
| 431 |
if self.anchors_context:
|
| 432 |
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
|
|
|
|
| 436 |
# Removing COM of fragment from the atom coordinates
|
| 437 |
if self.inpainting:
|
| 438 |
center_of_mass_mask = node_mask
|
| 439 |
+
elif isinstance(self.val_dataset, MOADDataset) and self.center_of_mass == 'fragments':
|
| 440 |
+
center_of_mass_mask = template_data['fragment_only_mask']
|
| 441 |
elif self.center_of_mass == 'fragments':
|
| 442 |
center_of_mass_mask = fragment_mask
|
| 443 |
elif self.center_of_mass == 'anchors':
|