|
"""Siren MLP https://www.vincentsitzmann.com/siren/""" |
|
|
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class SineLayer(nn.Module): |
|
""" |
|
Sine layer for the SIREN network. |
|
""" |
|
|
|
def __init__( |
|
self, in_features, out_features, bias=True, is_first=False, omega_0=30.0 |
|
): |
|
super().__init__() |
|
self.omega_0 = omega_0 |
|
self.is_first = is_first |
|
|
|
self.in_features = in_features |
|
self.linear = nn.Linear(in_features, out_features, bias=bias) |
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
with torch.no_grad(): |
|
if self.is_first: |
|
self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) |
|
else: |
|
self.linear.weight.uniform_( |
|
-np.sqrt(6 / self.in_features) / self.omega_0, |
|
np.sqrt(6 / self.in_features) / self.omega_0, |
|
) |
|
|
|
def forward(self, x): |
|
return torch.sin(self.omega_0 * self.linear(x)) |
|
|
|
|
|
class Siren(nn.Module): |
|
"""Siren network. |
|
|
|
Args: |
|
in_dim: Input layer dimension |
|
num_layers: Number of network layers |
|
layer_width: Width of each MLP layer |
|
out_dim: Output layer dimension. Uses layer_width if None. |
|
activation: intermediate layer activation function. |
|
out_activation: output activation function. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_dim: int, |
|
hidden_layers: int, |
|
hidden_features: int, |
|
out_dim: Optional[int] = None, |
|
outermost_linear: bool = False, |
|
first_omega_0: float = 30, |
|
hidden_omega_0: float = 30, |
|
out_activation: Optional[nn.Module] = None, |
|
) -> None: |
|
super().__init__() |
|
self.in_dim = in_dim |
|
assert self.in_dim > 0 |
|
self.out_dim = out_dim if out_dim is not None else hidden_features |
|
self.outermost_linear = outermost_linear |
|
self.first_omega_0 = first_omega_0 |
|
self.hidden_omega_0 = hidden_omega_0 |
|
self.hidden_layers = hidden_layers |
|
self.layer_width = hidden_features |
|
self.out_activation = out_activation |
|
|
|
self.net = [] |
|
self.net.append( |
|
SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0) |
|
) |
|
|
|
for _ in range(hidden_layers): |
|
self.net.append( |
|
SineLayer( |
|
hidden_features, |
|
hidden_features, |
|
is_first=False, |
|
omega_0=hidden_omega_0, |
|
) |
|
) |
|
|
|
if outermost_linear: |
|
final_layer = nn.Linear(hidden_features, self.out_dim) |
|
|
|
with torch.no_grad(): |
|
final_layer.weight.uniform_( |
|
-np.sqrt(6 / hidden_features) / hidden_omega_0, |
|
np.sqrt(6 / hidden_features) / hidden_omega_0, |
|
) |
|
|
|
self.net.append(final_layer) |
|
else: |
|
self.net.append( |
|
SineLayer( |
|
hidden_features, |
|
self.out_dim, |
|
is_first=False, |
|
omega_0=hidden_omega_0, |
|
) |
|
) |
|
|
|
if self.out_activation is not None: |
|
self.net.append(self.out_activation) |
|
|
|
self.net = nn.Sequential(*self.net) |
|
|
|
def forward(self, model_input): |
|
"""Forward pass through the network""" |
|
output = self.net(model_input) |
|
return output |
|
|