jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import functools
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
from accelerate import init_empty_weights
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
)
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
import finetrainers.functional as FF
from finetrainers.data import VideoArtifact
from finetrainers.logging import get_logger
from finetrainers.models.modeling_utils import ModelSpecification
from finetrainers.processors import ProcessorMixin, T5Processor
from finetrainers.typing import ArtifactType, SchedulerType
from finetrainers.utils import get_non_null_items, safetensors_torch_save_function
logger = get_logger()
class WanLatentEncodeProcessor(ProcessorMixin):
r"""
Processor to encode image/video into latents using the Wan VAE.
Args:
output_names (`List[str]`):
The names of the outputs that the processor returns. The outputs are in the following order:
- latents: The latents of the input image/video.
- latents_mean: The channel-wise mean of the latent space.
- latents_std: The channel-wise standard deviation of the latent space.
"""
def __init__(self, output_names: List[str]):
super().__init__()
self.output_names = output_names
assert len(self.output_names) == 3
def forward(
self,
vae: AutoencoderKLWan,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
) -> Dict[str, torch.Tensor]:
device = vae.device
dtype = vae.dtype
if image is not None:
video = image.unsqueeze(1)
assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
video = video.to(device=device, dtype=dtype)
video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
if compute_posterior:
latents = vae.encode(video).latent_dist.sample(generator=generator)
latents = latents.to(dtype=dtype)
else:
# TODO(aryan): refactor in diffusers to have use_slicing attribute
# if vae.use_slicing and video.shape[0] > 1:
# encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
# moments = torch.cat(encoded_slices)
# else:
# moments = vae._encode(video)
moments = vae._encode(video)
latents = moments.to(dtype=dtype)
latents_mean = torch.tensor(vae.config.latents_mean)
latents_std = 1.0 / torch.tensor(vae.config.latents_std)
return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std}
class WanImageConditioningLatentEncodeProcessor(ProcessorMixin):
r"""
Processor to encode image/video into latents using the Wan VAE.
Args:
output_names (`List[str]`):
The names of the outputs that the processor returns. The outputs are in the following order:
- latents: The latents of the input image/video.
- latents_mean: The channel-wise mean of the latent space.
- latents_std: The channel-wise standard deviation of the latent space.
- mask: The conditioning frame mask for the input image/video.
"""
def __init__(self, output_names: List[str], *, use_last_frame: bool = False):
super().__init__()
self.output_names = output_names
self.use_last_frame = use_last_frame
assert len(self.output_names) == 4
def forward(
self,
vae: AutoencoderKLWan,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
compute_posterior: bool = True,
) -> Dict[str, torch.Tensor]:
device = vae.device
dtype = vae.dtype
if image is not None:
video = image.unsqueeze(1)
assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
video = video.to(device=device, dtype=dtype)
video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
num_frames = video.size(2)
if not self.use_last_frame:
first_frame, remaining_frames = video[:, :, :1], video[:, :, 1:]
video = torch.cat([first_frame, torch.zeros_like(remaining_frames)], dim=2)
else:
first_frame, remaining_frames, last_frame = video[:, :, :1], video[:, :, 1:-1], video[:, :, -1:]
video = torch.cat([first_frame, torch.zeros_like(remaining_frames), last_frame], dim=2)
# Image conditioning uses argmax sampling, so we use "mode" here
if compute_posterior:
latents = vae.encode(video).latent_dist.mode()
latents = latents.to(dtype=dtype)
else:
# TODO(aryan): refactor in diffusers to have use_slicing attribute
# if vae.use_slicing and video.shape[0] > 1:
# encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
# moments = torch.cat(encoded_slices)
# else:
# moments = vae._encode(video)
moments = vae._encode(video)
latents = moments.to(dtype=dtype)
latents_mean = torch.tensor(vae.config.latents_mean)
latents_std = 1.0 / torch.tensor(vae.config.latents_std)
temporal_downsample = 2 ** sum(vae.temperal_downsample) if getattr(self, "vae", None) else 4
mask = latents.new_ones(latents.shape[0], 1, num_frames, latents.shape[3], latents.shape[4])
if not self.use_last_frame:
mask[:, :, 1:] = 0
else:
mask[:, :, 1:-1] = 0
first_frame_mask = mask[:, :, :1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=temporal_downsample)
mask = torch.cat([first_frame_mask, mask[:, :, 1:]], dim=2)
mask = mask.view(latents.shape[0], -1, temporal_downsample, latents.shape[3], latents.shape[4])
mask = mask.transpose(1, 2)
return {
self.output_names[0]: latents,
self.output_names[1]: latents_mean,
self.output_names[2]: latents_std,
self.output_names[3]: mask,
}
class WanImageEncodeProcessor(ProcessorMixin):
r"""
Processor to encoding image conditioning for Wan I2V training.
Args:
output_names (`List[str]`):
The names of the outputs that the processor returns. The outputs are in the following order:
- image_embeds: The CLIP vision model image embeddings of the input image.
"""
def __init__(self, output_names: List[str], *, use_last_frame: bool = False):
super().__init__()
self.output_names = output_names
self.use_last_frame = use_last_frame
assert len(self.output_names) == 1
def forward(
self,
image_encoder: CLIPVisionModel,
image_processor: CLIPImageProcessor,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
device = image_encoder.device
dtype = image_encoder.dtype
last_image = None
# We know the image here is in the range [-1, 1] (probably a little overshot if using bilinear interpolation), but
# the processor expects it to be in the range [0, 1].
image = image if video is None else video[:, 0] # [B, F, C, H, W] -> [B, C, H, W] (take first frame)
image = FF.normalize(image, min=0.0, max=1.0, dim=1)
assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
if self.use_last_frame:
last_image = image if video is None else video[:, -1]
last_image = FF.normalize(last_image, min=0.0, max=1.0, dim=1)
image = torch.stack([image, last_image], dim=0)
image = image_processor(images=image.float(), do_rescale=False, do_convert_rgb=False, return_tensors="pt")
image = image.to(device=device, dtype=dtype)
image_embeds = image_encoder(**image, output_hidden_states=True)
image_embeds = image_embeds.hidden_states[-2]
return {self.output_names[0]: image_embeds}
class WanModelSpecification(ModelSpecification):
def __init__(
self,
pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
tokenizer_id: Optional[str] = None,
text_encoder_id: Optional[str] = None,
transformer_id: Optional[str] = None,
vae_id: Optional[str] = None,
text_encoder_dtype: torch.dtype = torch.bfloat16,
transformer_dtype: torch.dtype = torch.bfloat16,
vae_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
condition_model_processors: List[ProcessorMixin] = None,
latent_model_processors: List[ProcessorMixin] = None,
**kwargs,
) -> None:
super().__init__(
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer_id=tokenizer_id,
text_encoder_id=text_encoder_id,
transformer_id=transformer_id,
vae_id=vae_id,
text_encoder_dtype=text_encoder_dtype,
transformer_dtype=transformer_dtype,
vae_dtype=vae_dtype,
revision=revision,
cache_dir=cache_dir,
)
use_last_frame = self.transformer_config.get("pos_embed_seq_len", None) is not None
if condition_model_processors is None:
condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])]
if latent_model_processors is None:
latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
if self.transformer_config.get("image_dim", None) is not None:
latent_model_processors.append(
WanImageConditioningLatentEncodeProcessor(
["latent_condition", "__drop__", "__drop__", "latent_condition_mask"],
use_last_frame=use_last_frame,
)
)
latent_model_processors.append(
WanImageEncodeProcessor(["encoder_hidden_states_image"], use_last_frame=use_last_frame)
)
self.condition_model_processors = condition_model_processors
self.latent_model_processors = latent_model_processors
@property
def _resolution_dim_keys(self):
return {"latents": (2, 3, 4)}
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
if self.tokenizer_id is not None:
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
)
if self.text_encoder_id is not None:
text_encoder = AutoModel.from_pretrained(
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
)
else:
text_encoder = UMT5EncoderModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=self.text_encoder_dtype,
**common_kwargs,
)
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
if self.vae_id is not None:
vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
else:
vae = AutoencoderKLWan.from_pretrained(
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
)
models = {"vae": vae}
if self.transformer_config.get("image_dim", None) is not None:
# TODO(aryan): refactor the trainer to be able to support these extra models from CLI args more easily
image_encoder = CLIPVisionModel.from_pretrained(
self.pretrained_model_name_or_path, subfolder="image_encoder", torch_dtype=torch.bfloat16
)
image_processor = CLIPImageProcessor.from_pretrained(
self.pretrained_model_name_or_path, subfolder="image_processor"
)
models["image_encoder"] = image_encoder
models["image_processor"] = image_processor
return models
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
if self.transformer_id is not None:
transformer = WanTransformer3DModel.from_pretrained(
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
)
else:
transformer = WanTransformer3DModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=self.transformer_dtype,
**common_kwargs,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
def load_pipeline(
self,
tokenizer: Optional[AutoTokenizer] = None,
text_encoder: Optional[UMT5EncoderModel] = None,
transformer: Optional[WanTransformer3DModel] = None,
vae: Optional[AutoencoderKLWan] = None,
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
image_encoder: Optional[CLIPVisionModel] = None,
image_processor: Optional[CLIPImageProcessor] = None,
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
training: bool = False,
**kwargs,
) -> Union[WanPipeline, WanImageToVideoPipeline]:
components = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"image_encoder": image_encoder,
"image_processor": image_processor,
}
components = get_non_null_items(components)
if self.transformer_config.get("image_dim", None) is not None:
pipe = WanPipeline.from_pretrained(
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
)
else:
pipe = WanImageToVideoPipeline.from_pretrained(
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
)
pipe.text_encoder.to(self.text_encoder_dtype)
pipe.vae.to(self.vae_dtype)
if not training:
pipe.transformer.to(self.transformer_dtype)
# TODO(aryan): add support in diffusers
# if enable_slicing:
# pipe.vae.enable_slicing()
# if enable_tiling:
# pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload()
return pipe
@torch.no_grad()
def prepare_conditions(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
caption: str,
max_sequence_length: int = 512,
**kwargs,
) -> Dict[str, Any]:
conditions = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"caption": caption,
"max_sequence_length": max_sequence_length,
**kwargs,
}
input_keys = set(conditions.keys())
conditions = super().prepare_conditions(**conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
return conditions
@torch.no_grad()
def prepare_latents(
self,
vae: AutoencoderKLWan,
image_encoder: Optional[CLIPVisionModel] = None,
image_processor: Optional[CLIPImageProcessor] = None,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
**kwargs,
) -> Dict[str, torch.Tensor]:
conditions = {
"vae": vae,
"image_encoder": image_encoder,
"image_processor": image_processor,
"image": image,
"video": video,
"generator": generator,
# We must force this to False because the latent normalization should be done before
# the posterior is computed. The VAE does not handle this any more:
# https://github.com/huggingface/diffusers/pull/10998
"compute_posterior": False,
**kwargs,
}
input_keys = set(conditions.keys())
conditions = super().prepare_latents(**conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
return conditions
def forward(
self,
transformer: WanTransformer3DModel,
condition_model_conditions: Dict[str, torch.Tensor],
latent_model_conditions: Dict[str, torch.Tensor],
sigmas: torch.Tensor,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
compute_posterior = False # See explanation in prepare_latents
latent_condition = latent_condition_mask = None
if compute_posterior:
latents = latent_model_conditions.pop("latents")
latent_condition = latent_model_conditions.pop("latent_condition", None)
latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None)
else:
latents = latent_model_conditions.pop("latents")
latents_mean = latent_model_conditions.pop("latents_mean")
latents_std = latent_model_conditions.pop("latents_std")
latent_condition = latent_model_conditions.pop("latent_condition", None)
latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None)
mu, logvar = torch.chunk(latents, 2, dim=1)
mu = self._normalize_latents(mu, latents_mean, latents_std)
logvar = self._normalize_latents(logvar, latents_mean, latents_std)
latents = torch.cat([mu, logvar], dim=1)
posterior = DiagonalGaussianDistribution(latents)
latents = posterior.sample(generator=generator)
if latent_condition is not None:
mu, logvar = torch.chunk(latent_condition, 2, dim=1)
mu = self._normalize_latents(mu, latents_mean, latents_std)
logvar = self._normalize_latents(logvar, latents_mean, latents_std)
latent_condition = torch.cat([mu, logvar], dim=1)
posterior = DiagonalGaussianDistribution(latent_condition)
latent_condition = posterior.mode()
del posterior
noise = torch.zeros_like(latents).normal_(generator=generator)
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
timesteps = (sigmas.flatten() * 1000.0).long()
if self.transformer_config.get("image_dim", None) is not None:
noisy_latents = torch.cat([noisy_latents, latent_condition_mask, latent_condition], dim=1)
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
pred = transformer(
**latent_model_conditions,
**condition_model_conditions,
timestep=timesteps,
return_dict=False,
)[0]
target = FF.flow_match_target(noise, latents)
return pred, target, sigmas
def validation(
self,
pipeline: Union[WanPipeline, WanImageToVideoPipeline],
prompt: str,
image: Optional[PIL.Image.Image] = None,
last_image: Optional[PIL.Image.Image] = None,
video: Optional[List[PIL.Image.Image]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
**kwargs,
) -> List[ArtifactType]:
generation_kwargs = {
"prompt": prompt,
"height": height,
"width": width,
"num_frames": num_frames,
"num_inference_steps": num_inference_steps,
"generator": generator,
"return_dict": True,
"output_type": "pil",
}
if self.transformer_config.get("image_dim", None) is not None:
if image is None and video is None:
raise ValueError("Either image or video must be provided for Wan I2V validation.")
image = image if image is not None else video[0]
generation_kwargs["image"] = image
if self.transformer_config.get("pos_embed_seq_len", None) is not None:
last_image = last_image if last_image is not None else image if video is None else video[-1]
generation_kwargs["last_image"] = last_image
generation_kwargs = get_non_null_items(generation_kwargs)
video = pipeline(**generation_kwargs).frames[0]
return [VideoArtifact(value=video)]
def _save_lora_weights(
self,
directory: str,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
metadata: Optional[Dict[str, str]] = None,
*args,
**kwargs,
) -> None:
pipeline_cls = (
WanImageToVideoPipeline if self.transformer_config.get("image_dim", None) is not None else WanPipeline
)
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
pipeline_cls.save_lora_weights(
directory,
transformer_state_dict,
save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
safe_serialization=True,
)
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
def _save_model(
self,
directory: str,
transformer: WanTransformer3DModel,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
) -> None:
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
with init_empty_weights():
transformer_copy = WanTransformer3DModel.from_config(transformer.config)
transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
@staticmethod
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
) -> torch.Tensor:
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
latents = ((latents.float() - latents_mean) * latents_std).to(latents)
return latents