|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
from typing import Union, Tuple, List, Callable, Optional |
|
|
|
import numpy as np |
|
from einops import repeat, rearrange |
|
from tqdm import tqdm |
|
import logging |
|
|
|
import comfy.ops |
|
ops = comfy.ops.disable_weight_init |
|
|
|
def generate_dense_grid_points( |
|
bbox_min: np.ndarray, |
|
bbox_max: np.ndarray, |
|
octree_resolution: int, |
|
indexing: str = "ij", |
|
): |
|
length = bbox_max - bbox_min |
|
num_cells = octree_resolution |
|
|
|
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) |
|
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) |
|
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) |
|
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) |
|
xyz = np.stack((xs, ys, zs), axis=-1) |
|
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] |
|
|
|
return xyz, grid_size, length |
|
|
|
|
|
class VanillaVolumeDecoder: |
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
latents: torch.FloatTensor, |
|
geo_decoder: Callable, |
|
bounds: Union[Tuple[float], List[float], float] = 1.01, |
|
num_chunks: int = 10000, |
|
octree_resolution: int = None, |
|
enable_pbar: bool = True, |
|
**kwargs, |
|
): |
|
device = latents.device |
|
dtype = latents.dtype |
|
batch_size = latents.shape[0] |
|
|
|
|
|
if isinstance(bounds, float): |
|
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] |
|
|
|
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) |
|
xyz_samples, grid_size, length = generate_dense_grid_points( |
|
bbox_min=bbox_min, |
|
bbox_max=bbox_max, |
|
octree_resolution=octree_resolution, |
|
indexing="ij" |
|
) |
|
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) |
|
|
|
|
|
batch_logits = [] |
|
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", |
|
disable=not enable_pbar): |
|
chunk_queries = xyz_samples[start: start + num_chunks, :] |
|
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) |
|
logits = geo_decoder(queries=chunk_queries, latents=latents) |
|
batch_logits.append(logits) |
|
|
|
grid_logits = torch.cat(batch_logits, dim=1) |
|
grid_logits = grid_logits.view((batch_size, *grid_size)).float() |
|
|
|
return grid_logits |
|
|
|
|
|
class FourierEmbedder(nn.Module): |
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts |
|
each feature dimension of `x[..., i]` into: |
|
[ |
|
sin(x[..., i]), |
|
sin(f_1*x[..., i]), |
|
sin(f_2*x[..., i]), |
|
... |
|
sin(f_N * x[..., i]), |
|
cos(x[..., i]), |
|
cos(f_1*x[..., i]), |
|
cos(f_2*x[..., i]), |
|
... |
|
cos(f_N * x[..., i]), |
|
x[..., i] # only present if include_input is True. |
|
], here f_i is the frequency. |
|
|
|
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. |
|
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; |
|
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. |
|
|
|
Args: |
|
num_freqs (int): the number of frequencies, default is 6; |
|
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], |
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; |
|
input_dim (int): the input dimension, default is 3; |
|
include_input (bool): include the input tensor or not, default is True. |
|
|
|
Attributes: |
|
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], |
|
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); |
|
|
|
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), |
|
otherwise, it is input_dim * num_freqs * 2. |
|
|
|
""" |
|
|
|
def __init__(self, |
|
num_freqs: int = 6, |
|
logspace: bool = True, |
|
input_dim: int = 3, |
|
include_input: bool = True, |
|
include_pi: bool = True) -> None: |
|
|
|
"""The initialization""" |
|
|
|
super().__init__() |
|
|
|
if logspace: |
|
frequencies = 2.0 ** torch.arange( |
|
num_freqs, |
|
dtype=torch.float32 |
|
) |
|
else: |
|
frequencies = torch.linspace( |
|
1.0, |
|
2.0 ** (num_freqs - 1), |
|
num_freqs, |
|
dtype=torch.float32 |
|
) |
|
|
|
if include_pi: |
|
frequencies *= torch.pi |
|
|
|
self.register_buffer("frequencies", frequencies, persistent=False) |
|
self.include_input = include_input |
|
self.num_freqs = num_freqs |
|
|
|
self.out_dim = self.get_dims(input_dim) |
|
|
|
def get_dims(self, input_dim): |
|
temp = 1 if self.include_input or self.num_freqs == 0 else 0 |
|
out_dim = input_dim * (self.num_freqs * 2 + temp) |
|
|
|
return out_dim |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" Forward process. |
|
|
|
Args: |
|
x: tensor of shape [..., dim] |
|
|
|
Returns: |
|
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] |
|
where temp is 1 if include_input is True and 0 otherwise. |
|
""" |
|
|
|
if self.num_freqs > 0: |
|
embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1) |
|
if self.include_input: |
|
return torch.cat((x, embed.sin(), embed.cos()), dim=-1) |
|
else: |
|
return torch.cat((embed.sin(), embed.cos()), dim=-1) |
|
else: |
|
return x |
|
|
|
|
|
class CrossAttentionProcessor: |
|
def __call__(self, attn, q, k, v): |
|
out = F.scaled_dot_product_attention(q, k, v) |
|
return out |
|
|
|
|
|
class DropPath(nn.Module): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
""" |
|
|
|
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
self.scale_by_keep = scale_by_keep |
|
|
|
def forward(self, x): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
|
'survival rate' as the argument. |
|
|
|
""" |
|
if self.drop_prob == 0. or not self.training: |
|
return x |
|
keep_prob = 1 - self.drop_prob |
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
if keep_prob > 0.0 and self.scale_by_keep: |
|
random_tensor.div_(keep_prob) |
|
return x * random_tensor |
|
|
|
def extra_repr(self): |
|
return f'drop_prob={round(self.drop_prob, 3):0.3f}' |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, *, |
|
width: int, |
|
expand_ratio: int = 4, |
|
output_width: int = None, |
|
drop_path_rate: float = 0.0 |
|
): |
|
super().__init__() |
|
self.width = width |
|
self.c_fc = ops.Linear(width, width * expand_ratio) |
|
self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width) |
|
self.gelu = nn.GELU() |
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
def forward(self, x): |
|
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) |
|
|
|
|
|
class QKVMultiheadCrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
heads: int, |
|
width=None, |
|
qk_norm=False, |
|
norm_layer=ops.LayerNorm |
|
): |
|
super().__init__() |
|
self.heads = heads |
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() |
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() |
|
|
|
self.attn_processor = CrossAttentionProcessor() |
|
|
|
def forward(self, q, kv): |
|
_, n_ctx, _ = q.shape |
|
bs, n_data, width = kv.shape |
|
attn_ch = width // self.heads // 2 |
|
q = q.view(bs, n_ctx, self.heads, -1) |
|
kv = kv.view(bs, n_data, self.heads, -1) |
|
k, v = torch.split(kv, attn_ch, dim=-1) |
|
|
|
q = self.q_norm(q) |
|
k = self.k_norm(k) |
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) |
|
out = self.attn_processor(self, q, k, v) |
|
out = out.transpose(1, 2).reshape(bs, n_ctx, -1) |
|
return out |
|
|
|
|
|
class MultiheadCrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
width: int, |
|
heads: int, |
|
qkv_bias: bool = True, |
|
data_width: Optional[int] = None, |
|
norm_layer=ops.LayerNorm, |
|
qk_norm: bool = False, |
|
kv_cache: bool = False, |
|
): |
|
super().__init__() |
|
self.width = width |
|
self.heads = heads |
|
self.data_width = width if data_width is None else data_width |
|
self.c_q = ops.Linear(width, width, bias=qkv_bias) |
|
self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias) |
|
self.c_proj = ops.Linear(width, width) |
|
self.attention = QKVMultiheadCrossAttention( |
|
heads=heads, |
|
width=width, |
|
norm_layer=norm_layer, |
|
qk_norm=qk_norm |
|
) |
|
self.kv_cache = kv_cache |
|
self.data = None |
|
|
|
def forward(self, x, data): |
|
x = self.c_q(x) |
|
if self.kv_cache: |
|
if self.data is None: |
|
self.data = self.c_kv(data) |
|
logging.info('Save kv cache,this should be called only once for one mesh') |
|
data = self.data |
|
else: |
|
data = self.c_kv(data) |
|
x = self.attention(x, data) |
|
x = self.c_proj(x) |
|
return x |
|
|
|
|
|
class ResidualCrossAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
width: int, |
|
heads: int, |
|
mlp_expand_ratio: int = 4, |
|
data_width: Optional[int] = None, |
|
qkv_bias: bool = True, |
|
norm_layer=ops.LayerNorm, |
|
qk_norm: bool = False |
|
): |
|
super().__init__() |
|
|
|
if data_width is None: |
|
data_width = width |
|
|
|
self.attn = MultiheadCrossAttention( |
|
width=width, |
|
heads=heads, |
|
data_width=data_width, |
|
qkv_bias=qkv_bias, |
|
norm_layer=norm_layer, |
|
qk_norm=qk_norm |
|
) |
|
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) |
|
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) |
|
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) |
|
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) |
|
|
|
def forward(self, x: torch.Tensor, data: torch.Tensor): |
|
x = x + self.attn(self.ln_1(x), self.ln_2(data)) |
|
x = x + self.mlp(self.ln_3(x)) |
|
return x |
|
|
|
|
|
class QKVMultiheadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
heads: int, |
|
width=None, |
|
qk_norm=False, |
|
norm_layer=ops.LayerNorm |
|
): |
|
super().__init__() |
|
self.heads = heads |
|
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() |
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() |
|
|
|
def forward(self, qkv): |
|
bs, n_ctx, width = qkv.shape |
|
attn_ch = width // self.heads // 3 |
|
qkv = qkv.view(bs, n_ctx, self.heads, -1) |
|
q, k, v = torch.split(qkv, attn_ch, dim=-1) |
|
|
|
q = self.q_norm(q) |
|
k = self.k_norm(k) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) |
|
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) |
|
return out |
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
width: int, |
|
heads: int, |
|
qkv_bias: bool, |
|
norm_layer=ops.LayerNorm, |
|
qk_norm: bool = False, |
|
drop_path_rate: float = 0.0 |
|
): |
|
super().__init__() |
|
self.width = width |
|
self.heads = heads |
|
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) |
|
self.c_proj = ops.Linear(width, width) |
|
self.attention = QKVMultiheadAttention( |
|
heads=heads, |
|
width=width, |
|
norm_layer=norm_layer, |
|
qk_norm=qk_norm |
|
) |
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
def forward(self, x): |
|
x = self.c_qkv(x) |
|
x = self.attention(x) |
|
x = self.drop_path(self.c_proj(x)) |
|
return x |
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
width: int, |
|
heads: int, |
|
qkv_bias: bool = True, |
|
norm_layer=ops.LayerNorm, |
|
qk_norm: bool = False, |
|
drop_path_rate: float = 0.0, |
|
): |
|
super().__init__() |
|
self.attn = MultiheadAttention( |
|
width=width, |
|
heads=heads, |
|
qkv_bias=qkv_bias, |
|
norm_layer=norm_layer, |
|
qk_norm=qk_norm, |
|
drop_path_rate=drop_path_rate |
|
) |
|
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) |
|
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) |
|
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) |
|
|
|
def forward(self, x: torch.Tensor): |
|
x = x + self.attn(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
width: int, |
|
layers: int, |
|
heads: int, |
|
qkv_bias: bool = True, |
|
norm_layer=ops.LayerNorm, |
|
qk_norm: bool = False, |
|
drop_path_rate: float = 0.0 |
|
): |
|
super().__init__() |
|
self.width = width |
|
self.layers = layers |
|
self.resblocks = nn.ModuleList( |
|
[ |
|
ResidualAttentionBlock( |
|
width=width, |
|
heads=heads, |
|
qkv_bias=qkv_bias, |
|
norm_layer=norm_layer, |
|
qk_norm=qk_norm, |
|
drop_path_rate=drop_path_rate |
|
) |
|
for _ in range(layers) |
|
] |
|
) |
|
|
|
def forward(self, x: torch.Tensor): |
|
for block in self.resblocks: |
|
x = block(x) |
|
return x |
|
|
|
|
|
class CrossAttentionDecoder(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
*, |
|
out_channels: int, |
|
fourier_embedder: FourierEmbedder, |
|
width: int, |
|
heads: int, |
|
mlp_expand_ratio: int = 4, |
|
downsample_ratio: int = 1, |
|
enable_ln_post: bool = True, |
|
qkv_bias: bool = True, |
|
qk_norm: bool = False, |
|
label_type: str = "binary" |
|
): |
|
super().__init__() |
|
|
|
self.enable_ln_post = enable_ln_post |
|
self.fourier_embedder = fourier_embedder |
|
self.downsample_ratio = downsample_ratio |
|
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) |
|
if self.downsample_ratio != 1: |
|
self.latents_proj = ops.Linear(width * downsample_ratio, width) |
|
if self.enable_ln_post == False: |
|
qk_norm = False |
|
self.cross_attn_decoder = ResidualCrossAttentionBlock( |
|
width=width, |
|
mlp_expand_ratio=mlp_expand_ratio, |
|
heads=heads, |
|
qkv_bias=qkv_bias, |
|
qk_norm=qk_norm |
|
) |
|
|
|
if self.enable_ln_post: |
|
self.ln_post = ops.LayerNorm(width) |
|
self.output_proj = ops.Linear(width, out_channels) |
|
self.label_type = label_type |
|
self.count = 0 |
|
|
|
def forward(self, queries=None, query_embeddings=None, latents=None): |
|
if query_embeddings is None: |
|
query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) |
|
self.count += query_embeddings.shape[1] |
|
if self.downsample_ratio != 1: |
|
latents = self.latents_proj(latents) |
|
x = self.cross_attn_decoder(query_embeddings, latents) |
|
if self.enable_ln_post: |
|
x = self.ln_post(x) |
|
occ = self.output_proj(x) |
|
return occ |
|
|
|
|
|
class ShapeVAE(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
embed_dim: int, |
|
width: int, |
|
heads: int, |
|
num_decoder_layers: int, |
|
geo_decoder_downsample_ratio: int = 1, |
|
geo_decoder_mlp_expand_ratio: int = 4, |
|
geo_decoder_ln_post: bool = True, |
|
num_freqs: int = 8, |
|
include_pi: bool = True, |
|
qkv_bias: bool = True, |
|
qk_norm: bool = False, |
|
label_type: str = "binary", |
|
drop_path_rate: float = 0.0, |
|
scale_factor: float = 1.0, |
|
): |
|
super().__init__() |
|
self.geo_decoder_ln_post = geo_decoder_ln_post |
|
|
|
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) |
|
|
|
self.post_kl = ops.Linear(embed_dim, width) |
|
|
|
self.transformer = Transformer( |
|
width=width, |
|
layers=num_decoder_layers, |
|
heads=heads, |
|
qkv_bias=qkv_bias, |
|
qk_norm=qk_norm, |
|
drop_path_rate=drop_path_rate |
|
) |
|
|
|
self.geo_decoder = CrossAttentionDecoder( |
|
fourier_embedder=self.fourier_embedder, |
|
out_channels=1, |
|
mlp_expand_ratio=geo_decoder_mlp_expand_ratio, |
|
downsample_ratio=geo_decoder_downsample_ratio, |
|
enable_ln_post=self.geo_decoder_ln_post, |
|
width=width // geo_decoder_downsample_ratio, |
|
heads=heads // geo_decoder_downsample_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_norm=qk_norm, |
|
label_type=label_type, |
|
) |
|
|
|
self.volume_decoder = VanillaVolumeDecoder() |
|
self.scale_factor = scale_factor |
|
|
|
def decode(self, latents, **kwargs): |
|
latents = self.post_kl(latents.movedim(-2, -1)) |
|
latents = self.transformer(latents) |
|
|
|
bounds = kwargs.get("bounds", 1.01) |
|
num_chunks = kwargs.get("num_chunks", 8000) |
|
octree_resolution = kwargs.get("octree_resolution", 256) |
|
enable_pbar = kwargs.get("enable_pbar", True) |
|
|
|
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) |
|
return grid_logits.movedim(-2, -1) |
|
|
|
def encode(self, x): |
|
return None |
|
|