Spaces:
Runtime error
Runtime error
Update graph_decoder/diffusion_model.py
Browse files- graph_decoder/diffusion_model.py +67 -70
graph_decoder/diffusion_model.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import spaces
|
| 2 |
import os
|
| 3 |
import yaml
|
| 4 |
import json
|
|
@@ -20,73 +19,72 @@ class GraphDiT(nn.Module):
|
|
| 20 |
model_dtype,
|
| 21 |
):
|
| 22 |
super().__init__()
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
|
| 90 |
|
| 91 |
def init_model(self, model_dir):
|
| 92 |
model_file = os.path.join(model_dir, 'model.pt')
|
|
@@ -179,8 +177,7 @@ class GraphDiT(nn.Module):
|
|
| 179 |
}
|
| 180 |
return noisy_data
|
| 181 |
|
| 182 |
-
|
| 183 |
-
@spaces.GPU
|
| 184 |
def generate(
|
| 185 |
self,
|
| 186 |
properties,
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import yaml
|
| 3 |
import json
|
|
|
|
| 19 |
model_dtype,
|
| 20 |
):
|
| 21 |
super().__init__()
|
| 22 |
+
|
| 23 |
+
dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)
|
| 24 |
+
|
| 25 |
+
input_dims = data_info.input_dims
|
| 26 |
+
output_dims = data_info.output_dims
|
| 27 |
+
nodes_dist = data_info.nodes_dist
|
| 28 |
+
active_index = data_info.active_index
|
| 29 |
+
|
| 30 |
+
self.model_config = dm_cfg
|
| 31 |
+
self.data_info = data_info
|
| 32 |
+
self.T = dm_cfg.diffusion_steps
|
| 33 |
+
self.Xdim = input_dims["X"]
|
| 34 |
+
self.Edim = input_dims["E"]
|
| 35 |
+
self.ydim = input_dims["y"]
|
| 36 |
+
self.Xdim_output = output_dims["X"]
|
| 37 |
+
self.Edim_output = output_dims["E"]
|
| 38 |
+
self.ydim_output = output_dims["y"]
|
| 39 |
+
self.node_dist = nodes_dist
|
| 40 |
+
self.active_index = active_index
|
| 41 |
+
self.max_n_nodes = data_info.max_n_nodes
|
| 42 |
+
self.atom_decoder = data_info.atom_decoder
|
| 43 |
+
self.hidden_size = dm_cfg.hidden_size
|
| 44 |
+
self.mol_visualizer = MolecularVisualization(self.atom_decoder)
|
| 45 |
+
|
| 46 |
+
self.denoiser = Transformer(
|
| 47 |
+
max_n_nodes=self.max_n_nodes,
|
| 48 |
+
hidden_size=dm_cfg.hidden_size,
|
| 49 |
+
depth=dm_cfg.depth,
|
| 50 |
+
num_heads=dm_cfg.num_heads,
|
| 51 |
+
mlp_ratio=dm_cfg.mlp_ratio,
|
| 52 |
+
drop_condition=dm_cfg.drop_condition,
|
| 53 |
+
Xdim=self.Xdim,
|
| 54 |
+
Edim=self.Edim,
|
| 55 |
+
ydim=self.ydim,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.model_dtype = model_dtype
|
| 59 |
+
self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
|
| 60 |
+
dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
|
| 61 |
+
)
|
| 62 |
+
x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
|
| 63 |
+
data_info.node_types.to(self.model_dtype)
|
| 64 |
+
)
|
| 65 |
+
e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
|
| 66 |
+
data_info.edge_types.to(self.model_dtype)
|
| 67 |
+
)
|
| 68 |
+
x_marginals = x_marginals / x_marginals.sum()
|
| 69 |
+
e_marginals = e_marginals / e_marginals.sum()
|
| 70 |
+
|
| 71 |
+
xe_conditions = data_info.transition_E.to(self.model_dtype)
|
| 72 |
+
xe_conditions = xe_conditions[self.active_index][:, self.active_index]
|
| 73 |
+
|
| 74 |
+
xe_conditions = xe_conditions.sum(dim=1)
|
| 75 |
+
ex_conditions = xe_conditions.t()
|
| 76 |
+
xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
|
| 77 |
+
ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)
|
| 78 |
+
|
| 79 |
+
self.transition_model = utils.MarginalTransition(
|
| 80 |
+
x_marginals=x_marginals,
|
| 81 |
+
e_marginals=e_marginals,
|
| 82 |
+
xe_conditions=xe_conditions,
|
| 83 |
+
ex_conditions=ex_conditions,
|
| 84 |
+
y_classes=self.ydim_output,
|
| 85 |
+
n_nodes=self.max_n_nodes,
|
| 86 |
+
)
|
| 87 |
+
self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)
|
|
|
|
| 88 |
|
| 89 |
def init_model(self, model_dir):
|
| 90 |
model_file = os.path.join(model_dir, 'model.pt')
|
|
|
|
| 177 |
}
|
| 178 |
return noisy_data
|
| 179 |
|
| 180 |
+
@torch.no_grad()
|
|
|
|
| 181 |
def generate(
|
| 182 |
self,
|
| 183 |
properties,
|