|
from typing import Tuple
|
|
import torch
|
|
from diffusers import AutoencoderKL
|
|
from einops import rearrange
|
|
from torch import Tensor
|
|
|
|
|
|
from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
|
CausalVideoAutoencoder,
|
|
)
|
|
from ltx_video.models.autoencoders.video_autoencoder import (
|
|
Downsample3D,
|
|
VideoAutoencoder,
|
|
)
|
|
|
|
try:
|
|
import torch_xla.core.xla_model as xm
|
|
except ImportError:
|
|
xm = None
|
|
|
|
|
|
def vae_encode(
|
|
media_items: Tensor,
|
|
vae: AutoencoderKL,
|
|
split_size: int = 1,
|
|
vae_per_channel_normalize=False,
|
|
) -> Tensor:
|
|
"""
|
|
Encodes media items (images or videos) into latent representations using a specified VAE model.
|
|
The function supports processing batches of images or video frames and can handle the processing
|
|
in smaller sub-batches if needed.
|
|
|
|
Args:
|
|
media_items (Tensor): A torch Tensor containing the media items to encode. The expected
|
|
shape is (batch_size, channels, height, width) for images or (batch_size, channels,
|
|
frames, height, width) for videos.
|
|
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
|
|
pre-configured and loaded with the appropriate model weights.
|
|
split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
|
|
If set to more than 1, the input media items are processed in smaller batches according to
|
|
this value. Defaults to 1, which processes all items in a single batch.
|
|
|
|
Returns:
|
|
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
|
|
to match the input shape, scaled by the model's configuration.
|
|
|
|
Examples:
|
|
>>> import torch
|
|
>>> from diffusers import AutoencoderKL
|
|
>>> vae = AutoencoderKL.from_pretrained('your-model-name')
|
|
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
|
|
>>> latents = vae_encode(images, vae)
|
|
>>> print(latents.shape) # Output shape will depend on the model's latent configuration.
|
|
|
|
Note:
|
|
In case of a video, the function encodes the media item frame-by frame.
|
|
"""
|
|
is_video_shaped = media_items.dim() == 5
|
|
batch_size, channels = media_items.shape[0:2]
|
|
|
|
if channels != 3:
|
|
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
|
|
|
if is_video_shaped and not isinstance(
|
|
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
|
):
|
|
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
|
if split_size > 1:
|
|
if len(media_items) % split_size != 0:
|
|
raise ValueError(
|
|
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
|
)
|
|
encode_bs = len(media_items) // split_size
|
|
|
|
latents = []
|
|
if media_items.device.type == "xla":
|
|
xm.mark_step()
|
|
for image_batch in media_items.split(encode_bs):
|
|
latents.append(vae.encode(image_batch).latent_dist.sample())
|
|
if media_items.device.type == "xla":
|
|
xm.mark_step()
|
|
latents = torch.cat(latents, dim=0)
|
|
else:
|
|
latents = vae.encode(media_items).latent_dist.sample()
|
|
|
|
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
|
if is_video_shaped and not isinstance(
|
|
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
|
):
|
|
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
|
return latents
|
|
|
|
|
|
def vae_decode(
|
|
latents: Tensor,
|
|
vae: AutoencoderKL,
|
|
is_video: bool = True,
|
|
split_size: int = 1,
|
|
vae_per_channel_normalize=False,
|
|
timestep=None,
|
|
) -> Tensor:
|
|
is_video_shaped = latents.dim() == 5
|
|
batch_size = latents.shape[0]
|
|
|
|
if is_video_shaped and not isinstance(
|
|
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
|
):
|
|
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
|
if split_size > 1:
|
|
if len(latents) % split_size != 0:
|
|
raise ValueError(
|
|
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
|
)
|
|
encode_bs = len(latents) // split_size
|
|
image_batch = [
|
|
_run_decoder(
|
|
latent_batch, vae, is_video, vae_per_channel_normalize, timestep
|
|
)
|
|
for latent_batch in latents.split(encode_bs)
|
|
]
|
|
images = torch.cat(image_batch, dim=0)
|
|
else:
|
|
images = _run_decoder(
|
|
latents, vae, is_video, vae_per_channel_normalize, timestep
|
|
)
|
|
|
|
if is_video_shaped and not isinstance(
|
|
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
|
):
|
|
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
|
return images
|
|
|
|
|
|
def _run_decoder(
|
|
latents: Tensor,
|
|
vae: AutoencoderKL,
|
|
is_video: bool,
|
|
vae_per_channel_normalize=False,
|
|
timestep=None,
|
|
) -> Tensor:
|
|
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
|
*_, fl, hl, wl = latents.shape
|
|
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
|
latents = latents.to(vae.dtype)
|
|
vae_decode_kwargs = {}
|
|
if timestep is not None:
|
|
vae_decode_kwargs["timestep"] = timestep
|
|
image = vae.decode(
|
|
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
|
return_dict=False,
|
|
target_shape=(
|
|
1,
|
|
3,
|
|
fl * temporal_scale if is_video else 1,
|
|
hl * spatial_scale,
|
|
wl * spatial_scale,
|
|
),
|
|
**vae_decode_kwargs,
|
|
)[0]
|
|
else:
|
|
image = vae.decode(
|
|
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
|
return_dict=False,
|
|
)[0]
|
|
return image
|
|
|
|
|
|
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
|
|
if isinstance(vae, CausalVideoAutoencoder):
|
|
spatial = vae.spatial_downscale_factor
|
|
temporal = vae.temporal_downscale_factor
|
|
else:
|
|
down_blocks = len(
|
|
[
|
|
block
|
|
for block in vae.encoder.down_blocks
|
|
if isinstance(block.downsample, Downsample3D)
|
|
]
|
|
)
|
|
spatial = vae.config.patch_size * 2**down_blocks
|
|
temporal = (
|
|
vae.config.patch_size_t * 2**down_blocks
|
|
if isinstance(vae, VideoAutoencoder)
|
|
else 1
|
|
)
|
|
|
|
return (temporal, spatial, spatial)
|
|
|
|
|
|
def latent_to_pixel_coords(
|
|
latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
|
|
) -> Tensor:
|
|
"""
|
|
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
|
configuration.
|
|
|
|
Args:
|
|
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
|
containing the latent corner coordinates of each token.
|
|
vae (AutoencoderKL): The VAE model
|
|
causal_fix (bool): Whether to take into account the different temporal scale
|
|
of the first frame. Default = False for backwards compatibility.
|
|
Returns:
|
|
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
|
"""
|
|
|
|
scale_factors = get_vae_size_scale_factor(vae)
|
|
causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
|
|
pixel_coords = latent_to_pixel_coords_from_factors(
|
|
latent_coords, scale_factors, causal_fix
|
|
)
|
|
return pixel_coords
|
|
|
|
|
|
def latent_to_pixel_coords_from_factors(
|
|
latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
|
|
) -> Tensor:
|
|
pixel_coords = (
|
|
latent_coords
|
|
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
|
)
|
|
if causal_fix:
|
|
|
|
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
|
return pixel_coords
|
|
|
|
|
|
def normalize_latents(
|
|
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
|
) -> Tensor:
|
|
return (
|
|
(latents - vae.mean_of_means.to(latents.dtype).to(latents.device).view(1, -1, 1, 1, 1))
|
|
/ vae.std_of_means.to(latents.dtype).to(latents.device).view(1, -1, 1, 1, 1)
|
|
if vae_per_channel_normalize
|
|
else latents * vae.config.scaling_factor
|
|
)
|
|
|
|
|
|
def un_normalize_latents(
|
|
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
|
) -> Tensor:
|
|
return (
|
|
latents * vae.std_of_means.to(latents.dtype).to(latents.device).view(1, -1, 1, 1, 1)
|
|
+ vae.mean_of_means.to(latents.dtype).to(latents.device).view(1, -1, 1, 1, 1)
|
|
if vae_per_channel_normalize
|
|
else latents / vae.config.scaling_factor
|
|
)
|
|
|