Spaces:
Runtime error
Runtime error
| import argparse | |
| import shutil | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import torch | |
| import subprocess | |
| import output | |
| from rdkit import Chem | |
| from src import const | |
| from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset | |
| from src.lightning import DDPM | |
| from src.linker_size_lightning import SizeClassifier | |
| from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket | |
| from zipfile import ZipFile | |
| MODELS_METADATA = { | |
| 'geom_difflinker': { | |
| 'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1', | |
| 'path': 'models/geom_difflinker.ckpt', | |
| }, | |
| 'geom_difflinker_given_anchors': { | |
| 'link': 'https://zenodo.org/record/7775568/files/geom_difflinker_given_anchors.ckpt?download=1', | |
| 'path': 'models/geom_difflinker_given_anchors.ckpt', | |
| }, | |
| 'pockets_difflinker': { | |
| 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1', | |
| 'path': 'models/pockets_difflinker.ckpt', | |
| }, | |
| 'pockets_difflinker_given_anchors': { | |
| 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1', | |
| 'path': 'models/pockets_difflinker_given_anchors.ckpt', | |
| }, | |
| } | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--ip', type=str, default=None) | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f'Device: {device}') | |
| os.makedirs("results", exist_ok=True) | |
| os.makedirs("models", exist_ok=True) | |
| size_gnn_path = 'models/geom_size_gnn.ckpt' | |
| if not os.path.exists(size_gnn_path): | |
| print('Downloading SizeGNN model...') | |
| link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1' | |
| subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True) | |
| size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device) | |
| print('Loaded SizeGNN model') | |
| diffusion_models = {} | |
| for model_name, metadata in MODELS_METADATA.items(): | |
| link = metadata['link'] | |
| diffusion_path = metadata['path'] | |
| if not os.path.exists(diffusion_path): | |
| print(f'Downloading {model_name}...') | |
| subprocess.run(f'wget {link} -O {diffusion_path}', shell=True) | |
| diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device) | |
| print(f'Loaded model {model_name}') | |
| print(os.curdir) | |
| print(os.path.abspath(os.curdir)) | |
| print(os.listdir(os.curdir)) | |
| def read_molecule_content(path): | |
| with open(path, "r") as f: | |
| return "".join(f.readlines()) | |
| def read_molecule(path): | |
| if path.endswith('.pdb'): | |
| return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) | |
| elif path.endswith('.mol'): | |
| return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) | |
| elif path.endswith('.mol2'): | |
| return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) | |
| elif path.endswith('.sdf'): | |
| return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] | |
| raise Exception('Unknown file extension') | |
| def read_molecule_file(in_file, allowed_extentions): | |
| if isinstance(in_file, str): | |
| path = in_file | |
| else: | |
| path = in_file.name | |
| extension = path.split('.')[-1] | |
| if extension not in allowed_extentions: | |
| msg = output.INVALID_FORMAT_MSG.format(extension=extension) | |
| return None, None, msg | |
| try: | |
| mol = read_molecule(path) | |
| except Exception as e: | |
| e = str(e).replace('\'', '') | |
| msg = output.ERROR_FORMAT_MSG.format(message=e) | |
| return None, None, msg | |
| if extension == 'pdb': | |
| content = Chem.MolToPDBBlock(mol) | |
| elif extension in ['mol', 'mol2', 'sdf']: | |
| content = Chem.MolToMolBlock(mol, kekulize=False) | |
| extension = 'mol' | |
| else: | |
| raise NotImplementedError | |
| return content, extension, None | |
| def show_input(in_fragments, in_protein): | |
| vis = '' | |
| if in_fragments is not None and in_protein is None: | |
| vis = show_fragments(in_fragments) | |
| elif in_fragments is None and in_protein is not None: | |
| vis = show_target(in_protein) | |
| elif in_fragments is not None and in_protein is not None: | |
| vis = show_fragments_and_target(in_fragments, in_protein) | |
| return [vis, gr.Dropdown.update(choices=[], value=None, visible=False), None] | |
| def show_fragments(in_fragments): | |
| molecule, extension, html = read_molecule_file(in_fragments, allowed_extentions=['sdf', 'pdb', 'mol', 'mol2']) | |
| if molecule is not None: | |
| html = output.FRAGMENTS_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension) | |
| return output.IFRAME_TEMPLATE.format(html=html) | |
| def show_target(in_protein): | |
| molecule, extension, html = read_molecule_file(in_protein, allowed_extentions=['pdb']) | |
| if molecule is not None: | |
| html = output.TARGET_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension) | |
| return output.IFRAME_TEMPLATE.format(html=html) | |
| def show_fragments_and_target(in_fragments, in_protein): | |
| fragments_molecule, fragments_extension, msg = read_molecule_file(in_fragments, ['sdf', 'pdb', 'mol', 'mol2']) | |
| if fragments_molecule is None: | |
| return output.IFRAME_TEMPLATE.format(html=msg) | |
| target_molecule, target_extension, msg = read_molecule_file(in_protein, allowed_extentions=['pdb']) | |
| if fragments_molecule is None: | |
| return output.IFRAME_TEMPLATE.format(html=msg) | |
| html = output.FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE.format( | |
| molecule=fragments_molecule, | |
| fmt=fragments_extension, | |
| target=target_molecule, | |
| target_fmt=target_extension, | |
| ) | |
| return output.IFRAME_TEMPLATE.format(html=html) | |
| def clear_fragments_input(in_protein): | |
| vis = '' | |
| if in_protein is not None: | |
| vis = show_target(in_protein) | |
| return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None] | |
| def clear_protein_input(in_fragments): | |
| vis = '' | |
| if in_fragments is not None: | |
| vis = show_fragments(in_fragments) | |
| return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None] | |
| def click_on_example(example): | |
| fragment_fname, target_fname = example | |
| fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None | |
| target_path = f'examples/{target_fname}' if target_fname != '' else None | |
| return [fragment_path, target_path] + show_input(fragment_path, target_path) | |
| def draw_sample(sample_path, out_files, num_samples): | |
| with_protein = (len(out_files) == num_samples + 3) | |
| in_file = out_files[1] | |
| in_sdf = in_file if isinstance(in_file, str) else in_file.name | |
| input_fragments_content = read_molecule_content(in_sdf) | |
| fragments_fmt = in_sdf.split('.')[-1] | |
| offset = 2 | |
| input_target_content = None | |
| target_fmt = None | |
| if with_protein: | |
| offset += 1 | |
| in_pdb = out_files[2] if isinstance(out_files[2], str) else out_files[2].name | |
| input_target_content = read_molecule_content(in_pdb) | |
| target_fmt = in_pdb.split('.')[-1] | |
| out_sdf = sample_path if isinstance(sample_path, str) else sample_path.name | |
| generated_molecule_content = read_molecule_content(out_sdf) | |
| molecule_fmt = out_sdf.split('.')[-1] | |
| if with_protein: | |
| html = output.SAMPLES_WITH_TARGET_RENDERING_TEMPLATE.format( | |
| fragments=input_fragments_content, | |
| fragments_fmt=fragments_fmt, | |
| molecule=generated_molecule_content, | |
| molecule_fmt=molecule_fmt, | |
| target=input_target_content, | |
| target_fmt=target_fmt, | |
| ) | |
| else: | |
| html = output.SAMPLES_RENDERING_TEMPLATE.format( | |
| fragments=input_fragments_content, | |
| fragments_fmt=fragments_fmt, | |
| molecule=generated_molecule_content, | |
| molecule_fmt=molecule_fmt, | |
| ) | |
| return output.IFRAME_TEMPLATE.format(html=html) | |
| def compress(output_fnames, name): | |
| archive_path = f'results/all_files_{name}.zip' | |
| with ZipFile(archive_path, 'w') as archive: | |
| for fname in output_fnames: | |
| archive.write(fname) | |
| return archive_path | |
| def generate(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms): | |
| if in_fragments is None: | |
| return [None, None, None, None] | |
| if in_protein is None: | |
| return generate_without_pocket(in_fragments, n_steps, n_atoms, num_samples, selected_atoms) | |
| else: | |
| return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms) | |
| def generate_without_pocket(input_file, n_steps, n_atoms, num_samples, selected_atoms): | |
| # Parsing selected atoms (javascript output) | |
| selected_atoms = selected_atoms.strip() | |
| if selected_atoms == '': | |
| selected_atoms = [] | |
| else: | |
| selected_atoms = list(map(int, selected_atoms.split(','))) | |
| # Selecting model | |
| if len(selected_atoms) == 0: | |
| selected_model_name = 'geom_difflinker' | |
| else: | |
| selected_model_name = 'geom_difflinker_given_anchors' | |
| print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms) | |
| ddpm = diffusion_models[selected_model_name] | |
| path = input_file.name | |
| extension = path.split('.')[-1] | |
| if extension not in ['sdf', 'pdb', 'mol', 'mol2']: | |
| msg = output.INVALID_FORMAT_MSG.format(extension=extension) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| try: | |
| molecule = read_molecule(path) | |
| try: | |
| molecule = Chem.RemoveAllHs(molecule) | |
| except: | |
| pass | |
| name = '.'.join(path.split('/')[-1].split('.')[:-1]) | |
| inp_sdf = f'results/input_{name}.sdf' | |
| except Exception as e: | |
| e = str(e).replace('\'', '') | |
| error = f'Could not read the molecule: {e}' | |
| msg = output.ERROR_FORMAT_MSG.format(message=error) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| if molecule.GetNumAtoms() > 100: | |
| error = f'Too large molecule: upper limit is 100 heavy atoms' | |
| msg = output.ERROR_FORMAT_MSG.format(message=error) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| with Chem.SDWriter(inp_sdf) as w: | |
| w.SetKekulize(False) | |
| w.write(molecule) | |
| positions, one_hot, charges = parse_molecule(molecule, is_geom=True) | |
| anchors = np.zeros_like(charges) | |
| anchors[selected_atoms] = 1 | |
| fragment_mask = np.ones_like(charges) | |
| linker_mask = np.zeros_like(charges) | |
| print('Read and parsed molecule') | |
| dataset = [{ | |
| 'uuid': '0', | |
| 'name': '0', | |
| 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
| 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
| 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
| 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), | |
| 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'num_atoms': len(positions), | |
| }] * num_samples | |
| dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges) | |
| print('Created dataloader') | |
| ddpm.edm.T = n_steps | |
| if n_atoms == 0: | |
| def sample_fn(_data): | |
| out, _ = size_nn.forward(_data, return_loss=False) | |
| probabilities = torch.softmax(out, dim=1) | |
| distribution = torch.distributions.Categorical(probs=probabilities) | |
| samples = distribution.sample() | |
| sizes = [] | |
| for label in samples.detach().cpu().numpy(): | |
| sizes.append(size_nn.linker_id2size[label]) | |
| sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) | |
| return sizes | |
| else: | |
| def sample_fn(_data): | |
| return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms | |
| for data in dataloader: | |
| try: | |
| generate_linkers( | |
| ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=False | |
| ) | |
| except Exception as e: | |
| e = str(e).replace('\'', '') | |
| error = f'Caught exception while generating linkers: {e}' | |
| msg = output.ERROR_FORMAT_MSG.format(message=error) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| out_files = try_to_convert_to_sdf(name, num_samples) | |
| out_files = [inp_sdf] + out_files | |
| out_files = [compress(out_files, name=name)] + out_files | |
| choice = out_files[2] | |
| return [ | |
| draw_sample(choice, out_files, num_samples), | |
| out_files, | |
| gr.Dropdown.update( | |
| choices=out_files[2:], | |
| value=choice, | |
| visible=True, | |
| ), | |
| None | |
| ] | |
| def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms): | |
| # Parsing selected atoms (javascript output) | |
| selected_atoms = selected_atoms.strip() | |
| if selected_atoms == '': | |
| selected_atoms = [] | |
| else: | |
| selected_atoms = list(map(int, selected_atoms.split(','))) | |
| # Selecting model | |
| if len(selected_atoms) == 0: | |
| selected_model_name = 'pockets_difflinker' | |
| else: | |
| selected_model_name = 'pockets_difflinker_given_anchors' | |
| print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms) | |
| ddpm = diffusion_models[selected_model_name] | |
| fragments_path = in_fragments.name | |
| fragments_extension = fragments_path.split('.')[-1] | |
| if fragments_extension not in ['sdf', 'pdb', 'mol', 'mol2']: | |
| msg = output.INVALID_FORMAT_MSG.format(extension=fragments_extension) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| protein_path = in_protein.name | |
| protein_extension = protein_path.split('.')[-1] | |
| if protein_extension not in ['pdb']: | |
| msg = output.INVALID_FORMAT_MSG.format(extension=protein_extension) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| try: | |
| fragments_mol = read_molecule(fragments_path) | |
| name = '.'.join(fragments_path.split('/')[-1].split('.')[:-1]) | |
| except Exception as e: | |
| e = str(e).replace('\'', '') | |
| error = f'Could not read the molecule: {e}' | |
| msg = output.ERROR_FORMAT_MSG.format(message=error) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| if fragments_mol.GetNumAtoms() > 100: | |
| error = f'Too large molecule: upper limit is 100 heavy atoms' | |
| msg = output.ERROR_FORMAT_MSG.format(message=error) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| inp_sdf = f'results/input_{name}.sdf' | |
| with Chem.SDWriter(inp_sdf) as w: | |
| w.SetKekulize(False) | |
| w.write(fragments_mol) | |
| inp_pdb = f'results/target_{name}.pdb' | |
| shutil.copy(protein_path, inp_pdb) | |
| frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments_mol, is_geom=True) | |
| pocket_pos, pocket_one_hot, pocket_charges = get_pocket(fragments_mol, protein_path) | |
| print(f'Detected pocket with {len(pocket_pos)} atoms') | |
| positions = np.concatenate([frag_pos, pocket_pos], axis=0) | |
| one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0) | |
| charges = np.concatenate([frag_charges, pocket_charges], axis=0) | |
| anchors = np.zeros_like(charges) | |
| anchors[selected_atoms] = 1 | |
| fragment_only_mask = np.concatenate([ | |
| np.ones_like(frag_charges), | |
| np.zeros_like(pocket_charges), | |
| ]) | |
| pocket_mask = np.concatenate([ | |
| np.zeros_like(frag_charges), | |
| np.ones_like(pocket_charges), | |
| ]) | |
| linker_mask = np.concatenate([ | |
| np.zeros_like(frag_charges), | |
| np.zeros_like(pocket_charges), | |
| ]) | |
| fragment_mask = np.concatenate([ | |
| np.ones_like(frag_charges), | |
| np.ones_like(pocket_charges), | |
| ]) | |
| print('Read and parsed molecule') | |
| dataset = [{ | |
| 'uuid': '0', | |
| 'name': '0', | |
| 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device), | |
| 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device), | |
| 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device), | |
| 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device), | |
| 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device), | |
| 'num_atoms': len(positions), | |
| }] * num_samples | |
| dataset = MOADDataset(data=dataset) | |
| ddpm.val_dataset = dataset | |
| dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges) | |
| print('Created dataloader') | |
| ddpm.edm.T = n_steps | |
| if n_atoms == 0: | |
| def sample_fn(_data): | |
| out, _ = size_nn.forward(_data, return_loss=False) | |
| probabilities = torch.softmax(out, dim=1) | |
| distribution = torch.distributions.Categorical(probs=probabilities) | |
| samples = distribution.sample() | |
| sizes = [] | |
| for label in samples.detach().cpu().numpy(): | |
| sizes.append(size_nn.linker_id2size[label]) | |
| sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long) | |
| return sizes | |
| else: | |
| def sample_fn(_data): | |
| return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms | |
| for data in dataloader: | |
| try: | |
| generate_linkers( | |
| ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=True | |
| ) | |
| except Exception as e: | |
| e = str(e).replace('\'', '') | |
| error = f'Caught exception while generating linkers: {e}' | |
| msg = output.ERROR_FORMAT_MSG.format(message=error) | |
| return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None] | |
| out_files = try_to_convert_to_sdf(name, num_samples) | |
| out_files = [inp_sdf, inp_pdb] + out_files | |
| out_files = [compress(out_files, name=name)] + out_files | |
| choice = out_files[3] | |
| return [ | |
| draw_sample(choice, out_files, num_samples), | |
| out_files, | |
| gr.Dropdown.update( | |
| choices=out_files[3:], | |
| value=choice, | |
| visible=True, | |
| ), | |
| None | |
| ] | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design') | |
| gr.Markdown( | |
| 'Given a set of disconnected fragments in 3D, ' | |
| 'DiffLinker places missing atoms in between and designs a molecule incorporating all the initial fragments. ' | |
| 'Our method can link an arbitrary number of fragments, requires no information on the attachment atoms ' | |
| 'and linker size, and can be conditioned on the protein pockets.' | |
| ) | |
| gr.Markdown( | |
| '[**[Paper]**](https://arxiv.org/abs/2210.05274) ' | |
| '[**[Code]**](https://github.com/igashov/DiffLinker)' | |
| ) | |
| with gr.Box(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown('## Input') | |
| gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:') | |
| input_fragments_file = gr.File(file_count='single', label='Input Fragments') | |
| gr.Markdown('Upload the file of the target protein in .pdb format (optionally):') | |
| input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)') | |
| n_steps = gr.Slider(minimum=50, maximum=500, label="Number of Denoising Steps", step=10) | |
| n_atoms = gr.Slider( | |
| minimum=0, maximum=20, | |
| label="Linker Size: DiffLinker will predict it if set to 0", | |
| step=1 | |
| ) | |
| n_samples = gr.Slider(minimum=5, maximum=50, label="Number of Samples", step=5) | |
| examples = gr.Dataset( | |
| components=[gr.File(visible=False), gr.File(visible=False)], | |
| samples=[ | |
| ['examples/example_1.sdf', ''], | |
| ['examples/example_2.sdf', ''], | |
| ['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'], | |
| ['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'], | |
| ], | |
| type='values', | |
| headers=['Input Fragments', 'Target Protein'], | |
| ) | |
| button = gr.Button('Generate Linker!') | |
| gr.Markdown('') | |
| gr.Markdown('## Output Files') | |
| gr.Markdown('Download files with the generated molecules here:') | |
| output_files = gr.File(file_count='multiple', label='Output Files', interactive=False) | |
| hidden = gr.Textbox(visible=False) | |
| with gr.Column(): | |
| gr.Markdown('## Visualization') | |
| gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)') | |
| samples = gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| type='value', | |
| multiselect=False, | |
| visible=False, | |
| interactive=True, | |
| label='Samples' | |
| ) | |
| visualization = gr.HTML() | |
| input_fragments_file.change( | |
| fn=show_input, | |
| inputs=[input_fragments_file, input_protein_file], | |
| outputs=[visualization, samples, hidden], | |
| ) | |
| input_protein_file.change( | |
| fn=show_input, | |
| inputs=[input_fragments_file, input_protein_file], | |
| outputs=[visualization, samples, hidden], | |
| ) | |
| input_fragments_file.clear( | |
| fn=clear_fragments_input, | |
| inputs=[input_protein_file], | |
| outputs=[input_fragments_file, visualization, samples, hidden], | |
| ) | |
| input_protein_file.clear( | |
| fn=clear_protein_input, | |
| inputs=[input_fragments_file], | |
| outputs=[input_protein_file, visualization, samples, hidden], | |
| ) | |
| examples.click( | |
| fn=click_on_example, | |
| inputs=[examples], | |
| outputs=[input_fragments_file, input_protein_file, visualization, samples, hidden] | |
| ) | |
| button.click( | |
| fn=generate, | |
| inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, n_samples, hidden], | |
| outputs=[visualization, output_files, samples, hidden], | |
| _js=output.RETURN_SELECTION_JS, | |
| ) | |
| samples.select( | |
| fn=draw_sample, | |
| inputs=[samples, output_files, n_samples], | |
| outputs=[visualization], | |
| ) | |
| demo.load(_js=output.STARTUP_JS) | |
| demo.launch(server_name=args.ip) | |