Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class Upsample(nn.Module): | |
def __init__(self, in_channels : int, with_conv : bool): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1) | |
def forward(self, x): | |
x = torch.nn.functional.interpolate(x, scale_factor = 2.0, mode = "nearest") | |
if self.with_conv: | |
x = self.conv(x) | |
return x | |
class Downsample(nn.Module): | |
def __init__(self, in_channels : int, with_conv : bool): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 2, padding = 0) | |
def forward(self, x): | |
if self.with_conv: | |
pad = (0, 1, 0, 1) | |
x = torch.nn.functional.pad(x, pad, mode = "constant", value = 0) | |
x = self.conv(x) | |
else: | |
x = torch.nn.functional.avg_pool2d(x, kernel_size = 2, stride = 2) | |
return x |