|
import torch.nn as nn |
|
from einops import rearrange |
|
|
|
|
|
class PixelShuffleND(nn.Module): |
|
def __init__(self, dims, upscale_factors=(2, 2, 2)): |
|
super().__init__() |
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3" |
|
self.dims = dims |
|
self.upscale_factors = upscale_factors |
|
|
|
def forward(self, x): |
|
if self.dims == 3: |
|
return rearrange( |
|
x, |
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", |
|
p1=self.upscale_factors[0], |
|
p2=self.upscale_factors[1], |
|
p3=self.upscale_factors[2], |
|
) |
|
elif self.dims == 2: |
|
return rearrange( |
|
x, |
|
"b (c p1 p2) h w -> b c (h p1) (w p2)", |
|
p1=self.upscale_factors[0], |
|
p2=self.upscale_factors[1], |
|
) |
|
elif self.dims == 1: |
|
return rearrange( |
|
x, |
|
"b (c p1) f h w -> b c (f p1) h w", |
|
p1=self.upscale_factors[0], |
|
) |
|
|