# borrowed from svg-project/Sparse-VideoGen from typing import Any, Dict, Optional, Tuple, Union import torch from diffusers.models.transformers.transformer_wan import WanTransformerBlock, WanTransformer3DModel from diffusers import WanPipeline, WanImageToVideoPipeline from diffusers.image_processor import PipelineImageInput from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers logger = logging.get_logger(__name__) from typing import Any, Callable, Dict, List, Optional, Union from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput class WanTransformerBlock_Sparse(WanTransformerBlock): def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, rotary_emb: torch.Tensor, timestep: int = 0, numeral_timestep: Optional[int] = None, ) -> torch.Tensor: shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb.float() ).chunk(6, dim=1) # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, timestep=timestep, numeral_timestep=numeral_timestep) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) hidden_states = hidden_states + attn_output # 3. Feed-forward norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( hidden_states ) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) return hidden_states class WanTransformer3DModel_Sparse(WanTransformer3DModel): def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, numeral_timestep: Optional[int] = None, encoder_hidden_states_image: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.unflatten(1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, timestep=timestep, numeral_timestep=numeral_timestep ) else: for block in self.blocks: hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, timestep=timestep, numeral_timestep=numeral_timestep, ) # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) # Move the shift and scale tensors to the same device as hidden_states. # When using multi-GPU inference via accelerate these will be on the # first device rather than the last device, which hidden_states ends up # on. shift = shift.to(hidden_states.device) scale = scale.to(hidden_states.device) hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class WanPipeline_Sparse(WanPipeline): @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 480, width: int = 832, num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, defaults to `480`): The height in pixels of the generated image. width (`int`, defaults to `832`): The width in pixels of the generated image. num_frames (`int`, defaults to `81`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, defaults to `5.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): The dtype to use for the torch.amp.autocast. Examples: Returns: [`~WanPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, negative_prompt, height, width, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False device = self._execution_device # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, ) transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, num_frames, torch.float32, device, generator, latents, ) # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, numeral_timestep=i, )[0] if self.do_classifier_free_guidance: noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, numeral_timestep=i, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() self._current_timestep = None if not output_type == "latent": latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( latents.device, latents.dtype ) latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return WanPipelineOutput(frames=video) class WanImageToVideoPipeline_Sparse(WanImageToVideoPipeline): @torch.no_grad() def __call__( self, image: PipelineImageInput, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 480, width: int = 832, num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs self.check_inputs( prompt, negative_prompt, image, height, width, prompt_embeds, negative_prompt_embeds, image_embeds, callback_on_step_end_tensor_inputs, ) if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, ) transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) if image_embeds is None: if last_image is None: image_embeds = self.encode_image(image, device) else: image_embeds = self.encode_image([image, last_image], device) image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(transformer_dtype) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) if last_image is not None: last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( device, dtype=torch.float32 ) latents, condition = self.prepare_latents( image, batch_size * num_videos_per_prompt, num_channels_latents, height, width, num_frames, torch.float32, device, generator, latents, last_image, ) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, attention_kwargs=attention_kwargs, return_dict=False, numeral_timestep=i, # <--- MODIFICATION )[0] if self.do_classifier_free_guidance: noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states_image=image_embeds, attention_kwargs=attention_kwargs, return_dict=False, numeral_timestep=i, # <--- MODIFICATION )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() self._current_timestep = None if not output_type == "latent": latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( latents.device, latents.dtype ) latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents self.maybe_free_model_hooks() if not return_dict: return (video,) return WanPipelineOutput(frames=video) def replace_sparse_forward(): WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward WanPipeline.__call__ = WanPipeline_Sparse.__call__ WanImageToVideoPipeline.__call__ = WanImageToVideoPipeline_Sparse.__call__