Spaces:
Running
Running
VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/models
/wan
/control_specification.py
import functools | |
import os | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | |
import safetensors | |
import torch | |
from accelerate import init_empty_weights | |
from diffusers import ( | |
AutoencoderKLWan, | |
FlowMatchEulerDiscreteScheduler, | |
WanPipeline, | |
WanTransformer3DModel, | |
) | |
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | |
from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel | |
import finetrainers.functional as FF | |
from finetrainers.data import VideoArtifact | |
from finetrainers.logging import get_logger | |
from finetrainers.models.modeling_utils import ControlModelSpecification | |
from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights | |
from finetrainers.patches.dependencies.diffusers.control import control_channel_concat | |
from finetrainers.processors import ProcessorMixin, T5Processor | |
from finetrainers.typing import ArtifactType, SchedulerType | |
from finetrainers.utils import get_non_null_items, safetensors_torch_save_function | |
from .base_specification import WanLatentEncodeProcessor | |
if TYPE_CHECKING: | |
from finetrainers.trainer.control_trainer.config import FrameConditioningType | |
logger = get_logger() | |
class WanControlModelSpecification(ControlModelSpecification): | |
def __init__( | |
self, | |
pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", | |
tokenizer_id: Optional[str] = None, | |
text_encoder_id: Optional[str] = None, | |
transformer_id: Optional[str] = None, | |
vae_id: Optional[str] = None, | |
text_encoder_dtype: torch.dtype = torch.bfloat16, | |
transformer_dtype: torch.dtype = torch.bfloat16, | |
vae_dtype: torch.dtype = torch.bfloat16, | |
revision: Optional[str] = None, | |
cache_dir: Optional[str] = None, | |
condition_model_processors: List[ProcessorMixin] = None, | |
latent_model_processors: List[ProcessorMixin] = None, | |
control_model_processors: List[ProcessorMixin] = None, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
tokenizer_id=tokenizer_id, | |
text_encoder_id=text_encoder_id, | |
transformer_id=transformer_id, | |
vae_id=vae_id, | |
text_encoder_dtype=text_encoder_dtype, | |
transformer_dtype=transformer_dtype, | |
vae_dtype=vae_dtype, | |
revision=revision, | |
cache_dir=cache_dir, | |
) | |
if condition_model_processors is None: | |
condition_model_processors = [T5Processor(["encoder_hidden_states", "__drop__"])] | |
if latent_model_processors is None: | |
latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])] | |
if control_model_processors is None: | |
control_model_processors = [WanLatentEncodeProcessor(["control_latents", "__drop__", "__drop__"])] | |
self.condition_model_processors = condition_model_processors | |
self.latent_model_processors = latent_model_processors | |
self.control_model_processors = control_model_processors | |
def control_injection_layer_name(self) -> str: | |
return "patch_embedding" | |
def _resolution_dim_keys(self): | |
return {"latents": (2, 3, 4)} | |
def load_condition_models(self) -> Dict[str, torch.nn.Module]: | |
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | |
if self.tokenizer_id is not None: | |
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained( | |
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs | |
) | |
if self.text_encoder_id is not None: | |
text_encoder = AutoModel.from_pretrained( | |
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs | |
) | |
else: | |
text_encoder = UMT5EncoderModel.from_pretrained( | |
self.pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
torch_dtype=self.text_encoder_dtype, | |
**common_kwargs, | |
) | |
return {"tokenizer": tokenizer, "text_encoder": text_encoder} | |
def load_latent_models(self) -> Dict[str, torch.nn.Module]: | |
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | |
if self.vae_id is not None: | |
vae = AutoencoderKLWan.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) | |
else: | |
vae = AutoencoderKLWan.from_pretrained( | |
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs | |
) | |
return {"vae": vae} | |
def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]: | |
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | |
if self.transformer_id is not None: | |
transformer = WanTransformer3DModel.from_pretrained( | |
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs | |
) | |
else: | |
transformer = WanTransformer3DModel.from_pretrained( | |
self.pretrained_model_name_or_path, | |
subfolder="transformer", | |
torch_dtype=self.transformer_dtype, | |
**common_kwargs, | |
) | |
transformer.patch_embedding = _expand_conv3d_with_zeroed_weights( | |
transformer.patch_embedding, new_in_channels=new_in_features | |
) | |
transformer.register_to_config(in_channels=new_in_features) | |
scheduler = FlowMatchEulerDiscreteScheduler() | |
return {"transformer": transformer, "scheduler": scheduler} | |
def load_pipeline( | |
self, | |
tokenizer: Optional[AutoTokenizer] = None, | |
text_encoder: Optional[UMT5EncoderModel] = None, | |
transformer: Optional[WanTransformer3DModel] = None, | |
vae: Optional[AutoencoderKLWan] = None, | |
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, | |
enable_slicing: bool = False, | |
enable_tiling: bool = False, | |
enable_model_cpu_offload: bool = False, | |
training: bool = False, | |
**kwargs, | |
) -> WanPipeline: | |
components = { | |
"tokenizer": tokenizer, | |
"text_encoder": text_encoder, | |
"transformer": transformer, | |
"vae": vae, | |
"scheduler": scheduler, | |
} | |
components = get_non_null_items(components) | |
pipe = WanPipeline.from_pretrained( | |
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir | |
) | |
pipe.text_encoder.to(self.text_encoder_dtype) | |
pipe.vae.to(self.vae_dtype) | |
if not training: | |
pipe.transformer.to(self.transformer_dtype) | |
# TODO(aryan): add support in diffusers | |
# if enable_slicing: | |
# pipe.vae.enable_slicing() | |
# if enable_tiling: | |
# pipe.vae.enable_tiling() | |
if enable_model_cpu_offload: | |
pipe.enable_model_cpu_offload() | |
return pipe | |
def prepare_conditions( | |
self, | |
tokenizer: AutoTokenizer, | |
text_encoder: UMT5EncoderModel, | |
caption: str, | |
max_sequence_length: int = 512, | |
**kwargs, | |
) -> Dict[str, Any]: | |
conditions = { | |
"tokenizer": tokenizer, | |
"text_encoder": text_encoder, | |
"caption": caption, | |
"max_sequence_length": max_sequence_length, | |
**kwargs, | |
} | |
input_keys = set(conditions.keys()) | |
conditions = super().prepare_conditions(**conditions) | |
conditions = {k: v for k, v in conditions.items() if k not in input_keys} | |
return conditions | |
def prepare_latents( | |
self, | |
vae: AutoencoderKLWan, | |
image: Optional[torch.Tensor] = None, | |
video: Optional[torch.Tensor] = None, | |
control_image: Optional[torch.Tensor] = None, | |
control_video: Optional[torch.Tensor] = None, | |
generator: Optional[torch.Generator] = None, | |
compute_posterior: bool = True, | |
**kwargs, | |
) -> Dict[str, torch.Tensor]: | |
common_kwargs = { | |
"vae": vae, | |
"generator": generator, | |
# We must force this to False because the latent normalization should be done before | |
# the posterior is computed. The VAE does not handle this any more: | |
# https://github.com/huggingface/diffusers/pull/10998 | |
"compute_posterior": False, | |
**kwargs, | |
} | |
conditions = {"image": image, "video": video, **common_kwargs} | |
input_keys = set(conditions.keys()) | |
conditions = super().prepare_latents(**conditions) | |
conditions = {k: v for k, v in conditions.items() if k not in input_keys} | |
control_conditions = {"image": control_image, "video": control_video, **common_kwargs} | |
input_keys = set(control_conditions.keys()) | |
control_conditions = ControlModelSpecification.prepare_latents( | |
self, self.control_model_processors, **control_conditions | |
) | |
control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys} | |
return {**control_conditions, **conditions} | |
def forward( | |
self, | |
transformer: WanTransformer3DModel, | |
condition_model_conditions: Dict[str, torch.Tensor], | |
latent_model_conditions: Dict[str, torch.Tensor], | |
sigmas: torch.Tensor, | |
generator: Optional[torch.Generator] = None, | |
compute_posterior: bool = True, | |
**kwargs, | |
) -> Tuple[torch.Tensor, ...]: | |
from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents | |
compute_posterior = False # See explanation in prepare_latents | |
if compute_posterior: | |
latents = latent_model_conditions.pop("latents") | |
control_latents = latent_model_conditions.pop("control_latents") | |
else: | |
latents = latent_model_conditions.pop("latents") | |
control_latents = latent_model_conditions.pop("control_latents") | |
latents_mean = latent_model_conditions.pop("latents_mean") | |
latents_std = latent_model_conditions.pop("latents_std") | |
mu, logvar = torch.chunk(latents, 2, dim=1) | |
mu = self._normalize_latents(mu, latents_mean, latents_std) | |
logvar = self._normalize_latents(logvar, latents_mean, latents_std) | |
latents = torch.cat([mu, logvar], dim=1) | |
mu, logvar = torch.chunk(control_latents, 2, dim=1) | |
mu = self._normalize_latents(mu, latents_mean, latents_std) | |
logvar = self._normalize_latents(logvar, latents_mean, latents_std) | |
control_latents = torch.cat([mu, logvar], dim=1) | |
posterior = DiagonalGaussianDistribution(latents) | |
latents = posterior.mode() | |
del posterior | |
control_posterior = DiagonalGaussianDistribution(control_latents) | |
control_latents = control_posterior.mode() | |
del control_posterior | |
noise = torch.zeros_like(latents).normal_(generator=generator) | |
timesteps = (sigmas.flatten() * 1000.0).long() | |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas) | |
control_latents = apply_frame_conditioning_on_latents( | |
control_latents, | |
noisy_latents.shape[2], | |
channel_dim=1, | |
frame_dim=2, | |
frame_conditioning_type=self.frame_conditioning_type, | |
frame_conditioning_index=self.frame_conditioning_index, | |
concatenate_mask=self.frame_conditioning_concatenate_mask, | |
) | |
noisy_latents = torch.cat([noisy_latents, control_latents], dim=1) | |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents) | |
pred = transformer( | |
**latent_model_conditions, | |
**condition_model_conditions, | |
timestep=timesteps, | |
return_dict=False, | |
)[0] | |
target = FF.flow_match_target(noise, latents) | |
return pred, target, sigmas | |
def validation( | |
self, | |
pipeline: WanPipeline, | |
prompt: str, | |
control_image: Optional[torch.Tensor] = None, | |
control_video: Optional[torch.Tensor] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_frames: Optional[int] = None, | |
num_inference_steps: int = 50, | |
generator: Optional[torch.Generator] = None, | |
frame_conditioning_type: "FrameConditioningType" = "full", | |
frame_conditioning_index: int = 0, | |
**kwargs, | |
) -> List[ArtifactType]: | |
from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents | |
with torch.no_grad(): | |
dtype = pipeline.vae.dtype | |
device = pipeline._execution_device | |
in_channels = self.transformer_config.in_channels # We need to use the original in_channels | |
latents = pipeline.prepare_latents(1, in_channels, height, width, num_frames, dtype, device, generator) | |
latents_mean = ( | |
torch.tensor(self.vae_config.latents_mean) | |
.view(1, self.vae_config.z_dim, 1, 1, 1) | |
.to(latents.device, latents.dtype) | |
) | |
latents_std = 1.0 / torch.tensor(self.vae_config.latents_std).view(1, self.vae_config.z_dim, 1, 1, 1).to( | |
latents.device, latents.dtype | |
) | |
if control_image is not None: | |
control_video = pipeline.video_processor.preprocess( | |
control_image, height=height, width=width | |
).unsqueeze(2) | |
else: | |
control_video = pipeline.video_processor.preprocess_video(control_video, height=height, width=width) | |
control_video = control_video.to(device=device, dtype=dtype) | |
control_latents = pipeline.vae.encode(control_video).latent_dist.mode() | |
control_latents = self._normalize_latents(control_latents, latents_mean, latents_std) | |
control_latents = apply_frame_conditioning_on_latents( | |
control_latents, | |
latents.shape[2], | |
channel_dim=1, | |
frame_dim=2, | |
frame_conditioning_type=frame_conditioning_type, | |
frame_conditioning_index=frame_conditioning_index, | |
concatenate_mask=self.frame_conditioning_concatenate_mask, | |
) | |
generation_kwargs = { | |
"latents": latents, | |
"prompt": prompt, | |
"height": height, | |
"width": width, | |
"num_frames": num_frames, | |
"num_inference_steps": num_inference_steps, | |
"generator": generator, | |
"return_dict": True, | |
"output_type": "pil", | |
} | |
generation_kwargs = get_non_null_items(generation_kwargs) | |
with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]): | |
video = pipeline(**generation_kwargs).frames[0] | |
return [VideoArtifact(value=video)] | |
def _save_lora_weights( | |
self, | |
directory: str, | |
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
norm_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
scheduler: Optional[SchedulerType] = None, | |
metadata: Optional[Dict[str, str]] = None, | |
*args, | |
**kwargs, | |
) -> None: | |
# TODO(aryan): this needs refactoring | |
if transformer_state_dict is not None: | |
WanPipeline.save_lora_weights( | |
directory, | |
transformer_state_dict, | |
save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), | |
safe_serialization=True, | |
) | |
if norm_state_dict is not None: | |
safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors")) | |
if scheduler is not None: | |
scheduler.save_pretrained(os.path.join(directory, "scheduler")) | |
def _save_model( | |
self, | |
directory: str, | |
transformer: WanTransformer3DModel, | |
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
scheduler: Optional[SchedulerType] = None, | |
) -> None: | |
# TODO(aryan): this needs refactoring | |
if transformer_state_dict is not None: | |
with init_empty_weights(): | |
transformer_copy = WanTransformer3DModel.from_config(transformer.config) | |
transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) | |
transformer_copy.save_pretrained(os.path.join(directory, "transformer")) | |
if scheduler is not None: | |
scheduler.save_pretrained(os.path.join(directory, "scheduler")) | |
def _normalize_latents( | |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor | |
) -> torch.Tensor: | |
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) | |
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) | |
latents = ((latents.float() - latents_mean) * latents_std).to(latents) | |
return latents | |
def _original_control_layer_in_features(self): | |
return self.transformer_config.in_channels | |
def _original_control_layer_out_features(self): | |
return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim | |
def _qk_norm_identifiers(self): | |
return ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] | |