Spaces:
Build error
Build error
| from typing import Optional, Tuple, Union | |
| from einops import rearrange | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.models.attention_processor import Attention | |
| from diffusers.models.resnet import ResnetBlock2D | |
| from diffusers.models.upsampling import Upsample2D | |
| from diffusers.models.downsampling import Downsample2D | |
| class TemporalConvBlock(nn.Module): | |
| """ | |
| Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: | |
| https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 | |
| """ | |
| def __init__(self, in_dim, out_dim=None, dropout=0.0, up_sample=False, down_sample=False, spa_stride=1): | |
| super().__init__() | |
| out_dim = out_dim or in_dim | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| spa_pad = int((spa_stride-1)*0.5) | |
| temp_pad = 0 | |
| self.temp_pad = temp_pad | |
| if down_sample: | |
| self.conv1 = nn.Sequential( | |
| nn.GroupNorm(32, in_dim), | |
| nn.SiLU(), | |
| nn.Conv3d(in_dim, out_dim, (2, spa_stride, spa_stride), stride=(2,1,1), padding=(0, spa_pad, spa_pad)) | |
| ) | |
| elif up_sample: | |
| self.conv1 = nn.Sequential( | |
| nn.GroupNorm(32, in_dim), | |
| nn.SiLU(), | |
| nn.Conv3d(in_dim, out_dim*2, (1, spa_stride, spa_stride), padding=(0, spa_pad, spa_pad)) | |
| ) | |
| else: | |
| self.conv1 = nn.Sequential( | |
| nn.GroupNorm(32, in_dim), | |
| nn.SiLU(), | |
| nn.Conv3d(in_dim, out_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)) | |
| ) | |
| self.conv2 = nn.Sequential( | |
| nn.GroupNorm(32, out_dim), | |
| nn.SiLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), | |
| ) | |
| self.conv3 = nn.Sequential( | |
| nn.GroupNorm(32, out_dim), | |
| nn.SiLU(), | |
| nn.Dropout(dropout), | |
| nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), | |
| ) | |
| self.conv4 = nn.Sequential( | |
| nn.GroupNorm(32, out_dim), | |
| nn.SiLU(), | |
| nn.Conv3d(out_dim, in_dim, (3, spa_stride, spa_stride), padding=(temp_pad, spa_pad, spa_pad)), | |
| ) | |
| # zero out the last layer params,so the conv block is identity | |
| nn.init.zeros_(self.conv4[-1].weight) | |
| nn.init.zeros_(self.conv4[-1].bias) | |
| self.down_sample = down_sample | |
| self.up_sample = up_sample | |
| def forward(self, hidden_states): | |
| identity = hidden_states | |
| if self.down_sample: | |
| identity = identity[:,:,::2] | |
| elif self.up_sample: | |
| hidden_states_new = torch.cat((hidden_states,hidden_states),dim=2) | |
| hidden_states_new[:, :, 0::2] = hidden_states | |
| hidden_states_new[:, :, 1::2] = hidden_states | |
| identity = hidden_states_new | |
| del hidden_states_new | |
| if self.down_sample or self.up_sample: | |
| hidden_states = self.conv1(hidden_states) | |
| else: | |
| hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) | |
| hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) | |
| hidden_states = self.conv1(hidden_states) | |
| if self.up_sample: | |
| hidden_states = rearrange(hidden_states, 'b (d c) f h w -> b c (f d) h w', d=2) | |
| hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) | |
| hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) | |
| hidden_states = self.conv2(hidden_states) | |
| hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) | |
| hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) | |
| hidden_states = self.conv3(hidden_states) | |
| hidden_states = torch.cat((hidden_states[:,:,0:1], hidden_states), dim=2) | |
| hidden_states = torch.cat((hidden_states,hidden_states[:,:,-1:]), dim=2) | |
| hidden_states = self.conv4(hidden_states) | |
| hidden_states = identity + hidden_states | |
| return hidden_states | |
| class DownEncoderBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| output_scale_factor=1.0, | |
| add_downsample=True, | |
| add_temp_downsample=False, | |
| downsample_padding=1, | |
| ): | |
| super().__init__() | |
| resnets = [] | |
| temp_convs = [] | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=None, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| temp_convs.append( | |
| TemporalConvBlock( | |
| out_channels, | |
| out_channels, | |
| dropout=0.1, | |
| ) | |
| ) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.temp_convs = nn.ModuleList(temp_convs) | |
| if add_temp_downsample: | |
| self.temp_convs_down = TemporalConvBlock( | |
| out_channels, | |
| out_channels, | |
| dropout=0.1, | |
| down_sample=True, | |
| spa_stride=3 | |
| ) | |
| self.add_temp_downsample = add_temp_downsample | |
| if add_downsample: | |
| self.downsamplers = nn.ModuleList( | |
| [ | |
| Downsample2D( | |
| out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" | |
| ) | |
| ] | |
| ) | |
| else: | |
| self.downsamplers = None | |
| def _set_partial_grad(self): | |
| for temp_conv in self.temp_convs: | |
| temp_conv.requires_grad_(True) | |
| if self.downsamplers: | |
| for down_layer in self.downsamplers: | |
| down_layer.requires_grad_(True) | |
| def forward(self, hidden_states): | |
| bz = hidden_states.shape[0] | |
| for resnet, temp_conv in zip(self.resnets, self.temp_convs): | |
| hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') | |
| hidden_states = resnet(hidden_states, temb=None) | |
| hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) | |
| hidden_states = temp_conv(hidden_states) | |
| if self.add_temp_downsample: | |
| hidden_states = self.temp_convs_down(hidden_states) | |
| if self.downsamplers is not None: | |
| hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') | |
| for upsampler in self.downsamplers: | |
| hidden_states = upsampler(hidden_states) | |
| hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) | |
| return hidden_states | |
| class UpDecoderBlock3D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", # default, spatial | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| output_scale_factor=1.0, | |
| add_upsample=True, | |
| add_temp_upsample=False, | |
| temb_channels=None, | |
| ): | |
| super().__init__() | |
| self.add_upsample = add_upsample | |
| resnets = [] | |
| temp_convs = [] | |
| for i in range(num_layers): | |
| input_channels = in_channels if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock2D( | |
| in_channels=input_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| temp_convs.append( | |
| TemporalConvBlock( | |
| out_channels, | |
| out_channels, | |
| dropout=0.1, | |
| ) | |
| ) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.temp_convs = nn.ModuleList(temp_convs) | |
| self.add_temp_upsample = add_temp_upsample | |
| if add_temp_upsample: | |
| self.temp_conv_up = TemporalConvBlock( | |
| out_channels, | |
| out_channels, | |
| dropout=0.1, | |
| up_sample=True, | |
| spa_stride=3 | |
| ) | |
| if self.add_upsample: | |
| # self.upsamplers = nn.ModuleList([PSUpsample2D(out_channels, use_conv=True, use_pixel_shuffle=True, out_channels=out_channels)]) | |
| self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
| else: | |
| self.upsamplers = None | |
| def _set_partial_grad(self): | |
| for temp_conv in self.temp_convs: | |
| temp_conv.requires_grad_(True) | |
| if self.add_upsample: | |
| self.upsamplers.requires_grad_(True) | |
| def forward(self, hidden_states): | |
| bz = hidden_states.shape[0] | |
| for resnet, temp_conv in zip(self.resnets, self.temp_convs): | |
| hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') | |
| hidden_states = resnet(hidden_states, temb=None) | |
| hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) | |
| hidden_states = temp_conv(hidden_states) | |
| if self.add_temp_upsample: | |
| hidden_states = self.temp_conv_up(hidden_states) | |
| if self.upsamplers is not None: | |
| hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states) | |
| hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) | |
| return hidden_states | |
| class UNetMidBlock3DConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", # default, spatial | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| add_attention: bool = True, | |
| attention_head_dim=1, | |
| output_scale_factor=1.0, | |
| ): | |
| super().__init__() | |
| resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
| self.add_attention = add_attention | |
| # there is always at least one resnet | |
| resnets = [ | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ] | |
| temp_convs = [ | |
| TemporalConvBlock( | |
| in_channels, | |
| in_channels, | |
| dropout=0.1, | |
| ) | |
| ] | |
| attentions = [] | |
| if attention_head_dim is None: | |
| attention_head_dim = in_channels | |
| for _ in range(num_layers): | |
| if self.add_attention: | |
| attentions.append( | |
| Attention( | |
| in_channels, | |
| heads=in_channels // attention_head_dim, | |
| dim_head=attention_head_dim, | |
| rescale_output_factor=output_scale_factor, | |
| eps=resnet_eps, | |
| norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, | |
| spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, | |
| residual_connection=True, | |
| bias=True, | |
| upcast_softmax=True, | |
| _from_deprecated_attn_block=True, | |
| ) | |
| ) | |
| else: | |
| attentions.append(None) | |
| resnets.append( | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| temp_convs.append( | |
| TemporalConvBlock( | |
| in_channels, | |
| in_channels, | |
| dropout=0.1, | |
| ) | |
| ) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.temp_convs = nn.ModuleList(temp_convs) | |
| self.attentions = nn.ModuleList(attentions) | |
| def _set_partial_grad(self): | |
| for temp_conv in self.temp_convs: | |
| temp_conv.requires_grad_(True) | |
| def forward( | |
| self, | |
| hidden_states, | |
| ): | |
| bz = hidden_states.shape[0] | |
| hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') | |
| hidden_states = self.resnets[0](hidden_states, temb=None) | |
| hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) | |
| hidden_states = self.temp_convs[0](hidden_states) | |
| hidden_states = rearrange(hidden_states, 'b c n h w -> (b n) c h w') | |
| for attn, resnet, temp_conv in zip( | |
| self.attentions, self.resnets[1:], self.temp_convs[1:] | |
| ): | |
| hidden_states = attn(hidden_states) | |
| hidden_states = resnet(hidden_states, temb=None) | |
| hidden_states = rearrange(hidden_states, '(b n) c h w -> b c n h w', b=bz) | |
| hidden_states = temp_conv(hidden_states) | |
| return hidden_states | |