Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Function | |
| from torch.cuda.amp import custom_bwd, custom_fwd | |
| from lam.models.rendering.utils.typing import * | |
| def get_activation(name): | |
| if name is None: | |
| return lambda x: x | |
| name = name.lower() | |
| if name == "none": | |
| return lambda x: x | |
| elif name == "lin2srgb": | |
| return lambda x: torch.where( | |
| x > 0.0031308, | |
| torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, | |
| 12.92 * x, | |
| ).clamp(0.0, 1.0) | |
| elif name == "exp": | |
| return lambda x: torch.exp(x) | |
| elif name == "shifted_exp": | |
| return lambda x: torch.exp(x - 1.0) | |
| elif name == "trunc_exp": | |
| return trunc_exp | |
| elif name == "shifted_trunc_exp": | |
| return lambda x: trunc_exp(x - 1.0) | |
| elif name == "sigmoid": | |
| return lambda x: torch.sigmoid(x) | |
| elif name == "tanh": | |
| return lambda x: torch.tanh(x) | |
| elif name == "shifted_softplus": | |
| return lambda x: F.softplus(x - 1.0) | |
| elif name == "scale_-11_01": | |
| return lambda x: x * 0.5 + 0.5 | |
| else: | |
| try: | |
| return getattr(F, name) | |
| except AttributeError: | |
| raise ValueError(f"Unknown activation function: {name}") | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| dim_in: int, | |
| dim_out: int, | |
| n_neurons: int, | |
| n_hidden_layers: int, | |
| activation: str = "relu", | |
| output_activation: Optional[str] = None, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| layers = [ | |
| self.make_linear( | |
| dim_in, n_neurons, is_first=True, is_last=False, bias=bias | |
| ), | |
| self.make_activation(activation), | |
| ] | |
| for i in range(n_hidden_layers - 1): | |
| layers += [ | |
| self.make_linear( | |
| n_neurons, n_neurons, is_first=False, is_last=False, bias=bias | |
| ), | |
| self.make_activation(activation), | |
| ] | |
| layers += [ | |
| self.make_linear( | |
| n_neurons, dim_out, is_first=False, is_last=True, bias=bias | |
| ) | |
| ] | |
| self.layers = nn.Sequential(*layers) | |
| self.output_activation = get_activation(output_activation) | |
| def forward(self, x): | |
| x = self.layers(x) | |
| x = self.output_activation(x) | |
| return x | |
| def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True): | |
| layer = nn.Linear(dim_in, dim_out, bias=bias) | |
| return layer | |
| def make_activation(self, activation): | |
| if activation == "relu": | |
| return nn.ReLU(inplace=True) | |
| elif activation == "silu": | |
| return nn.SiLU(inplace=True) | |
| else: | |
| raise NotImplementedError | |
| class _TruncExp(Function): # pylint: disable=abstract-method | |
| # Implementation from torch-ngp: | |
| # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py | |
| def forward(ctx, x): # pylint: disable=arguments-differ | |
| ctx.save_for_backward(x) | |
| return torch.exp(x) | |
| def backward(ctx, g): # pylint: disable=arguments-differ | |
| x = ctx.saved_tensors[0] | |
| return g * torch.exp(torch.clamp(x, max=15)) | |
| trunc_exp = _TruncExp.apply |