Spaces:
Running
Running
# https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/model.py#L368 | |
import torch | |
import torch.nn as nn | |
from auto_encoder.components.normalize import Normalize | |
from auto_encoder.components.nonlinearity import nonlinearity | |
class ResnetBlock(nn.Module): | |
def __init__(self, *, in_channels : int, out_channels : int = None, conv_shortcut=False, dropout): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.use_conv_shortcut = conv_shortcut | |
self.norm1 = Normalize(in_channels) | |
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1) | |
self.norm2 = Normalize(out_channels) | |
self.dropout = torch.nn.Dropout(dropout) | |
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1) | |
if self.in_channels != self.out_channels: | |
if self.use_conv_shortcut: | |
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
else: | |
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
def forward(self, x): | |
h = x | |
h = self.norm1(h) | |
h = nonlinearity(h) | |
h = self.conv1(h) | |
h = self.norm2(h) | |
h = nonlinearity(h) | |
h = self.dropout(h) | |
h = self.conv2(h) | |
if self.in_channels != self.out_channels: | |
if self.use_conv_shortcut: | |
x = self.conv_shortcut(x) | |
else: | |
x = self.nin_shortcut(x) | |
return x+h | |