Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| from abc import ABC, abstractmethod | |
| from collections import OrderedDict | |
| from typing import Any, Dict, Optional, Tuple | |
| import numpy as np | |
| import torch.nn as nn | |
| from torch import torch | |
| from shap_e.util.collections import AttrDict | |
| def flatten_param_shapes(param_shapes: Dict[str, Tuple[int]]): | |
| flat_shapes = OrderedDict( | |
| (name, (int(np.prod(shape)) // shape[-1], shape[-1])) | |
| for name, shape in param_shapes.items() | |
| ) | |
| return flat_shapes | |
| class ParamsProj(nn.Module, ABC): | |
| def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int): | |
| super().__init__() | |
| self.device = device | |
| self.param_shapes = param_shapes | |
| self.d_latent = d_latent | |
| def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
| pass | |
| class LinearParamsProj(ParamsProj): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| param_shapes: Dict[str, Tuple[int]], | |
| d_latent: int, | |
| init_scale: Optional[float] = None, | |
| ): | |
| super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) | |
| self.param_shapes = param_shapes | |
| self.projections = nn.ModuleDict({}) | |
| for k, v in param_shapes.items(): | |
| self.projections[_sanitize_name(k)] = nn.Linear( | |
| d_latent, int(np.prod(v)), device=device | |
| ) | |
| if init_scale is not None: | |
| scale = init_scale / math.sqrt(d_latent) | |
| mod = self.projections[_sanitize_name(k)] | |
| nn.init.normal_(mod.weight, std=scale) | |
| nn.init.zeros_(mod.bias) | |
| def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
| out = AttrDict() | |
| for k in self.param_shapes.keys(): | |
| proj = self.projections[_sanitize_name(k)] | |
| out[k] = proj(x).reshape([len(x), *self.param_shapes[k]]) | |
| return out | |
| class MLPParamsProj(ParamsProj): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| param_shapes: Dict[str, Tuple[int]], | |
| d_latent: int, | |
| hidden_size: Optional[int] = None, | |
| ): | |
| super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) | |
| if hidden_size is None: | |
| hidden_size = d_latent | |
| self.param_shapes = param_shapes | |
| self.projections = nn.ModuleDict({}) | |
| for k, v in param_shapes.items(): | |
| self.projections[_sanitize_name(k)] = nn.Sequential( | |
| nn.Linear(d_latent, hidden_size, device=device), | |
| nn.GELU(), | |
| nn.Linear(hidden_size, int(np.prod(v)), device=device), | |
| ) | |
| def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
| out = AttrDict() | |
| for k in self.param_shapes.keys(): | |
| proj = self.projections[_sanitize_name(k)] | |
| out[k] = proj(x).reshape([len(x), *self.param_shapes[k]]) | |
| return out | |
| class ChannelsProj(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| vectors: int, | |
| channels: int, | |
| d_latent: int, | |
| init_scale: float = 1.0, | |
| learned_scale: Optional[float] = None, | |
| use_ln: bool = False, | |
| ): | |
| super().__init__() | |
| self.proj = nn.Linear(d_latent, vectors * channels, device=device) | |
| self.use_ln = use_ln | |
| self.learned_scale = learned_scale | |
| if use_ln: | |
| self.norm = nn.LayerNorm(normalized_shape=(channels,), device=device) | |
| if learned_scale is not None: | |
| self.norm.weight.data.fill_(learned_scale) | |
| scale = init_scale / math.sqrt(d_latent) | |
| elif learned_scale is not None: | |
| gain = torch.ones((channels,), device=device) * learned_scale | |
| self.register_parameter("gain", nn.Parameter(gain)) | |
| scale = init_scale / math.sqrt(d_latent) | |
| else: | |
| scale = init_scale / math.sqrt(d_latent * channels) | |
| nn.init.normal_(self.proj.weight, std=scale) | |
| nn.init.zeros_(self.proj.bias) | |
| self.d_latent = d_latent | |
| self.vectors = vectors | |
| self.channels = channels | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x_bvd = x | |
| w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) | |
| b_vc = self.proj.bias.view(1, self.vectors, self.channels) | |
| h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) | |
| if self.use_ln: | |
| h = self.norm(h) | |
| elif self.learned_scale is not None: | |
| h = h * self.gain.view(1, 1, -1) | |
| h = h + b_vc | |
| return h | |
| class ChannelsParamsProj(ParamsProj): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| param_shapes: Dict[str, Tuple[int]], | |
| d_latent: int, | |
| init_scale: float = 1.0, | |
| learned_scale: Optional[float] = None, | |
| use_ln: bool = False, | |
| ): | |
| super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) | |
| self.param_shapes = param_shapes | |
| self.projections = nn.ModuleDict({}) | |
| self.flat_shapes = flatten_param_shapes(param_shapes) | |
| self.learned_scale = learned_scale | |
| self.use_ln = use_ln | |
| for k, (vectors, channels) in self.flat_shapes.items(): | |
| self.projections[_sanitize_name(k)] = ChannelsProj( | |
| device=device, | |
| vectors=vectors, | |
| channels=channels, | |
| d_latent=d_latent, | |
| init_scale=init_scale, | |
| learned_scale=learned_scale, | |
| use_ln=use_ln, | |
| ) | |
| def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
| out = AttrDict() | |
| start = 0 | |
| for k, shape in self.param_shapes.items(): | |
| vectors, _ = self.flat_shapes[k] | |
| end = start + vectors | |
| x_bvd = x[:, start:end] | |
| out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) | |
| start = end | |
| return out | |
| def params_proj_from_config( | |
| config: Dict[str, Any], device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int | |
| ): | |
| name = config.pop("name") | |
| if name == "linear": | |
| return LinearParamsProj( | |
| **config, device=device, param_shapes=param_shapes, d_latent=d_latent | |
| ) | |
| elif name == "mlp": | |
| return MLPParamsProj(**config, device=device, param_shapes=param_shapes, d_latent=d_latent) | |
| elif name == "channels": | |
| return ChannelsParamsProj( | |
| **config, device=device, param_shapes=param_shapes, d_latent=d_latent | |
| ) | |
| else: | |
| raise ValueError(f"unknown params proj: {name}") | |
| def _sanitize_name(x: str) -> str: | |
| return x.replace(".", "__") | |