Spaces:
Running
on
Zero
Running
on
Zero
File size: 50,678 Bytes
42f2c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 |
# Copyright (c) 2023 HuggingFace Team
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache License, Version 2.0 (the "License")
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025
#
# Original file was released under Apache License, Version 2.0 (the "License"), with the full license text
# available at http://www.apache.org/licenses/LICENSE-2.0.
#
# This modified file is released under the same license.
from contextlib import nullcontext
from typing import Literal, Optional, Tuple, Union
import diffusers
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, SpatialNorm
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.downsampling import Downsample2D
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.unets.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D
from diffusers.models.upsampling import Upsample2D
from diffusers.utils import is_torch_version
from diffusers.utils.accelerate_utils import apply_forward_hook
from einops import rearrange
from common.distributed.advanced import get_sequence_parallel_world_size
from common.logger import get_logger
from models.video_vae_v3.modules.causal_inflation_lib import (
InflatedCausalConv3d,
causal_norm_wrapper,
init_causal_conv3d,
remove_head,
)
from models.video_vae_v3.modules.context_parallel_lib import (
causal_conv_gather_outputs,
causal_conv_slice_inputs,
)
from models.video_vae_v3.modules.global_config import set_norm_limit
from models.video_vae_v3.modules.types import (
CausalAutoencoderOutput,
CausalDecoderOutput,
CausalEncoderOutput,
MemoryState,
_inflation_mode_t,
_memory_device_t,
_receptive_field_t,
)
logger = get_logger(__name__) # pylint: disable=invalid-name
class Upsample3D(Upsample2D):
"""A 3D upsampling layer with an optional convolution."""
def __init__(
self,
*args,
inflation_mode: _inflation_mode_t = "tail",
temporal_up: bool = False,
spatial_up: bool = True,
slicing: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
conv = self.conv if self.name == "conv" else self.Conv2d_0
assert type(conv) is not nn.ConvTranspose2d
# Note: lora_layer is not passed into constructor in the original implementation.
# So we make a simplification.
conv = init_causal_conv3d(
self.channels,
self.out_channels,
3,
padding=1,
inflation_mode=inflation_mode,
)
self.temporal_up = temporal_up
self.spatial_up = spatial_up
self.temporal_ratio = 2 if temporal_up else 1
self.spatial_ratio = 2 if spatial_up else 1
self.slicing = slicing
assert not self.interpolate
# [Override] MAGViT v2 implementation
if not self.interpolate:
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio
self.upscale_conv = nn.Conv3d(
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0
)
identity = (
torch.eye(self.channels)
.repeat(upscale_ratio, 1)
.reshape_as(self.upscale_conv.weight)
)
self.upscale_conv.weight.data.copy_(identity)
nn.init.zeros_(self.upscale_conv.bias)
if self.name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
memory_state: MemoryState = MemoryState.DISABLED,
**kwargs,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if hasattr(self, "norm") and self.norm is not None:
# [Overridden] change to causal norm.
hidden_states = causal_norm_wrapper(self.norm, hidden_states)
if self.use_conv_transpose:
return self.conv(hidden_states)
if self.slicing:
split_size = hidden_states.size(2) // 2
hidden_states = list(
hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2)
)
else:
hidden_states = [hidden_states]
for i in range(len(hidden_states)):
hidden_states[i] = self.upscale_conv(hidden_states[i])
hidden_states[i] = rearrange(
hidden_states[i],
"b (x y z c) f h w -> b c (f z) (h x) (w y)",
x=self.spatial_ratio,
y=self.spatial_ratio,
z=self.temporal_ratio,
)
# [Overridden] For causal temporal conv
if self.temporal_up and memory_state != MemoryState.ACTIVE:
hidden_states[0] = remove_head(hidden_states[0])
if not self.slicing:
hidden_states = hidden_states[0]
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states, memory_state=memory_state)
else:
hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state)
if not self.slicing:
return hidden_states
else:
return torch.cat(hidden_states, dim=2)
class Downsample3D(Downsample2D):
"""A 3D downsampling layer with an optional convolution."""
def __init__(
self,
*args,
inflation_mode: _inflation_mode_t = "tail",
spatial_down: bool = False,
temporal_down: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
conv = self.conv
self.temporal_down = temporal_down
self.spatial_down = spatial_down
self.temporal_ratio = 2 if temporal_down else 1
self.spatial_ratio = 2 if spatial_down else 1
self.temporal_kernel = 3 if temporal_down else 1
self.spatial_kernel = 3 if spatial_down else 1
if type(conv) in [nn.Conv2d, LoRACompatibleConv]:
# Note: lora_layer is not passed into constructor in the original implementation.
# So we make a simplification.
conv = init_causal_conv3d(
self.channels,
self.out_channels,
kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel),
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
padding=(
1 if self.temporal_down else 0,
self.padding if self.spatial_down else 0,
self.padding if self.spatial_down else 0,
),
inflation_mode=inflation_mode,
)
elif type(conv) is nn.AvgPool2d:
assert self.channels == self.out_channels
conv = nn.AvgPool3d(
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
)
else:
raise NotImplementedError
if self.name == "conv":
self.Conv2d_0 = conv
self.conv = conv
else:
self.conv = conv
def forward(
self,
hidden_states: torch.FloatTensor,
memory_state: MemoryState = MemoryState.DISABLED,
**kwargs,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if hasattr(self, "norm") and self.norm is not None:
# [Overridden] change to causal norm.
hidden_states = causal_norm_wrapper(self.norm, hidden_states)
if self.use_conv and self.padding == 0 and self.spatial_down:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, memory_state=memory_state)
return hidden_states
class ResnetBlock3D(ResnetBlock2D):
def __init__(
self,
*args,
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half",
slicing: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.conv1 = init_causal_conv3d(
self.in_channels,
self.out_channels,
kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3),
stride=1,
padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1),
inflation_mode=inflation_mode,
)
self.conv2 = init_causal_conv3d(
self.out_channels,
self.conv2.out_channels,
kernel_size=3,
stride=1,
padding=1,
inflation_mode=inflation_mode,
)
if self.up:
assert type(self.upsample) is Upsample2D
self.upsample = Upsample3D(
self.in_channels,
use_conv=False,
inflation_mode=inflation_mode,
slicing=slicing,
)
elif self.down:
assert type(self.downsample) is Downsample2D
self.downsample = Downsample3D(
self.in_channels,
use_conv=False,
padding=1,
name="op",
inflation_mode=inflation_mode,
)
if self.use_in_shortcut:
self.conv_shortcut = init_causal_conv3d(
self.in_channels,
self.conv_shortcut.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=(self.conv_shortcut.bias is not None),
inflation_mode=inflation_mode,
)
def forward(
self, input_tensor, temb, memory_state: MemoryState = MemoryState.DISABLED, **kwargs
):
hidden_states = input_tensor
hidden_states = causal_norm_wrapper(self.norm1, hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes.
# see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor, memory_state=memory_state)
hidden_states = self.upsample(hidden_states, memory_state=memory_state)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor, memory_state=memory_state)
hidden_states = self.downsample(hidden_states, memory_state=memory_state)
hidden_states = self.conv1(hidden_states, memory_state=memory_state)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = causal_norm_wrapper(self.norm2, hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, memory_state=memory_state)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class DownEncoderBlock3D(DownEncoderBlock2D):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half",
temporal_down: bool = True,
spatial_down: bool = True,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
num_layers=num_layers,
resnet_eps=resnet_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_pre_norm=resnet_pre_norm,
output_scale_factor=output_scale_factor,
add_downsample=add_downsample,
downsample_padding=downsample_padding,
)
resnets = []
temporal_modules = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
# [Override] Replace module.
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
)
temporal_modules.append(nn.Identity())
self.resnets = nn.ModuleList(resnets)
self.temporal_modules = nn.ModuleList(temporal_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
# [Override] Replace module.
Downsample3D(
out_channels,
use_conv=True,
out_channels=out_channels,
padding=downsample_padding,
name="op",
temporal_down=temporal_down,
spatial_down=spatial_down,
inflation_mode=inflation_mode,
)
]
)
else:
self.downsamplers = None
def forward(
self,
hidden_states: torch.FloatTensor,
memory_state: MemoryState = MemoryState.DISABLED,
**kwargs,
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
hidden_states = temporal(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, memory_state=memory_state)
return hidden_states
class UpDecoderBlock3D(UpDecoderBlock2D):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
temb_channels: Optional[int] = None,
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half",
temporal_up: bool = True,
spatial_up: bool = True,
slicing: bool = False,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
num_layers=num_layers,
resnet_eps=resnet_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_pre_norm=resnet_pre_norm,
output_scale_factor=output_scale_factor,
add_upsample=add_upsample,
temb_channels=temb_channels,
)
resnets = []
temporal_modules = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
# [Override] Replace module.
ResnetBlock3D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
slicing=slicing,
)
)
temporal_modules.append(nn.Identity())
self.resnets = nn.ModuleList(resnets)
self.temporal_modules = nn.ModuleList(temporal_modules)
if add_upsample:
# [Override] Replace module & use learnable upsample
self.upsamplers = nn.ModuleList(
[
Upsample3D(
out_channels,
use_conv=True,
out_channels=out_channels,
temporal_up=temporal_up,
spatial_up=spatial_up,
interpolate=False,
inflation_mode=inflation_mode,
slicing=slicing,
)
]
)
else:
self.upsamplers = None
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
memory_state: MemoryState = MemoryState.DISABLED,
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
hidden_states = temporal(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, memory_state=memory_state)
return hidden_states
class UNetMidBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half",
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
# [Override] Replace module.
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
]
attentions = []
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. "
f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=(
resnet_groups if resnet_time_scale_shift == "default" else None
),
spatial_norm_dim=(
temb_channels if resnet_time_scale_shift == "spatial" else None
),
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, memory_state: MemoryState = MemoryState.DISABLED):
video_length, frame_height, frame_width = hidden_states.size()[-3:]
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
hidden_states = attn(hidden_states, temb=temb)
hidden_states = rearrange(
hidden_states, "(b f) c h w -> b c f h w", f=video_length
)
hidden_states = resnet(hidden_states, temb, memory_state=memory_state)
return hidden_states
class Encoder3D(nn.Module):
r"""
[Override] override most logics to support extra condition input and causal conv
The `Encoder` layer of a variational autoencoder that encodes
its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use.
See `~diffusers.models.unet_2d_blocks.get_down_block`
for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use.
See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
# [Override] add extra_cond_dim, temporal down num
temporal_down_num: int = 2,
extra_cond_dim: int = None,
gradient_checkpoint: bool = False,
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half",
):
super().__init__()
self.layers_per_block = layers_per_block
self.temporal_down_num = temporal_down_num
self.conv_in = init_causal_conv3d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
inflation_mode=inflation_mode,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
self.extra_cond_dim = extra_cond_dim
self.conv_extra_cond = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
# [Override] to support temporal down block design
is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1
# Note: take the last ones
assert down_block_type == "DownEncoderBlock3D"
down_block = DownEncoderBlock3D(
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
downsample_padding=0,
# Note: Don't know why set it as 0
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
temporal_down=is_temporal_down_block,
spatial_down=True,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
self.down_blocks.append(down_block)
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
self.conv_extra_cond.append(
zero_module(
nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0)
)
if self.extra_cond_dim is not None and self.extra_cond_dim > 0
else None
)
# mid
self.mid_block = UNetMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = init_causal_conv3d(
block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode
)
self.gradient_checkpointing = gradient_checkpoint
def forward(
self,
sample: torch.FloatTensor,
extra_cond=None,
memory_state: MemoryState = MemoryState.DISABLED,
) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = self.conv_in(sample, memory_state=memory_state)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
# [Override] add extra block and extra cond
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, memory_state, use_reentrant=False
)
if extra_block is not None:
sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:])
# middle
sample = self.mid_block(sample, memory_state=memory_state)
# sample = torch.utils.checkpoint.checkpoint(
# create_custom_forward(self.mid_block), sample, use_reentrant=False
# )
else:
# down
# [Override] add extra block and extra cond
for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond):
sample = down_block(sample, memory_state=memory_state)
if extra_block is not None:
sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:])
# middle
sample = self.mid_block(sample, memory_state=memory_state)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state=memory_state)
return sample
class Decoder3D(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that
decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use.
See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use.
See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
# [Override] add temporal up block
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "half",
temporal_up_num: int = 2,
slicing_up_num: int = 0,
gradient_checkpoint: bool = False,
):
super().__init__()
self.layers_per_block = layers_per_block
self.temporal_up_num = temporal_up_num
self.conv_in = init_causal_conv3d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
inflation_mode=inflation_mode,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
print(f"slicing_up_num: {slicing_up_num}")
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
is_temporal_up_block = i < self.temporal_up_num
is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num
# Note: Keep symmetric
assert up_block_type == "UpDecoderBlock3D"
up_block = UpDecoderBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=norm_type,
temb_channels=temb_channels,
temporal_up=is_temporal_up_block,
slicing=is_slicing_up_block,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
self.conv_out = init_causal_conv3d(
block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode
)
self.gradient_checkpointing = gradient_checkpoint
# Note: Just copy from Decoder.
def forward(
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
memory_state: MemoryState = MemoryState.DISABLED,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample, memory_state=memory_state)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
memory_state,
use_reentrant=False,
)
else:
# middle
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, memory_state
)
else:
# middle
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds, memory_state=memory_state)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state=memory_state)
return sample
class AutoencoderKL(diffusers.AutoencoderKL):
"""
We simply inherit the model code from diffusers
"""
def __init__(self, attention: bool = True, *args, **kwargs):
super().__init__(*args, **kwargs)
# A hacky way to remove attention.
if not attention:
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
def load_state_dict(self, state_dict, strict=True):
# Newer version of diffusers changed the model keys,
# causing incompatibility with old checkpoints.
# They provided a method for conversion. We call conversion before loading state_dict.
convert_deprecated_attention_blocks = getattr(
self, "_convert_deprecated_attention_blocks", None
)
if callable(convert_deprecated_attention_blocks):
convert_deprecated_attention_blocks(state_dict)
return super().load_state_dict(state_dict, strict)
class VideoAutoencoderKL(diffusers.AutoencoderKL):
"""
We simply inherit the model code from diffusers
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock3D",),
up_block_types: Tuple[str] = ("UpDecoderBlock3D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
force_upcast: float = True,
attention: bool = True,
temporal_scale_num: int = 2,
slicing_up_num: int = 0,
gradient_checkpoint: bool = False,
inflation_mode: _inflation_mode_t = "tail",
time_receptive_field: _receptive_field_t = "full",
slicing_sample_min_size: int = 32,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
*args,
**kwargs,
):
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
self.slicing_sample_min_size = slicing_sample_min_size
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
# [Override] make sure it can be normally initialized
down_block_types=tuple(
[down_block_type.replace("3D", "2D") for down_block_type in down_block_types]
),
up_block_types=tuple(
[up_block_type.replace("3D", "2D") for up_block_type in up_block_types]
),
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
latent_channels=latent_channels,
norm_num_groups=norm_num_groups,
sample_size=sample_size,
scaling_factor=scaling_factor,
force_upcast=force_upcast,
*args,
**kwargs,
)
# pass init params to Encoder
self.encoder = Encoder3D(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
extra_cond_dim=extra_cond_dim,
# [Override] add temporal_down_num parameter
temporal_down_num=temporal_scale_num,
gradient_checkpoint=gradient_checkpoint,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
# pass init params to Decoder
self.decoder = Decoder3D(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
# [Override] add temporal_up_num parameter
temporal_up_num=temporal_scale_num,
slicing_up_num=slicing_up_num,
gradient_checkpoint=gradient_checkpoint,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
self.quant_conv = (
init_causal_conv3d(
in_channels=2 * latent_channels,
out_channels=2 * latent_channels,
kernel_size=1,
inflation_mode=inflation_mode,
)
if use_quant_conv
else None
)
self.post_quant_conv = (
init_causal_conv3d(
in_channels=latent_channels,
out_channels=latent_channels,
kernel_size=1,
inflation_mode=inflation_mode,
)
if use_post_quant_conv
else None
)
# A hacky way to remove attention.
if not attention:
self.encoder.mid_block.attentions = torch.nn.ModuleList([None])
self.decoder.mid_block.attentions = torch.nn.ModuleList([None])
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.slicing_encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@apply_forward_hook
def decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
decoded = self.slicing_decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def _encode(
self, x: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED
) -> torch.Tensor:
_x = x.to(self.device)
_x = causal_conv_slice_inputs(_x, self.slicing_sample_min_size, memory_state=memory_state)
h = self.encoder(_x, memory_state=memory_state)
if self.quant_conv is not None:
output = self.quant_conv(h, memory_state=memory_state)
else:
output = h
output = causal_conv_gather_outputs(output)
return output.to(x.device)
def _decode(
self, z: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED
) -> torch.Tensor:
_z = z.to(self.device)
_z = causal_conv_slice_inputs(_z, self.slicing_latent_min_size, memory_state=memory_state)
if self.post_quant_conv is not None:
_z = self.post_quant_conv(_z, memory_state=memory_state)
output = self.decoder(_z, memory_state=memory_state)
output = causal_conv_gather_outputs(output)
return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size()
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2)
encoded_slices = [
self._encode(
torch.cat((x[:, :, :1], x_slices[0]), dim=2),
memory_state=MemoryState.INITIALIZING,
)
]
for x_idx in range(1, len(x_slices)):
encoded_slices.append(
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE)
)
return torch.cat(encoded_slices, dim=2)
else:
return self._encode(x)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = get_sequence_parallel_world_size()
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
decoded_slices = [
self._decode(
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
memory_state=MemoryState.INITIALIZING,
)
]
for z_idx in range(1, len(z_slices)):
decoded_slices.append(
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE)
)
return torch.cat(decoded_slices, dim=2)
else:
return self._decode(z)
def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
raise NotImplementedError
def forward(
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs
):
# x: [b c t h w]
if mode == "encode":
h = self.encode(x)
return h.latent_dist
elif mode == "decode":
h = self.decode(x)
return h.sample
else:
h = self.encode(x)
h = self.decode(h.latent_dist.mode())
return h.sample
def load_state_dict(self, state_dict, strict=False):
# Newer version of diffusers changed the model keys,
# causing incompatibility with old checkpoints.
# They provided a method for conversion.
# We call conversion before loading state_dict.
convert_deprecated_attention_blocks = getattr(
self, "_convert_deprecated_attention_blocks", None
)
if callable(convert_deprecated_attention_blocks):
convert_deprecated_attention_blocks(state_dict)
return super().load_state_dict(state_dict, strict)
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
def __init__(
self,
*args,
spatial_downsample_factor: int,
temporal_downsample_factor: int,
freeze_encoder: bool,
**kwargs,
):
self.spatial_downsample_factor = spatial_downsample_factor
self.temporal_downsample_factor = temporal_downsample_factor
self.freeze_encoder = freeze_encoder
super().__init__(*args, **kwargs)
def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput:
with torch.no_grad() if self.freeze_encoder else nullcontext():
z, p = self.encode(x)
x = self.decode(z).sample
return CausalAutoencoderOutput(x, z, p)
def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput:
if x.ndim == 4:
x = x.unsqueeze(2)
p = super().encode(x).latent_dist
z = p.sample().squeeze(2)
return CausalEncoderOutput(z, p)
def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput:
if z.ndim == 4:
z = z.unsqueeze(2)
x = super().decode(z).sample.squeeze(2)
return CausalDecoderOutput(x)
def preprocess(self, x: torch.Tensor):
# x should in [B, C, T, H, W], [B, C, H, W]
assert x.ndim == 4 or x.size(2) % 4 == 1
return x
def postprocess(self, x: torch.Tensor):
# x should in [B, C, T, H, W], [B, C, H, W]
return x
def set_causal_slicing(
self,
*,
split_size: Optional[int],
memory_device: _memory_device_t,
):
assert (
split_size is None or memory_device is not None
), "if split_size is set, memory_device must not be None."
if split_size is not None:
self.enable_slicing()
self.slicing_sample_min_size = split_size
self.slicing_latent_min_size = split_size // self.temporal_downsample_factor
else:
self.disable_slicing()
for module in self.modules():
if isinstance(module, InflatedCausalConv3d):
module.set_memory_device(memory_device)
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
set_norm_limit(norm_max_mem)
for m in self.modules():
if isinstance(m, InflatedCausalConv3d):
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
|