huaweilin's picture
update
14ce5a9
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
class Conv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
cnn_type="2d",
causal_offset=0,
temporal_down=False,
):
super().__init__()
self.cnn_type = cnn_type
self.slice_seq_len = 17
if cnn_type == "2d":
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=padding
)
if cnn_type == "3d":
if temporal_down == False:
stride = (1, stride, stride)
else:
stride = (stride, stride, stride)
self.conv = nn.Conv3d(
in_channels, out_channels, kernel_size, stride=stride, padding=0
)
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
self.padding = (
kernel_size[0] - 1 + causal_offset, # Temporal causal padding
padding, # Height padding
padding, # Width padding
)
self.causal_offset = causal_offset
self.stride = stride
self.kernel_size = kernel_size
def forward(self, x):
if self.cnn_type == "2d":
if x.ndim == 5:
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.conv(x)
x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
return x
else:
return self.conv(x)
if self.cnn_type == "3d":
assert (
self.stride[0] == 1 or self.stride[0] == 2
), f"only temporal stride = 1 or 2 are supported"
xs = []
for i in range(0, x.shape[2], self.slice_seq_len + self.stride[0] - 1):
st = i
en = min(i + self.slice_seq_len, x.shape[2])
_x = x[:, :, st:en, :, :]
if i == 0:
_x = F.pad(
_x,
(
self.padding[2],
self.padding[2], # Width
self.padding[1],
self.padding[1], # Height
self.padding[0],
0,
),
) # Temporal
else:
padding_0 = self.kernel_size[0] - 1
_x = F.pad(
_x,
(
self.padding[2],
self.padding[2], # Width
self.padding[1],
self.padding[1], # Height
padding_0,
0,
),
) # Temporal
_x[
:,
:,
:padding_0,
self.padding[1] : _x.shape[-2] - self.padding[1],
self.padding[2] : _x.shape[-1] - self.padding[2],
] += x[:, :, i - padding_0 : i, :, :]
_x = self.conv(_x)
xs.append(_x)
try:
x = torch.cat(xs, dim=2)
except:
device = x.device
del x
xs = [_x.cpu().pin_memory() for _x in xs]
torch.cuda.empty_cache()
x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
return x