|
import torch.nn as nn
|
|
from einops import rearrange
|
|
|
|
|
|
class PixelShuffleND(nn.Module):
|
|
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
|
super().__init__()
|
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
|
self.dims = dims
|
|
self.upscale_factors = upscale_factors
|
|
|
|
def forward(self, x):
|
|
if self.dims == 3:
|
|
return rearrange(
|
|
x,
|
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
|
p1=self.upscale_factors[0],
|
|
p2=self.upscale_factors[1],
|
|
p3=self.upscale_factors[2],
|
|
)
|
|
elif self.dims == 2:
|
|
return rearrange(
|
|
x,
|
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
|
p1=self.upscale_factors[0],
|
|
p2=self.upscale_factors[1],
|
|
)
|
|
elif self.dims == 1:
|
|
return rearrange(
|
|
x,
|
|
"b (c p1) f h w -> b c (f p1) h w",
|
|
p1=self.upscale_factors[0],
|
|
)
|
|
|