Spaces:
Runtime error
Runtime error
| import os | |
| import yaml | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from . import diffusion_utils as utils | |
| from .molecule_utils import graph_to_smiles, check_valid | |
| from .transformer import Transformer | |
| from .visualize_utils import MolecularVisualization | |
| class GraphDiT(nn.Module): | |
| def __init__( | |
| self, | |
| model_config_path, | |
| data_info_path, | |
| model_dtype, | |
| ): | |
| super().__init__() | |
| dm_cfg, data_info = utils.load_config(model_config_path, data_info_path) | |
| input_dims = data_info.input_dims | |
| output_dims = data_info.output_dims | |
| nodes_dist = data_info.nodes_dist | |
| active_index = data_info.active_index | |
| self.model_config = dm_cfg | |
| self.data_info = data_info | |
| self.T = dm_cfg.diffusion_steps | |
| self.Xdim = input_dims["X"] | |
| self.Edim = input_dims["E"] | |
| self.ydim = input_dims["y"] | |
| self.Xdim_output = output_dims["X"] | |
| self.Edim_output = output_dims["E"] | |
| self.ydim_output = output_dims["y"] | |
| self.node_dist = nodes_dist | |
| self.active_index = active_index | |
| self.max_n_nodes = data_info.max_n_nodes | |
| self.atom_decoder = data_info.atom_decoder | |
| self.hidden_size = dm_cfg.hidden_size | |
| self.mol_visualizer = MolecularVisualization(self.atom_decoder) | |
| self.denoiser = Transformer( | |
| max_n_nodes=self.max_n_nodes, | |
| hidden_size=dm_cfg.hidden_size, | |
| depth=dm_cfg.depth, | |
| num_heads=dm_cfg.num_heads, | |
| mlp_ratio=dm_cfg.mlp_ratio, | |
| drop_condition=dm_cfg.drop_condition, | |
| Xdim=self.Xdim, | |
| Edim=self.Edim, | |
| ydim=self.ydim, | |
| ) | |
| self.model_dtype = model_dtype | |
| self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete( | |
| dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps | |
| ) | |
| x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum( | |
| data_info.node_types.to(self.model_dtype) | |
| ) | |
| e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum( | |
| data_info.edge_types.to(self.model_dtype) | |
| ) | |
| x_marginals = x_marginals / x_marginals.sum() | |
| e_marginals = e_marginals / e_marginals.sum() | |
| xe_conditions = data_info.transition_E.to(self.model_dtype) | |
| xe_conditions = xe_conditions[self.active_index][:, self.active_index] | |
| xe_conditions = xe_conditions.sum(dim=1) | |
| ex_conditions = xe_conditions.t() | |
| xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) | |
| ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) | |
| self.transition_model = utils.MarginalTransition( | |
| x_marginals=x_marginals, | |
| e_marginals=e_marginals, | |
| xe_conditions=xe_conditions, | |
| ex_conditions=ex_conditions, | |
| y_classes=self.ydim_output, | |
| n_nodes=self.max_n_nodes, | |
| ) | |
| self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) | |
| def init_model(self, model_dir): | |
| model_file = os.path.join(model_dir, 'model.pt') | |
| if os.path.exists(model_file): | |
| self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True)) | |
| else: | |
| raise FileNotFoundError(f"Model file not found: {model_file}") | |
| def disable_grads(self): | |
| self.denoiser.disable_grads() | |
| def forward( | |
| self, x, edge_index, edge_attr, graph_batch, properties, no_label_index | |
| ): | |
| raise ValueError('Not Implement') | |
| def _forward(self, noisy_data, unconditioned=False): | |
| noisy_x, noisy_e, properties = ( | |
| noisy_data["X_t"].to(self.model_dtype), | |
| noisy_data["E_t"].to(self.model_dtype), | |
| noisy_data["y_t"].to(self.model_dtype).clone(), | |
| ) | |
| node_mask, timestep = ( | |
| noisy_data["node_mask"], | |
| noisy_data["t"], | |
| ) | |
| pred = self.denoiser( | |
| noisy_x, | |
| noisy_e, | |
| node_mask, | |
| properties, | |
| timestep, | |
| unconditioned=unconditioned, | |
| ) | |
| return pred | |
| def apply_noise(self, X, E, y, node_mask): | |
| """Sample noise and apply it to the data.""" | |
| # Sample a timestep t. | |
| # When evaluating, the loss for t=0 is computed separately | |
| lowest_t = 0 if self.training else 1 | |
| t_int = torch.randint( | |
| lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device | |
| ).to( | |
| self.model_dtype | |
| ) # (bs, 1) | |
| s_int = t_int - 1 | |
| t_float = t_int / self.T | |
| s_float = s_int / self.T | |
| # beta_t and alpha_s_bar are used for denoising/loss computation | |
| beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) | |
| alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) | |
| alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) | |
| Qtb = self.transition_model.get_Qt_bar( | |
| alpha_t_bar, X.device | |
| ) # (bs, dx_in, dx_out), (bs, de_in, de_out) | |
| bs, n, d = X.shape | |
| X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) | |
| prob_all = X_all @ Qtb.X | |
| probX = prob_all[:, :, : self.Xdim_output] | |
| probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1) | |
| sampled_t = utils.sample_discrete_features( | |
| probX=probX, probE=probE, node_mask=node_mask | |
| ) | |
| X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) | |
| E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) | |
| assert (X.shape == X_t.shape) and (E.shape == E_t.shape) | |
| y_t = y | |
| z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) | |
| noisy_data = { | |
| "t_int": t_int, | |
| "t": t_float, | |
| "beta_t": beta_t, | |
| "alpha_s_bar": alpha_s_bar, | |
| "alpha_t_bar": alpha_t_bar, | |
| "X_t": z_t.X, | |
| "E_t": z_t.E, | |
| "y_t": z_t.y, | |
| "node_mask": node_mask, | |
| } | |
| return noisy_data | |
| def generate( | |
| self, | |
| properties, | |
| device, | |
| guide_scale=1., | |
| num_nodes=None, | |
| number_chain_steps=50, | |
| ): | |
| properties = [float('nan') if x is None else x for x in properties] | |
| properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device) | |
| batch_size = properties.size(0) | |
| assert batch_size == 1 | |
| if num_nodes is None: | |
| num_nodes = self.node_dist.sample_n(batch_size, device) | |
| else: | |
| num_nodes = torch.LongTensor([num_nodes]).to(device) | |
| arange = ( | |
| torch.arange(self.max_n_nodes, device=device) | |
| .unsqueeze(0) | |
| .expand(batch_size, -1) | |
| ) | |
| node_mask = arange < num_nodes.unsqueeze(1) | |
| z_T = utils.sample_discrete_feature_noise( | |
| limit_dist=self.limit_dist, node_mask=node_mask | |
| ) | |
| X, E = z_T.X, z_T.E | |
| assert (E == torch.transpose(E, 1, 2)).all() | |
| if number_chain_steps > 0: | |
| chain_X_size = torch.Size((number_chain_steps, X.size(1))) | |
| chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2))) | |
| chain_X = torch.zeros(chain_X_size) | |
| chain_E = torch.zeros(chain_E_size) | |
| # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | |
| y = properties | |
| for s_int in reversed(range(0, self.T)): | |
| s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | |
| t_array = s_array + 1 | |
| s_norm = s_array / self.T | |
| t_norm = t_array / self.T | |
| # Sample z_s | |
| sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt( | |
| s_norm, t_norm, X, E, y, node_mask, guide_scale, device | |
| ) | |
| X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
| if number_chain_steps > 0: | |
| # Save the first keep_chain graphs | |
| write_index = (s_int * number_chain_steps) // self.T | |
| chain_X[write_index] = discrete_sampled_s.X[:1] | |
| chain_E[write_index] = discrete_sampled_s.E[:1] | |
| # Sample | |
| sampled_s = sampled_s.mask(node_mask, collapse=True) | |
| X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
| molecule_list = [] | |
| n = num_nodes[0] | |
| atom_types = X[0, :n].cpu() | |
| edge_types = E[0, :n, :n].cpu() | |
| molecule_list.append([atom_types, edge_types]) | |
| smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0] | |
| # Visualize Chains | |
| if number_chain_steps > 0: | |
| final_X_chain = X[:1] | |
| final_E_chain = E[:1] | |
| chain_X[0] = final_X_chain # Overwrite last frame with the resulting X, E | |
| chain_E[0] = final_E_chain | |
| chain_X = utils.reverse_tensor(chain_X) | |
| chain_E = utils.reverse_tensor(chain_E) | |
| # Repeat last frame to see final sample better | |
| chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0) | |
| chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0) | |
| mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy()) | |
| else: | |
| mol_img_list = [] | |
| return smiles, mol_img_list | |
| def check_valid(self, smiles): | |
| return check_valid(smiles) | |
| def sample_p_zs_given_zt( | |
| self, s, t, X_t, E_t, properties, node_mask, guide_scale, device | |
| ): | |
| """Samples from zs ~ p(zs | zt). Only used during sampling. | |
| if last_step, return the graph prediction as well""" | |
| bs, n, _ = X_t.shape | |
| beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) | |
| alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) | |
| alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) | |
| # Neural net predictions | |
| noisy_data = { | |
| "X_t": X_t, | |
| "E_t": E_t, | |
| "y_t": properties, | |
| "t": t, | |
| "node_mask": node_mask, | |
| } | |
| def get_prob(noisy_data, unconditioned=False): | |
| pred = self._forward(noisy_data, unconditioned=unconditioned) | |
| # Normalize predictions | |
| pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 | |
| pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 | |
| # Retrieve transitions matrix | |
| Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device) | |
| Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device) | |
| Qt = self.transition_model.get_Qt(beta_t, device) | |
| Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) | |
| predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) | |
| unnormalized_probX_all = utils.reverse_diffusion( | |
| predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X | |
| ) | |
| unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output] | |
| unnormalized_prob_E = unnormalized_probX_all[ | |
| :, :, self.Xdim_output : | |
| ].reshape(bs, n * n, -1) | |
| unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 | |
| unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 | |
| prob_X = unnormalized_prob_X / torch.sum( | |
| unnormalized_prob_X, dim=-1, keepdim=True | |
| ) # bs, n, d_t-1 | |
| prob_E = unnormalized_prob_E / torch.sum( | |
| unnormalized_prob_E, dim=-1, keepdim=True | |
| ) # bs, n, d_t-1 | |
| prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) | |
| return prob_X, prob_E | |
| prob_X, prob_E = get_prob(noisy_data) | |
| ### Guidance | |
| if guide_scale != 1: | |
| uncon_prob_X, uncon_prob_E = get_prob( | |
| noisy_data, unconditioned=True | |
| ) | |
| prob_X = ( | |
| uncon_prob_X | |
| * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale | |
| ) | |
| prob_E = ( | |
| uncon_prob_E | |
| * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale | |
| ) | |
| prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5) | |
| prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5) | |
| # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all() | |
| # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all() | |
| sampled_s = utils.sample_discrete_features( | |
| prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item() | |
| ) | |
| X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype) | |
| E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype) | |
| assert (E_s == torch.transpose(E_s, 1, 2)).all() | |
| assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) | |
| out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties) | |
| out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties) | |
| return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask( | |
| node_mask, collapse=True | |
| ).type_as(properties) | |