Spaces:
Running
Running
File size: 3,415 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import torch.nn as nn
class AdaptiveGroupNorm(nn.Module):
def __init__(self, num_groups, num_channels, emb_dim, eps=1e-5):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
# Use a standard GroupNorm, but without learnable affine parameters
self.norm = nn.GroupNorm(num_groups, num_channels, eps=eps, affine=False)
# Linear layers to project the embedding to gamma and beta
self.gamma_proj = nn.Linear(emb_dim, num_channels)
self.beta_proj = nn.Linear(emb_dim, num_channels)
def forward(self, x, emb):
"""
Args:
x: Input tensor of shape [B, C, H, W].
emb: Embedding tensor of shape [B, emb_dim].
Returns:
Normalized tensor with adaptive scaling and shifting.
"""
# Normalize as usual with GroupNorm
normalized = self.norm(x)
# Get gamma and beta from the embedding
gamma = self.gamma_proj(emb)
beta = self.beta_proj(emb)
# Reshape for broadcasting: [B, C] -> [B, C, 1, 1]
gamma = gamma.view(-1, self.num_channels, 1, 1)
beta = beta.view(-1, self.num_channels, 1, 1)
# Apply adaptive scaling and shifting
return gamma * normalized + beta
class DepthwiseSeparableConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding):
super().__init__()
self.depthwise = nn.Conv2d(dim_in, dim_in, kernel_size, padding=padding, groups=dim_in)
self.pointwise = nn.Conv2d(dim_in, dim_out, 1) # 1x1 convolution
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self, dim, dim_out, groups, emb_dim, dropout=0.0, use_depthwise=False):
super().__init__()
self.norm = AdaptiveGroupNorm(groups, dim, emb_dim)
if use_depthwise:
self.proj = DepthwiseSeparableConv2d(dim, dim_out, kernel_size=3, padding=1)
else:
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
self.act = nn.SiLU()
self.dropout = nn.Dropout(dropout)
def forward(self, x, emb):
x = self.norm(x, emb) # Pre-normalization
x = self.proj(x)
x = self.act(x)
return self.dropout(x)
class ResnetBlock(nn.Module):
def __init__(self, dim: int, dim_out: int, t_emb_dim: int, *,
y_emb_dim: int = None, groups: int = 32, dropout: float = 0.0, residual_scale=1.0):
super().__init__()
if y_emb_dim is None:
y_emb_dim = 0
emb_dim = t_emb_dim + y_emb_dim
self.block1 = Block(dim, dim_out, groups, emb_dim, dropout) # Pass emb_dim
self.block2 = Block(dim_out, dim_out, groups, emb_dim, dropout) # Pass emb_dim
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
self.residual_scale = nn.Parameter(torch.tensor(residual_scale))
def forward(self, x, t_emb, y_emb=None):
cond_emb = t_emb
if y_emb is not None:
cond_emb = torch.cat([cond_emb, y_emb], dim=-1)
h = self.block1(x, cond_emb) # Pass combined embedding to Block
h = self.block2(h, cond_emb) # Pass combined embedding to Block
return self.residual_scale * h + self.res_conv(x) # Scale the residual |