Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,489 Bytes
829e08b 1f6855f 829e08b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import os
from omegaconf import OmegaConf
import torch
from torch import nn
from .utils.misc import instantiate_from_config
from ..utils import default, exists
def load_model():
model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml"))
# print(model_config)
if hasattr(model_config, "model"):
model_config = model_config.model
ckpt_path = "./ckpt/checkpoints/aligned_shape_latents/shapevae-256.ckpt"
model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
# model = model.cuda()
model = model.eval()
return model
class ShapeConditioner(nn.Module):
def __init__(
self,
*,
dim_latent = None
):
super().__init__()
self.model = load_model()
self.dim_model_out = 768
dim_latent = default(dim_latent, self.dim_model_out)
self.dim_latent = dim_latent
def forward(
self,
shape = None,
shape_embed = None,
):
assert exists(shape) ^ exists(shape_embed)
if not exists(shape_embed):
point_feature = self.model.encode_latents(shape)
shape_latents = self.model.to_shape_latents(point_feature[:, 1:])
shape_head = point_feature[:, 0:1]
shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1)
# shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp
return shape_head, shape_embed |