Spaces:
Runtime error
Runtime error
Upload graph_decoder/transformer.py with huggingface_hub
Browse files- graph_decoder/transformer.py +180 -0
graph_decoder/transformer.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .layers import Attention, MLP
|
| 4 |
+
from .conditions import TimestepEmbedder, ConditionEmbedder
|
| 5 |
+
from .diffusion_utils import PlaceHolder
|
| 6 |
+
|
| 7 |
+
def modulate(x, shift, scale):
|
| 8 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 9 |
+
|
| 10 |
+
class Transformer(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
max_n_nodes,
|
| 14 |
+
hidden_size=384,
|
| 15 |
+
depth=12,
|
| 16 |
+
num_heads=16,
|
| 17 |
+
mlp_ratio=4.0,
|
| 18 |
+
drop_condition=0.1,
|
| 19 |
+
Xdim=118,
|
| 20 |
+
Edim=5,
|
| 21 |
+
ydim=5,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.num_heads = num_heads
|
| 25 |
+
self.ydim = ydim
|
| 26 |
+
self.x_embedder = nn.Sequential(
|
| 27 |
+
nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False),
|
| 28 |
+
nn.LayerNorm(hidden_size)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 32 |
+
self.y_embedder = ConditionEmbedder(ydim, hidden_size, drop_condition)
|
| 33 |
+
|
| 34 |
+
self.blocks = nn.ModuleList(
|
| 35 |
+
[
|
| 36 |
+
Block(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
| 37 |
+
for _ in range(depth)
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
self.output_layer = OutputLayer(
|
| 41 |
+
max_n_nodes=max_n_nodes,
|
| 42 |
+
hidden_size=hidden_size,
|
| 43 |
+
atom_type=Xdim,
|
| 44 |
+
bond_type=Edim,
|
| 45 |
+
mlp_ratio=mlp_ratio,
|
| 46 |
+
num_heads=num_heads,
|
| 47 |
+
)
|
| 48 |
+
self.initialize_weights()
|
| 49 |
+
|
| 50 |
+
def initialize_weights(self):
|
| 51 |
+
# Initialize transformer layers:
|
| 52 |
+
def _basic_init(module):
|
| 53 |
+
if isinstance(module, nn.Linear):
|
| 54 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 55 |
+
if module.bias is not None:
|
| 56 |
+
nn.init.constant_(module.bias, 0)
|
| 57 |
+
|
| 58 |
+
def _constant_init(module, i):
|
| 59 |
+
if isinstance(module, nn.Linear):
|
| 60 |
+
nn.init.constant_(module.weight, i)
|
| 61 |
+
if module.bias is not None:
|
| 62 |
+
nn.init.constant_(module.bias, i)
|
| 63 |
+
|
| 64 |
+
self.apply(_basic_init)
|
| 65 |
+
|
| 66 |
+
for block in self.blocks:
|
| 67 |
+
_constant_init(block.adaLN_modulation[0], 0)
|
| 68 |
+
_constant_init(self.output_layer.adaLN_modulation[0], 0)
|
| 69 |
+
|
| 70 |
+
def disable_grads(self):
|
| 71 |
+
"""
|
| 72 |
+
Disable gradients for all parameters in the model.
|
| 73 |
+
"""
|
| 74 |
+
for param in self.parameters():
|
| 75 |
+
param.requires_grad = False
|
| 76 |
+
|
| 77 |
+
def print_trainable_parameters(self):
|
| 78 |
+
print("Trainable parameters:")
|
| 79 |
+
for name, param in self.named_parameters():
|
| 80 |
+
if param.requires_grad:
|
| 81 |
+
print(f"{name}: {param.size()}")
|
| 82 |
+
|
| 83 |
+
# Calculate and print the total number of trainable parameters
|
| 84 |
+
total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 85 |
+
print(f"\nTotal trainable parameters: {total_params}")
|
| 86 |
+
|
| 87 |
+
def forward(self, X_in, E_in, node_mask, y_in, t, unconditioned):
|
| 88 |
+
bs, n, _ = X_in.size()
|
| 89 |
+
X = torch.cat([X_in, E_in.reshape(bs, n, -1)], dim=-1)
|
| 90 |
+
X = self.x_embedder(X)
|
| 91 |
+
|
| 92 |
+
c1 = self.t_embedder(t)
|
| 93 |
+
c2 = self.y_embedder(y_in, self.training, unconditioned)
|
| 94 |
+
c = c1 + c2
|
| 95 |
+
|
| 96 |
+
for i, block in enumerate(self.blocks):
|
| 97 |
+
X = block(X, c, node_mask)
|
| 98 |
+
|
| 99 |
+
# X: B * N * dx, E: B * N * N * de
|
| 100 |
+
X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
|
| 101 |
+
return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
|
| 102 |
+
|
| 103 |
+
class Block(nn.Module):
|
| 104 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.attn_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False)
|
| 107 |
+
self.mlp_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False)
|
| 108 |
+
|
| 109 |
+
self.attn = Attention(
|
| 110 |
+
hidden_size, num_heads=num_heads, qkv_bias=False, qk_norm=True, **block_kwargs
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.mlp = MLP(
|
| 114 |
+
in_features=hidden_size,
|
| 115 |
+
hidden_features=int(hidden_size * mlp_ratio),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.adaLN_modulation = nn.Sequential(
|
| 119 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 120 |
+
nn.SiLU(),
|
| 121 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True),
|
| 122 |
+
nn.Softsign()
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def forward(self, x, c, node_mask):
|
| 126 |
+
(
|
| 127 |
+
shift_msa,
|
| 128 |
+
scale_msa,
|
| 129 |
+
gate_msa,
|
| 130 |
+
shift_mlp,
|
| 131 |
+
scale_mlp,
|
| 132 |
+
gate_mlp,
|
| 133 |
+
) = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 134 |
+
|
| 135 |
+
x = x + gate_msa.unsqueeze(1) * modulate(self.attn_norm(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa)
|
| 136 |
+
x = x + gate_mlp.unsqueeze(1) * modulate(self.mlp_norm(self.mlp(x)), shift_mlp, scale_mlp)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
class OutputLayer(nn.Module):
|
| 141 |
+
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.atom_type = atom_type
|
| 144 |
+
self.bond_type = bond_type
|
| 145 |
+
final_size = atom_type + max_n_nodes * bond_type
|
| 146 |
+
self.xedecoder = MLP(in_features=hidden_size,
|
| 147 |
+
out_features=final_size, drop=0)
|
| 148 |
+
|
| 149 |
+
self.norm_final = nn.LayerNorm(final_size, eps=1e-05, elementwise_affine=False)
|
| 150 |
+
self.adaLN_modulation = nn.Sequential(
|
| 151 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 152 |
+
nn.SiLU(),
|
| 153 |
+
nn.Linear(hidden_size, 2 * final_size, bias=True)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def forward(self, x, x_in, e_in, c, t, node_mask):
|
| 157 |
+
x_all = self.xedecoder(x)
|
| 158 |
+
B, N, D = x_all.size()
|
| 159 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 160 |
+
x_all = modulate(self.norm_final(x_all), shift, scale)
|
| 161 |
+
|
| 162 |
+
atom_out = x_all[:, :, :self.atom_type]
|
| 163 |
+
atom_out = x_in + atom_out
|
| 164 |
+
|
| 165 |
+
bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type)
|
| 166 |
+
bond_out = e_in + bond_out
|
| 167 |
+
|
| 168 |
+
##### standardize adj_out
|
| 169 |
+
edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
|
| 170 |
+
diag_mask = (
|
| 171 |
+
torch.eye(N, dtype=torch.bool)
|
| 172 |
+
.unsqueeze(0)
|
| 173 |
+
.expand(B, -1, -1)
|
| 174 |
+
.type_as(edge_mask)
|
| 175 |
+
)
|
| 176 |
+
bond_out.masked_fill_(edge_mask[:, :, :, None], 0)
|
| 177 |
+
bond_out.masked_fill_(diag_mask[:, :, :, None], 0)
|
| 178 |
+
bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2))
|
| 179 |
+
|
| 180 |
+
return atom_out, bond_out
|