|
import argparse |
|
import os |
|
import imageio |
|
import torch |
|
import numpy as np |
|
from einops import rearrange |
|
from torch import Tensor, nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
from torchvision import transforms |
|
from safetensors.torch import load_file |
|
import torch.utils.checkpoint as checkpoint |
|
|
|
from .conv import Conv |
|
from .multiscale_bsq import MultiScaleBSQ |
|
|
|
ptdtype = {None: torch.float32, "fp32": torch.float32, "bf16": torch.bfloat16} |
|
|
|
|
|
class Normalize(nn.Module): |
|
def __init__(self, in_channels, norm_type, norm_axis="spatial"): |
|
super().__init__() |
|
self.norm_axis = norm_axis |
|
assert norm_type in ["group", "batch", "no"] |
|
if norm_type == "group": |
|
if in_channels % 32 == 0: |
|
self.norm = nn.GroupNorm( |
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
elif in_channels % 24 == 0: |
|
self.norm = nn.GroupNorm( |
|
num_groups=24, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
else: |
|
raise NotImplementedError |
|
elif norm_type == "batch": |
|
self.norm = nn.SyncBatchNorm( |
|
in_channels, track_running_stats=False |
|
) |
|
elif norm_type == "no": |
|
self.norm = nn.Identity() |
|
|
|
def forward(self, x): |
|
if self.norm_axis == "spatial": |
|
if x.ndim == 4: |
|
x = self.norm(x) |
|
else: |
|
B, C, T, H, W = x.shape |
|
x = rearrange(x, "B C T H W -> (B T) C H W") |
|
x = self.norm(x) |
|
x = rearrange(x, "(B T) C H W -> B C T H W", T=T) |
|
elif self.norm_axis == "spatial-temporal": |
|
x = self.norm(x) |
|
else: |
|
raise NotImplementedError |
|
return x |
|
|
|
|
|
def swish(x: Tensor) -> Tensor: |
|
try: |
|
return x * torch.sigmoid(x) |
|
except: |
|
device = x.device |
|
x = x.cpu().pin_memory() |
|
return (x * torch.sigmoid(x)).to(device=device) |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
def __init__(self, in_channels, norm_type="group", cnn_param=None): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize( |
|
in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"] |
|
) |
|
|
|
self.q = Conv(in_channels, in_channels, kernel_size=1) |
|
self.k = Conv(in_channels, in_channels, kernel_size=1) |
|
self.v = Conv(in_channels, in_channels, kernel_size=1) |
|
self.proj_out = Conv(in_channels, in_channels, kernel_size=1) |
|
|
|
def attention(self, h_: Tensor) -> Tensor: |
|
B, _, T, _, _ = h_.shape |
|
h_ = self.norm(h_) |
|
h_ = rearrange(h_, "B C T H W -> (B T) C H W") |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
b, c, h, w = q.shape |
|
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() |
|
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() |
|
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() |
|
h_ = nn.functional.scaled_dot_product_attention(q, k, v) |
|
|
|
return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return x + self.proj_out(self.attention(x)) |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__( |
|
self, in_channels: int, out_channels: int, norm_type="group", cnn_param=None |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
|
|
self.norm1 = Normalize( |
|
in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"] |
|
) |
|
if cnn_param["res_conv_2d"] in ["half", "full"]: |
|
self.conv1 = Conv( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type="2d", |
|
) |
|
else: |
|
self.conv1 = Conv( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_param["cnn_type"], |
|
) |
|
self.norm2 = Normalize( |
|
out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"] |
|
) |
|
if cnn_param["res_conv_2d"] in ["full"]: |
|
self.conv2 = Conv( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type="2d", |
|
) |
|
else: |
|
self.conv2 = Conv( |
|
out_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_param["cnn_type"], |
|
) |
|
if self.in_channels != self.out_channels: |
|
self.nin_shortcut = Conv( |
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def forward(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = swish(h) |
|
h = self.conv1(h) |
|
|
|
h = self.norm2(h) |
|
h = swish(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
x = self.nin_shortcut(x) |
|
|
|
return x + h |
|
|
|
|
|
class Downsample(nn.Module): |
|
def __init__( |
|
self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False |
|
): |
|
super().__init__() |
|
assert spatial_down == True |
|
if cnn_type == "2d": |
|
self.pad = (0, 1, 0, 1) |
|
if cnn_type == "3d": |
|
self.pad = ( |
|
0, |
|
1, |
|
0, |
|
1, |
|
0, |
|
0, |
|
) |
|
|
|
self.conv = Conv( |
|
in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=0, |
|
cnn_type=cnn_type, |
|
temporal_down=temporal_down, |
|
) |
|
|
|
def forward(self, x: Tensor): |
|
x = nn.functional.pad(x, self.pad, mode="constant", value=0) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class Upsample(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
cnn_type="2d", |
|
spatial_up=False, |
|
temporal_up=False, |
|
use_pxsl=False, |
|
): |
|
super().__init__() |
|
if cnn_type == "2d": |
|
self.scale_factor = 2 |
|
self.causal_offset = 0 |
|
else: |
|
assert spatial_up == True |
|
if temporal_up: |
|
self.scale_factor = (2, 2, 2) |
|
self.causal_offset = -1 |
|
else: |
|
self.scale_factor = (1, 2, 2) |
|
self.causal_offset = 0 |
|
self.use_pxsl = use_pxsl |
|
if self.use_pxsl: |
|
self.conv = Conv( |
|
in_channels, |
|
in_channels * 4, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_type, |
|
causal_offset=self.causal_offset, |
|
) |
|
self.pxsl = nn.PixelShuffle(2) |
|
else: |
|
self.conv = Conv( |
|
in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_type, |
|
causal_offset=self.causal_offset, |
|
) |
|
|
|
def forward(self, x: Tensor): |
|
if self.use_pxsl: |
|
x = self.conv(x) |
|
x = self.pxsl(x) |
|
else: |
|
try: |
|
x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") |
|
except: |
|
|
|
_xs = [] |
|
for i in range(x.shape[1]): |
|
_x = F.interpolate( |
|
x[:, i : i + 1, ...], |
|
scale_factor=self.scale_factor, |
|
mode="nearest", |
|
) |
|
_xs.append(_x) |
|
x = torch.cat(_xs, dim=1) |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
ch: int, |
|
ch_mult: list[int], |
|
num_res_blocks: int, |
|
z_channels: int, |
|
in_channels=3, |
|
patch_size=8, |
|
temporal_patch_size=4, |
|
norm_type="group", |
|
cnn_param=None, |
|
use_checkpoint=False, |
|
use_vae=True, |
|
): |
|
super().__init__() |
|
self.max_down = np.log2(patch_size) |
|
self.temporal_max_down = np.log2(temporal_patch_size) |
|
self.temporal_down_offset = self.max_down - self.temporal_max_down |
|
self.ch = ch |
|
self.num_resolutions = len(ch_mult) |
|
self.num_res_blocks = num_res_blocks |
|
self.in_channels = in_channels |
|
self.cnn_param = cnn_param |
|
self.use_checkpoint = use_checkpoint |
|
|
|
|
|
|
|
if cnn_param["conv_in_out_2d"] == "yes": |
|
self.conv_in = Conv( |
|
in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d" |
|
) |
|
else: |
|
self.conv_in = Conv( |
|
in_channels, |
|
ch, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_param["cnn_type"], |
|
) |
|
|
|
in_ch_mult = (1,) + tuple(ch_mult) |
|
self.in_ch_mult = in_ch_mult |
|
self.down = nn.ModuleList() |
|
block_in = self.ch |
|
for i_level in range(self.num_resolutions): |
|
block = nn.ModuleList() |
|
attn = nn.ModuleList() |
|
block_in = ch * in_ch_mult[i_level] |
|
block_out = ch * ch_mult[i_level] |
|
for _ in range(self.num_res_blocks): |
|
block.append( |
|
ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_out, |
|
norm_type=norm_type, |
|
cnn_param=cnn_param, |
|
) |
|
) |
|
block_in = block_out |
|
down = nn.Module() |
|
down.block = block |
|
down.attn = attn |
|
|
|
spatial_down = True if i_level < self.max_down else False |
|
temporal_down = ( |
|
True |
|
if i_level < self.max_down and i_level >= self.temporal_down_offset |
|
else False |
|
) |
|
if spatial_down or temporal_down: |
|
down.downsample = Downsample( |
|
block_in, |
|
cnn_type=cnn_param["cnn_type"], |
|
spatial_down=spatial_down, |
|
temporal_down=temporal_down, |
|
) |
|
self.down.append(down) |
|
|
|
|
|
self.mid = nn.Module() |
|
self.mid.block_1 = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
norm_type=norm_type, |
|
cnn_param=cnn_param, |
|
) |
|
if cnn_param["cnn_attention"] == "yes": |
|
self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param) |
|
self.mid.block_2 = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
norm_type=norm_type, |
|
cnn_param=cnn_param, |
|
) |
|
|
|
|
|
self.norm_out = Normalize( |
|
block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"] |
|
) |
|
if cnn_param["conv_inner_2d"] == "yes": |
|
self.conv_out = Conv( |
|
block_in, |
|
(int(use_vae) + 1) * z_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type="2d", |
|
) |
|
else: |
|
self.conv_out = Conv( |
|
block_in, |
|
(int(use_vae) + 1) * z_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_param["cnn_type"], |
|
) |
|
|
|
def forward(self, x, return_hidden=False): |
|
if not self.use_checkpoint: |
|
return self._forward(x, return_hidden=return_hidden) |
|
else: |
|
return checkpoint.checkpoint( |
|
self._forward, x, return_hidden, use_reentrant=False |
|
) |
|
|
|
def _forward(self, x: Tensor, return_hidden=False) -> Tensor: |
|
|
|
h0 = self.conv_in(x) |
|
hs = [h0] |
|
for i_level in range(self.num_resolutions): |
|
for i_block in range(self.num_res_blocks): |
|
h = self.down[i_level].block[i_block](hs[-1]) |
|
if len(self.down[i_level].attn) > 0: |
|
h = self.down[i_level].attn[i_block](h) |
|
hs.append(h) |
|
if hasattr(self.down[i_level], "downsample"): |
|
hs.append(self.down[i_level].downsample(hs[-1])) |
|
|
|
|
|
h = hs[-1] |
|
hs_mid = [h] |
|
h = self.mid.block_1(h) |
|
if self.cnn_param["cnn_attention"] == "yes": |
|
h = self.mid.attn_1(h) |
|
h = self.mid.block_2(h) |
|
hs_mid.append(h) |
|
|
|
h = self.norm_out(h) |
|
h = swish(h) |
|
h = self.conv_out(h) |
|
if return_hidden: |
|
return h, hs, hs_mid |
|
else: |
|
return h |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
ch: int, |
|
ch_mult: list[int], |
|
num_res_blocks: int, |
|
z_channels: int, |
|
out_ch=3, |
|
patch_size=8, |
|
temporal_patch_size=4, |
|
norm_type="group", |
|
cnn_param=None, |
|
use_checkpoint=False, |
|
use_freq_dec=False, |
|
use_pxsf=False, |
|
): |
|
super().__init__() |
|
self.max_up = np.log2(patch_size) |
|
self.temporal_max_up = np.log2(temporal_patch_size) |
|
self.temporal_up_offset = self.max_up - self.temporal_max_up |
|
self.ch = ch |
|
self.num_resolutions = len(ch_mult) |
|
self.num_res_blocks = num_res_blocks |
|
self.ffactor = 2 ** (self.num_resolutions - 1) |
|
self.cnn_param = cnn_param |
|
self.use_checkpoint = use_checkpoint |
|
self.use_freq_dec = use_freq_dec |
|
self.use_pxsf = use_pxsf |
|
|
|
|
|
block_in = ch * ch_mult[self.num_resolutions - 1] |
|
|
|
|
|
if cnn_param["conv_inner_2d"] == "yes": |
|
self.conv_in = Conv( |
|
z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d" |
|
) |
|
else: |
|
self.conv_in = Conv( |
|
z_channels, |
|
block_in, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_param["cnn_type"], |
|
) |
|
|
|
|
|
self.mid = nn.Module() |
|
self.mid.block_1 = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
norm_type=norm_type, |
|
cnn_param=cnn_param, |
|
) |
|
if cnn_param["cnn_attention"] == "yes": |
|
self.mid.attn_1 = AttnBlock( |
|
block_in, norm_type=norm_type, cnn_param=cnn_param |
|
) |
|
self.mid.block_2 = ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
norm_type=norm_type, |
|
cnn_param=cnn_param, |
|
) |
|
|
|
|
|
self.up = nn.ModuleList() |
|
for i_level in reversed(range(self.num_resolutions)): |
|
block = nn.ModuleList() |
|
attn = nn.ModuleList() |
|
block_out = ch * ch_mult[i_level] |
|
for _ in range(self.num_res_blocks + 1): |
|
block.append( |
|
ResnetBlock( |
|
in_channels=block_in, |
|
out_channels=block_out, |
|
norm_type=norm_type, |
|
cnn_param=cnn_param, |
|
) |
|
) |
|
block_in = block_out |
|
up = nn.Module() |
|
up.block = block |
|
up.attn = attn |
|
|
|
|
|
spatial_up = True if 1 <= i_level <= self.max_up else False |
|
temporal_up = ( |
|
True |
|
if 1 <= i_level <= self.max_up |
|
and i_level >= self.temporal_up_offset + 1 |
|
else False |
|
) |
|
if spatial_up or temporal_up: |
|
up.upsample = Upsample( |
|
block_in, |
|
cnn_type=cnn_param["cnn_type"], |
|
spatial_up=spatial_up, |
|
temporal_up=temporal_up, |
|
use_pxsl=self.use_pxsf, |
|
) |
|
self.up.insert(0, up) |
|
|
|
|
|
self.norm_out = Normalize( |
|
block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"] |
|
) |
|
if cnn_param["conv_in_out_2d"] == "yes": |
|
self.conv_out = Conv( |
|
block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d" |
|
) |
|
else: |
|
self.conv_out = Conv( |
|
block_in, |
|
out_ch, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
cnn_type=cnn_param["cnn_type"], |
|
) |
|
|
|
def forward(self, z): |
|
if not self.use_checkpoint: |
|
return self._forward(z) |
|
else: |
|
return checkpoint.checkpoint(self._forward, z, use_reentrant=False) |
|
|
|
def _forward(self, z: Tensor) -> Tensor: |
|
|
|
h = self.conv_in(z) |
|
|
|
|
|
h = self.mid.block_1(h) |
|
if self.cnn_param["cnn_attention"] == "yes": |
|
h = self.mid.attn_1(h) |
|
h = self.mid.block_2(h) |
|
|
|
|
|
for i_level in reversed(range(self.num_resolutions)): |
|
for i_block in range(self.num_res_blocks + 1): |
|
h = self.up[i_level].block[i_block](h) |
|
if len(self.up[i_level].attn) > 0: |
|
h = self.up[i_level].attn[i_block](h) |
|
if hasattr(self.up[i_level], "upsample"): |
|
h = self.up[i_level].upsample(h) |
|
|
|
|
|
h = self.norm_out(h) |
|
h = swish(h) |
|
h = self.conv_out(h) |
|
return h |
|
|
|
|
|
class AutoEncoder(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
cnn_param = dict( |
|
cnn_type=args.cnn_type, |
|
conv_in_out_2d=args.conv_in_out_2d, |
|
res_conv_2d=args.res_conv_2d, |
|
cnn_attention=args.cnn_attention, |
|
cnn_norm_axis=args.cnn_norm_axis, |
|
conv_inner_2d=args.conv_inner_2d, |
|
) |
|
self.encoder = Encoder( |
|
ch=args.base_ch, |
|
ch_mult=args.encoder_ch_mult, |
|
num_res_blocks=args.num_res_blocks, |
|
z_channels=args.codebook_dim, |
|
patch_size=args.patch_size, |
|
temporal_patch_size=args.temporal_patch_size, |
|
cnn_param=cnn_param, |
|
use_checkpoint=args.use_checkpoint, |
|
use_vae=args.use_vae, |
|
) |
|
self.decoder = Decoder( |
|
ch=args.base_ch, |
|
ch_mult=args.decoder_ch_mult, |
|
num_res_blocks=args.num_res_blocks, |
|
z_channels=args.codebook_dim, |
|
patch_size=args.patch_size, |
|
temporal_patch_size=args.temporal_patch_size, |
|
cnn_param=cnn_param, |
|
use_checkpoint=args.use_checkpoint, |
|
use_freq_dec=args.use_freq_dec, |
|
use_pxsf=args.use_pxsf, |
|
) |
|
self.z_drop = nn.Dropout(args.z_drop) |
|
self.scale_factor = 0.3611 |
|
self.shift_factor = 0.1159 |
|
self.codebook_dim = self.embed_dim = args.codebook_dim |
|
|
|
self.gan_feat_weight = args.gan_feat_weight |
|
self.video_perceptual_weight = args.video_perceptual_weight |
|
self.recon_loss_type = args.recon_loss_type |
|
self.l1_weight = args.l1_weight |
|
self.use_vae = args.use_vae |
|
self.kl_weight = args.kl_weight |
|
self.lfq_weight = args.lfq_weight |
|
self.image_gan_weight = args.image_gan_weight |
|
self.video_gan_weight = args.video_gan_weight |
|
self.perceptual_weight = args.perceptual_weight |
|
self.flux_weight = args.flux_weight |
|
self.cycle_weight = args.cycle_weight |
|
self.cycle_feat_weight = args.cycle_feat_weight |
|
self.cycle_gan_weight = args.cycle_gan_weight |
|
|
|
self.flux_image_encoder = None |
|
|
|
if not args.use_vae: |
|
if args.quantizer_type == "MultiScaleBSQ": |
|
self.quantizer = MultiScaleBSQ( |
|
dim=args.codebook_dim, |
|
codebook_size=args.codebook_size, |
|
entropy_loss_weight=args.entropy_loss_weight, |
|
diversity_gamma=args.diversity_gamma, |
|
preserve_norm=args.preserve_norm, |
|
ln_before_quant=args.ln_before_quant, |
|
ln_init_by_sqrt=args.ln_init_by_sqrt, |
|
commitment_loss_weight=args.commitment_loss_weight, |
|
new_quant=args.new_quant, |
|
use_decay_factor=args.use_decay_factor, |
|
mask_out=args.mask_out, |
|
use_stochastic_depth=args.use_stochastic_depth, |
|
drop_rate=args.drop_rate, |
|
schedule_mode=args.schedule_mode, |
|
keep_first_quant=args.keep_first_quant, |
|
keep_last_quant=args.keep_last_quant, |
|
remove_residual_detach=args.remove_residual_detach, |
|
use_out_phi=args.use_out_phi, |
|
use_out_phi_res=args.use_out_phi_res, |
|
random_flip=args.random_flip, |
|
flip_prob=args.flip_prob, |
|
flip_mode=args.flip_mode, |
|
max_flip_lvl=args.max_flip_lvl, |
|
random_flip_1lvl=args.random_flip_1lvl, |
|
flip_lvl_idx=args.flip_lvl_idx, |
|
drop_when_test=args.drop_when_test, |
|
drop_lvl_idx=args.drop_lvl_idx, |
|
drop_lvl_num=args.drop_lvl_num, |
|
) |
|
self.quantize = self.quantizer |
|
self.vocab_size = args.codebook_size |
|
else: |
|
raise NotImplementedError(f"{args.quantizer_type} not supported") |
|
|
|
def forward(self, x): |
|
is_image = x.ndim == 4 |
|
if not is_image: |
|
B, C, T, H, W = x.shape |
|
else: |
|
B, C, H, W = x.shape |
|
T = 1 |
|
enc_dtype = ptdtype[self.args.encoder_dtype] |
|
|
|
with torch.amp.autocast("cuda", dtype=enc_dtype): |
|
h, hs, hs_mid = self.encoder(x, return_hidden=True) |
|
hs = [_h.detach() for _h in hs] |
|
hs_mid = [_h.detach() for _h in hs_mid] |
|
h = h.to(dtype=torch.float32) |
|
|
|
|
|
z, all_indices, _, _, all_loss, _ = self.quantizer(h) |
|
x_recon = self.decoder(z) |
|
vq_output = { |
|
"commitment_loss": torch.mean(all_loss) |
|
* self.lfq_weight, |
|
"encodings": all_indices, |
|
} |
|
|
|
return x_recon, None, z |
|
|
|
def encode_for_raw_features( |
|
self, x, scale_schedule, return_residual_norm_per_scale=False |
|
): |
|
is_image = x.ndim == 4 |
|
if not is_image: |
|
B, C, T, H, W = x.shape |
|
else: |
|
B, C, H, W = x.shape |
|
T = 1 |
|
|
|
enc_dtype = ptdtype[self.args.encoder_dtype] |
|
with torch.amp.autocast("cuda", dtype=enc_dtype): |
|
h, hs, hs_mid = self.encoder(x, return_hidden=True) |
|
|
|
hs = [_h.detach() for _h in hs] |
|
hs_mid = [_h.detach() for _h in hs_mid] |
|
h = h.to(dtype=torch.float32) |
|
return h, hs, hs_mid |
|
|
|
def encode(self, x, scale_schedule, return_residual_norm_per_scale=False): |
|
h, hs, hs_mid = self.encode_for_raw_features( |
|
x, scale_schedule, return_residual_norm_per_scale |
|
) |
|
|
|
( |
|
z, |
|
all_indices, |
|
all_bit_indices, |
|
residual_norm_per_scale, |
|
all_loss, |
|
var_input, |
|
) = self.quantizer( |
|
h, |
|
scale_schedule=scale_schedule, |
|
return_residual_norm_per_scale=return_residual_norm_per_scale, |
|
) |
|
return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input |
|
|
|
def decode(self, z): |
|
x_recon = self.decoder(z) |
|
x_recon = torch.clamp(x_recon, min=-1, max=1) |
|
return x_recon |
|
|
|
def decode_from_indices(self, all_indices, scale_schedule, label_type): |
|
summed_codes = 0 |
|
for idx_Bl in all_indices: |
|
codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type) |
|
summed_codes += F.interpolate( |
|
codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up |
|
) |
|
assert summed_codes.shape[-3] == 1 |
|
x_recon = self.decoder(summed_codes.squeeze(-3)) |
|
x_recon = torch.clamp(x_recon, min=-1, max=1) |
|
return summed_codes, x_recon |
|
|
|
@staticmethod |
|
def add_model_specific_args(parent_parser): |
|
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) |
|
parser.add_argument("--flux_weight", type=float, default=0) |
|
parser.add_argument("--cycle_weight", type=float, default=0) |
|
parser.add_argument("--cycle_feat_weight", type=float, default=0) |
|
parser.add_argument("--cycle_gan_weight", type=float, default=0) |
|
parser.add_argument("--cycle_loop", type=int, default=0) |
|
parser.add_argument("--z_drop", type=float, default=0.0) |
|
return parser |
|
|