Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class AdaptiveGroupNorm(nn.Module): | |
def __init__(self, num_groups, num_channels, emb_dim, eps=1e-5): | |
super().__init__() | |
self.num_groups = num_groups | |
self.num_channels = num_channels | |
self.eps = eps | |
# Use a standard GroupNorm, but without learnable affine parameters | |
self.norm = nn.GroupNorm(num_groups, num_channels, eps=eps, affine=False) | |
# Linear layers to project the embedding to gamma and beta | |
self.gamma_proj = nn.Linear(emb_dim, num_channels) | |
self.beta_proj = nn.Linear(emb_dim, num_channels) | |
def forward(self, x, emb): | |
""" | |
Args: | |
x: Input tensor of shape [B, C, H, W]. | |
emb: Embedding tensor of shape [B, emb_dim]. | |
Returns: | |
Normalized tensor with adaptive scaling and shifting. | |
""" | |
# Normalize as usual with GroupNorm | |
normalized = self.norm(x) | |
# Get gamma and beta from the embedding | |
gamma = self.gamma_proj(emb) | |
beta = self.beta_proj(emb) | |
# Reshape for broadcasting: [B, C] -> [B, C, 1, 1] | |
gamma = gamma.view(-1, self.num_channels, 1, 1) | |
beta = beta.view(-1, self.num_channels, 1, 1) | |
# Apply adaptive scaling and shifting | |
return gamma * normalized + beta | |
class DepthwiseSeparableConv2d(nn.Module): | |
def __init__(self, dim_in, dim_out, kernel_size, padding): | |
super().__init__() | |
self.depthwise = nn.Conv2d(dim_in, dim_in, kernel_size, padding=padding, groups=dim_in) | |
self.pointwise = nn.Conv2d(dim_in, dim_out, 1) # 1x1 convolution | |
def forward(self, x): | |
x = self.depthwise(x) | |
x = self.pointwise(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, dim, dim_out, groups, emb_dim, dropout=0.0, use_depthwise=False): | |
super().__init__() | |
self.norm = AdaptiveGroupNorm(groups, dim, emb_dim) | |
if use_depthwise: | |
self.proj = DepthwiseSeparableConv2d(dim, dim_out, kernel_size=3, padding=1) | |
else: | |
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) | |
self.act = nn.SiLU() | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, emb): | |
x = self.norm(x, emb) # Pre-normalization | |
x = self.proj(x) | |
x = self.act(x) | |
return self.dropout(x) | |
class ResnetBlock(nn.Module): | |
def __init__(self, dim: int, dim_out: int, t_emb_dim: int, *, | |
y_emb_dim: int = None, groups: int = 32, dropout: float = 0.0, residual_scale=1.0): | |
super().__init__() | |
if y_emb_dim is None: | |
y_emb_dim = 0 | |
emb_dim = t_emb_dim + y_emb_dim | |
self.block1 = Block(dim, dim_out, groups, emb_dim, dropout) # Pass emb_dim | |
self.block2 = Block(dim_out, dim_out, groups, emb_dim, dropout) # Pass emb_dim | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
self.residual_scale = nn.Parameter(torch.tensor(residual_scale)) | |
def forward(self, x, t_emb, y_emb=None): | |
cond_emb = t_emb | |
if y_emb is not None: | |
cond_emb = torch.cat([cond_emb, y_emb], dim=-1) | |
h = self.block1(x, cond_emb) # Pass combined embedding to Block | |
h = self.block2(h, cond_emb) # Pass combined embedding to Block | |
return self.residual_scale * h + self.res_conv(x) # Scale the residual |