import functools import os from typing import Any, Dict, List, Optional, Tuple, Union import PIL.Image import torch from accelerate import init_empty_weights from diffusers import ( AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanPipeline, WanTransformer3DModel, ) from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from transformers import AutoModel, AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel import finetrainers.functional as FF from finetrainers.data import VideoArtifact from finetrainers.logging import get_logger from finetrainers.models.modeling_utils import ModelSpecification from finetrainers.processors import ProcessorMixin, T5Processor from finetrainers.typing import ArtifactType, SchedulerType from finetrainers.utils import get_non_null_items, safetensors_torch_save_function logger = get_logger() class WanLatentEncodeProcessor(ProcessorMixin): r""" Processor to encode image/video into latents using the Wan VAE. Args: output_names (`List[str]`): The names of the outputs that the processor returns. The outputs are in the following order: - latents: The latents of the input image/video. - latents_mean: The channel-wise mean of the latent space. - latents_std: The channel-wise standard deviation of the latent space. """ def __init__(self, output_names: List[str]): super().__init__() self.output_names = output_names assert len(self.output_names) == 3 def forward( self, vae: AutoencoderKLWan, image: Optional[torch.Tensor] = None, video: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, compute_posterior: bool = True, ) -> Dict[str, torch.Tensor]: device = vae.device dtype = vae.dtype if image is not None: video = image.unsqueeze(1) assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" video = video.to(device=device, dtype=dtype) video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] if compute_posterior: latents = vae.encode(video).latent_dist.sample(generator=generator) latents = latents.to(dtype=dtype) else: # TODO(aryan): refactor in diffusers to have use_slicing attribute # if vae.use_slicing and video.shape[0] > 1: # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] # moments = torch.cat(encoded_slices) # else: # moments = vae._encode(video) moments = vae._encode(video) latents = moments.to(dtype=dtype) latents_mean = torch.tensor(vae.config.latents_mean) latents_std = 1.0 / torch.tensor(vae.config.latents_std) return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std} class WanImageConditioningLatentEncodeProcessor(ProcessorMixin): r""" Processor to encode image/video into latents using the Wan VAE. Args: output_names (`List[str]`): The names of the outputs that the processor returns. The outputs are in the following order: - latents: The latents of the input image/video. - latents_mean: The channel-wise mean of the latent space. - latents_std: The channel-wise standard deviation of the latent space. - mask: The conditioning frame mask for the input image/video. """ def __init__(self, output_names: List[str], *, use_last_frame: bool = False): super().__init__() self.output_names = output_names self.use_last_frame = use_last_frame assert len(self.output_names) == 4 def forward( self, vae: AutoencoderKLWan, image: Optional[torch.Tensor] = None, video: Optional[torch.Tensor] = None, compute_posterior: bool = True, ) -> Dict[str, torch.Tensor]: device = vae.device dtype = vae.dtype if image is not None: video = image.unsqueeze(1) assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" video = video.to(device=device, dtype=dtype) video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] num_frames = video.size(2) if not self.use_last_frame: first_frame, remaining_frames = video[:, :, :1], video[:, :, 1:] video = torch.cat([first_frame, torch.zeros_like(remaining_frames)], dim=2) else: first_frame, remaining_frames, last_frame = video[:, :, :1], video[:, :, 1:-1], video[:, :, -1:] video = torch.cat([first_frame, torch.zeros_like(remaining_frames), last_frame], dim=2) # Image conditioning uses argmax sampling, so we use "mode" here if compute_posterior: latents = vae.encode(video).latent_dist.mode() latents = latents.to(dtype=dtype) else: # TODO(aryan): refactor in diffusers to have use_slicing attribute # if vae.use_slicing and video.shape[0] > 1: # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] # moments = torch.cat(encoded_slices) # else: # moments = vae._encode(video) moments = vae._encode(video) latents = moments.to(dtype=dtype) latents_mean = torch.tensor(vae.config.latents_mean) latents_std = 1.0 / torch.tensor(vae.config.latents_std) temporal_downsample = 2 ** sum(vae.temperal_downsample) if getattr(self, "vae", None) else 4 mask = latents.new_ones(latents.shape[0], 1, num_frames, latents.shape[3], latents.shape[4]) if not self.use_last_frame: mask[:, :, 1:] = 0 else: mask[:, :, 1:-1] = 0 first_frame_mask = mask[:, :, :1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=temporal_downsample) mask = torch.cat([first_frame_mask, mask[:, :, 1:]], dim=2) mask = mask.view(latents.shape[0], -1, temporal_downsample, latents.shape[3], latents.shape[4]) mask = mask.transpose(1, 2) return { self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std, self.output_names[3]: mask, } class WanImageEncodeProcessor(ProcessorMixin): r""" Processor to encoding image conditioning for Wan I2V training. Args: output_names (`List[str]`): The names of the outputs that the processor returns. The outputs are in the following order: - image_embeds: The CLIP vision model image embeddings of the input image. """ def __init__(self, output_names: List[str], *, use_last_frame: bool = False): super().__init__() self.output_names = output_names self.use_last_frame = use_last_frame assert len(self.output_names) == 1 def forward( self, image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, image: Optional[torch.Tensor] = None, video: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: device = image_encoder.device dtype = image_encoder.dtype last_image = None # We know the image here is in the range [-1, 1] (probably a little overshot if using bilinear interpolation), but # the processor expects it to be in the range [0, 1]. image = image if video is None else video[:, 0] # [B, F, C, H, W] -> [B, C, H, W] (take first frame) image = FF.normalize(image, min=0.0, max=1.0, dim=1) assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor" if self.use_last_frame: last_image = image if video is None else video[:, -1] last_image = FF.normalize(last_image, min=0.0, max=1.0, dim=1) image = torch.stack([image, last_image], dim=0) image = image_processor(images=image.float(), do_rescale=False, do_convert_rgb=False, return_tensors="pt") image = image.to(device=device, dtype=dtype) image_embeds = image_encoder(**image, output_hidden_states=True) image_embeds = image_embeds.hidden_states[-2] return {self.output_names[0]: image_embeds} class WanModelSpecification(ModelSpecification): def __init__( self, pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", tokenizer_id: Optional[str] = None, text_encoder_id: Optional[str] = None, transformer_id: Optional[str] = None, vae_id: Optional[str] = None, text_encoder_dtype: torch.dtype = torch.bfloat16, transformer_dtype: torch.dtype = torch.bfloat16, vae_dtype: torch.dtype = torch.bfloat16, revision: Optional[str] = None, cache_dir: Optional[str] = None, condition_model_processors: List[ProcessorMixin] = None, latent_model_processors: List[ProcessorMixin] = None, **kwargs, ) -> None: super().__init__( pretrained_model_name_or_path=pretrained_model_name_or_path, tokenizer_id=tokenizer_id, text_encoder_id=text_encoder_id, transformer_id=transformer_id, vae_id=vae_id, text_encoder_dtype=text_encoder_dtype, transformer_dtype=transformer_dtype, vae_dtype=vae_dtype, revision=revision, cache_dir=cache_dir, ) use_last_frame = self.transformer_config.get("pos_embed_seq_len", None) is not None if condition_model_processors is None: condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])] if latent_model_processors is None: latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])] if self.transformer_config.get("image_dim", None) is not None: latent_model_processors.append( WanImageConditioningLatentEncodeProcessor( ["latent_condition", "__drop__", "__drop__", "latent_condition_mask"], use_last_frame=use_last_frame, ) ) latent_model_processors.append( WanImageEncodeProcessor(["encoder_hidden_states_image"], use_last_frame=use_last_frame) ) self.condition_model_processors = condition_model_processors self.latent_model_processors = latent_model_processors @property def _resolution_dim_keys(self): return {"latents": (2, 3, 4)} def load_condition_models(self) -> Dict[str, torch.nn.Module]: common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} if self.tokenizer_id is not None: tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) else: tokenizer = AutoTokenizer.from_pretrained( self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs ) if self.text_encoder_id is not None: text_encoder = AutoModel.from_pretrained( self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs ) else: text_encoder = UMT5EncoderModel.from_pretrained( self.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=self.text_encoder_dtype, **common_kwargs, ) return {"tokenizer": tokenizer, "text_encoder": text_encoder} def load_latent_models(self) -> Dict[str, torch.nn.Module]: common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} if self.vae_id is not None: vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) else: vae = AutoencoderKLWan.from_pretrained( self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs ) models = {"vae": vae} if self.transformer_config.get("image_dim", None) is not None: # TODO(aryan): refactor the trainer to be able to support these extra models from CLI args more easily image_encoder = CLIPVisionModel.from_pretrained( self.pretrained_model_name_or_path, subfolder="image_encoder", torch_dtype=torch.bfloat16 ) image_processor = CLIPImageProcessor.from_pretrained( self.pretrained_model_name_or_path, subfolder="image_processor" ) models["image_encoder"] = image_encoder models["image_processor"] = image_processor return models def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} if self.transformer_id is not None: transformer = WanTransformer3DModel.from_pretrained( self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs ) else: transformer = WanTransformer3DModel.from_pretrained( self.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=self.transformer_dtype, **common_kwargs, ) scheduler = FlowMatchEulerDiscreteScheduler() return {"transformer": transformer, "scheduler": scheduler} def load_pipeline( self, tokenizer: Optional[AutoTokenizer] = None, text_encoder: Optional[UMT5EncoderModel] = None, transformer: Optional[WanTransformer3DModel] = None, vae: Optional[AutoencoderKLWan] = None, scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, image_encoder: Optional[CLIPVisionModel] = None, image_processor: Optional[CLIPImageProcessor] = None, enable_slicing: bool = False, enable_tiling: bool = False, enable_model_cpu_offload: bool = False, training: bool = False, **kwargs, ) -> Union[WanPipeline, WanImageToVideoPipeline]: components = { "tokenizer": tokenizer, "text_encoder": text_encoder, "transformer": transformer, "vae": vae, "scheduler": scheduler, "image_encoder": image_encoder, "image_processor": image_processor, } components = get_non_null_items(components) if self.transformer_config.get("image_dim", None) is not None: pipe = WanPipeline.from_pretrained( self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir ) else: pipe = WanImageToVideoPipeline.from_pretrained( self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir ) pipe.text_encoder.to(self.text_encoder_dtype) pipe.vae.to(self.vae_dtype) if not training: pipe.transformer.to(self.transformer_dtype) # TODO(aryan): add support in diffusers # if enable_slicing: # pipe.vae.enable_slicing() # if enable_tiling: # pipe.vae.enable_tiling() if enable_model_cpu_offload: pipe.enable_model_cpu_offload() return pipe @torch.no_grad() def prepare_conditions( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, caption: str, max_sequence_length: int = 512, **kwargs, ) -> Dict[str, Any]: conditions = { "tokenizer": tokenizer, "text_encoder": text_encoder, "caption": caption, "max_sequence_length": max_sequence_length, **kwargs, } input_keys = set(conditions.keys()) conditions = super().prepare_conditions(**conditions) conditions = {k: v for k, v in conditions.items() if k not in input_keys} return conditions @torch.no_grad() def prepare_latents( self, vae: AutoencoderKLWan, image_encoder: Optional[CLIPVisionModel] = None, image_processor: Optional[CLIPImageProcessor] = None, image: Optional[torch.Tensor] = None, video: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, compute_posterior: bool = True, **kwargs, ) -> Dict[str, torch.Tensor]: conditions = { "vae": vae, "image_encoder": image_encoder, "image_processor": image_processor, "image": image, "video": video, "generator": generator, # We must force this to False because the latent normalization should be done before # the posterior is computed. The VAE does not handle this any more: # https://github.com/huggingface/diffusers/pull/10998 "compute_posterior": False, **kwargs, } input_keys = set(conditions.keys()) conditions = super().prepare_latents(**conditions) conditions = {k: v for k, v in conditions.items() if k not in input_keys} return conditions def forward( self, transformer: WanTransformer3DModel, condition_model_conditions: Dict[str, torch.Tensor], latent_model_conditions: Dict[str, torch.Tensor], sigmas: torch.Tensor, generator: Optional[torch.Generator] = None, compute_posterior: bool = True, **kwargs, ) -> Tuple[torch.Tensor, ...]: compute_posterior = False # See explanation in prepare_latents latent_condition = latent_condition_mask = None if compute_posterior: latents = latent_model_conditions.pop("latents") latent_condition = latent_model_conditions.pop("latent_condition", None) latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None) else: latents = latent_model_conditions.pop("latents") latents_mean = latent_model_conditions.pop("latents_mean") latents_std = latent_model_conditions.pop("latents_std") latent_condition = latent_model_conditions.pop("latent_condition", None) latent_condition_mask = latent_model_conditions.pop("latent_condition_mask", None) mu, logvar = torch.chunk(latents, 2, dim=1) mu = self._normalize_latents(mu, latents_mean, latents_std) logvar = self._normalize_latents(logvar, latents_mean, latents_std) latents = torch.cat([mu, logvar], dim=1) posterior = DiagonalGaussianDistribution(latents) latents = posterior.sample(generator=generator) if latent_condition is not None: mu, logvar = torch.chunk(latent_condition, 2, dim=1) mu = self._normalize_latents(mu, latents_mean, latents_std) logvar = self._normalize_latents(logvar, latents_mean, latents_std) latent_condition = torch.cat([mu, logvar], dim=1) posterior = DiagonalGaussianDistribution(latent_condition) latent_condition = posterior.mode() del posterior noise = torch.zeros_like(latents).normal_(generator=generator) noisy_latents = FF.flow_match_xt(latents, noise, sigmas) timesteps = (sigmas.flatten() * 1000.0).long() if self.transformer_config.get("image_dim", None) is not None: noisy_latents = torch.cat([noisy_latents, latent_condition_mask, latent_condition], dim=1) latent_model_conditions["hidden_states"] = noisy_latents.to(latents) pred = transformer( **latent_model_conditions, **condition_model_conditions, timestep=timesteps, return_dict=False, )[0] target = FF.flow_match_target(noise, latents) return pred, target, sigmas def validation( self, pipeline: Union[WanPipeline, WanImageToVideoPipeline], prompt: str, image: Optional[PIL.Image.Image] = None, last_image: Optional[PIL.Image.Image] = None, video: Optional[List[PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, num_frames: Optional[int] = None, num_inference_steps: int = 50, generator: Optional[torch.Generator] = None, **kwargs, ) -> List[ArtifactType]: generation_kwargs = { "prompt": prompt, "height": height, "width": width, "num_frames": num_frames, "num_inference_steps": num_inference_steps, "generator": generator, "return_dict": True, "output_type": "pil", } if self.transformer_config.get("image_dim", None) is not None: if image is None and video is None: raise ValueError("Either image or video must be provided for Wan I2V validation.") image = image if image is not None else video[0] generation_kwargs["image"] = image if self.transformer_config.get("pos_embed_seq_len", None) is not None: last_image = last_image if last_image is not None else image if video is None else video[-1] generation_kwargs["last_image"] = last_image generation_kwargs = get_non_null_items(generation_kwargs) video = pipeline(**generation_kwargs).frames[0] return [VideoArtifact(value=video)] def _save_lora_weights( self, directory: str, transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, scheduler: Optional[SchedulerType] = None, metadata: Optional[Dict[str, str]] = None, *args, **kwargs, ) -> None: pipeline_cls = ( WanImageToVideoPipeline if self.transformer_config.get("image_dim", None) is not None else WanPipeline ) # TODO(aryan): this needs refactoring if transformer_state_dict is not None: pipeline_cls.save_lora_weights( directory, transformer_state_dict, save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), safe_serialization=True, ) if scheduler is not None: scheduler.save_pretrained(os.path.join(directory, "scheduler")) def _save_model( self, directory: str, transformer: WanTransformer3DModel, transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, scheduler: Optional[SchedulerType] = None, ) -> None: # TODO(aryan): this needs refactoring if transformer_state_dict is not None: with init_empty_weights(): transformer_copy = WanTransformer3DModel.from_config(transformer.config) transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) transformer_copy.save_pretrained(os.path.join(directory, "transformer")) if scheduler is not None: scheduler.save_pretrained(os.path.join(directory, "scheduler")) @staticmethod def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor ) -> torch.Tensor: latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) latents = ((latents.float() - latents_mean) * latents_std).to(latents) return latents