qitaoz's picture
Upload 57 files
4562a06 verified
import torch
import torch.nn as nn
from diffusionsfm.model.dit import TimestepEmbedder
import ipdb
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(
-1
)
def _make_fusion_block(features, use_bn, use_ln, dpt_time, resolution):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
dpt_time=dpt_time,
ln=use_ln,
resolution=resolution
)
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand == True:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
return scratch
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn, ln, dpt_time=False, resolution=16):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.ln = ln
self.groups = 1
self.conv1 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
self.conv2 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
nn.init.kaiming_uniform_(self.conv1.weight)
nn.init.kaiming_uniform_(self.conv2.weight)
if self.bn == True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
if self.ln == True:
self.bn1 = nn.LayerNorm((features, resolution, resolution))
self.bn2 = nn.LayerNorm((features, resolution, resolution))
self.activation = activation
if dpt_time:
self.t_embedder = TimestepEmbedder(hidden_size=features)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(features, 3 * features, bias=True)
)
def forward(self, x, t=None):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
if t is not None:
# Embed timestamp & calculate shift parameters
t = self.t_embedder(t) # (B*N)
shift, scale, gate = self.adaLN_modulation(t).chunk(3, dim=1) # (B * N, T)
# Shift & scale x
x = modulate(x, shift, scale) # (B * N, T, H, W)
out = self.activation(x)
out = self.conv1(out)
if self.bn or self.ln:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn or self.ln:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
if t is not None:
out = gate.unsqueeze(-1).unsqueeze(-1) * out
return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block."""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
ln=False,
expand=False,
align_corners=True,
dpt_time=False,
resolution=16,
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(
features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
groups=1,
)
nn.init.kaiming_uniform_(self.out_conv.weight)
# The second block sees time
self.resConfUnit1 = ResidualConvUnit_custom(
features, activation, bn=bn, ln=ln, dpt_time=False, resolution=resolution
)
self.resConfUnit2 = ResidualConvUnit_custom(
features, activation, bn=bn, ln=ln, dpt_time=dpt_time, resolution=resolution
)
def forward(self, input, activation=None, t=None):
"""Forward pass.
Returns:
tensor: output
"""
output = input
if activation is not None:
res = self.resConfUnit1(activation)
output += res
output = self.resConfUnit2(output, t)
output = torch.nn.functional.interpolate(
output.float(),
scale_factor=2,
mode="bilinear",
align_corners=self.align_corners,
)
output = self.out_conv(output)
return output