import torch import torch.nn as nn from einops import rearrange import torch.nn.functional as F class Conv(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False, ): super().__init__() self.cnn_type = cnn_type self.slice_seq_len = 17 if cnn_type == "2d": self.conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding ) if cnn_type == "3d": if temporal_down == False: stride = (1, stride, stride) else: stride = (stride, stride, stride) self.conv = nn.Conv3d( in_channels, out_channels, kernel_size, stride=stride, padding=0 ) if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) self.padding = ( kernel_size[0] - 1 + causal_offset, # Temporal causal padding padding, # Height padding padding, # Width padding ) self.causal_offset = causal_offset self.stride = stride self.kernel_size = kernel_size def forward(self, x): if self.cnn_type == "2d": if x.ndim == 5: B, C, T, H, W = x.shape x = rearrange(x, "B C T H W -> (B T) C H W") x = self.conv(x) x = rearrange(x, "(B T) C H W -> B C T H W", T=T) return x else: return self.conv(x) if self.cnn_type == "3d": assert ( self.stride[0] == 1 or self.stride[0] == 2 ), f"only temporal stride = 1 or 2 are supported" xs = [] for i in range(0, x.shape[2], self.slice_seq_len + self.stride[0] - 1): st = i en = min(i + self.slice_seq_len, x.shape[2]) _x = x[:, :, st:en, :, :] if i == 0: _x = F.pad( _x, ( self.padding[2], self.padding[2], # Width self.padding[1], self.padding[1], # Height self.padding[0], 0, ), ) # Temporal else: padding_0 = self.kernel_size[0] - 1 _x = F.pad( _x, ( self.padding[2], self.padding[2], # Width self.padding[1], self.padding[1], # Height padding_0, 0, ), ) # Temporal _x[ :, :, :padding_0, self.padding[1] : _x.shape[-2] - self.padding[1], self.padding[2] : _x.shape[-1] - self.padding[2], ] += x[:, :, i - padding_0 : i, :, :] _x = self.conv(_x) xs.append(_x) try: x = torch.cat(xs, dim=2) except: device = x.device del x xs = [_x.cpu().pin_memory() for _x in xs] torch.cuda.empty_cache() x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device) return x