multimodalart HF Staff commited on
Commit
19486d0
·
verified ·
1 Parent(s): abdc3ca

Update radial_attn/models/wan/sparse_transformer.py

Browse files
radial_attn/models/wan/sparse_transformer.py CHANGED
@@ -5,7 +5,8 @@ from typing import Any, Dict, Optional, Tuple, Union
5
  import torch
6
 
7
  from diffusers.models.transformers.transformer_wan import WanTransformerBlock, WanTransformer3DModel
8
- from diffusers import WanPipeline
 
9
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
10
  from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
11
  logger = logging.get_logger(__name__)
@@ -13,6 +14,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
13
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
14
  from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
15
 
 
16
  class WanTransformerBlock_Sparse(WanTransformerBlock):
17
  def forward(
18
  self,
@@ -365,7 +367,176 @@ class WanPipeline_Sparse(WanPipeline):
365
 
366
  return WanPipelineOutput(frames=video)
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  def replace_sparse_forward():
369
  WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward
370
  WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward
371
- WanPipeline.__call__ = WanPipeline_Sparse.__call__
 
 
5
  import torch
6
 
7
  from diffusers.models.transformers.transformer_wan import WanTransformerBlock, WanTransformer3DModel
8
+ from diffusers import WanPipeline, WanImageToVideoPipeline
9
+ from diffusers.image_processor import PipelineImageInput
10
  from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
  from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
12
  logger = logging.get_logger(__name__)
 
14
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
15
  from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
16
 
17
+
18
  class WanTransformerBlock_Sparse(WanTransformerBlock):
19
  def forward(
20
  self,
 
367
 
368
  return WanPipelineOutput(frames=video)
369
 
370
+ # Add this entire function to the file
371
+ @torch.no_grad()
372
+ def wan_i2v_pipeline_call_sparse(
373
+ self,
374
+ image: PipelineImageInput,
375
+ prompt: Union[str, List[str]] = None,
376
+ negative_prompt: Union[str, List[str]] = None,
377
+ height: int = 480,
378
+ width: int = 832,
379
+ num_frames: int = 81,
380
+ num_inference_steps: int = 50,
381
+ guidance_scale: float = 5.0,
382
+ num_videos_per_prompt: Optional[int] = 1,
383
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
384
+ latents: Optional[torch.Tensor] = None,
385
+ prompt_embeds: Optional[torch.Tensor] = None,
386
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
387
+ image_embeds: Optional[torch.Tensor] = None,
388
+ last_image: Optional[torch.Tensor] = None,
389
+ output_type: Optional[str] = "np",
390
+ return_dict: bool = True,
391
+ attention_kwargs: Optional[Dict[str, Any]] = None,
392
+ callback_on_step_end: Optional[
393
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
394
+ ] = None,
395
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
396
+ max_sequence_length: int = 512,
397
+ ):
398
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
399
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
400
+
401
+ self.check_inputs(
402
+ prompt,
403
+ negative_prompt,
404
+ image,
405
+ height,
406
+ width,
407
+ prompt_embeds,
408
+ negative_prompt_embeds,
409
+ image_embeds,
410
+ callback_on_step_end_tensor_inputs,
411
+ )
412
+ if num_frames % self.vae_scale_factor_temporal != 1:
413
+ logger.warning(
414
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
415
+ )
416
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
417
+ num_frames = max(num_frames, 1)
418
+
419
+ self._guidance_scale = guidance_scale
420
+ self._attention_kwargs = attention_kwargs
421
+ self._current_timestep = None
422
+ self._interrupt = False
423
+ device = self._execution_device
424
+
425
+ if prompt is not None and isinstance(prompt, str):
426
+ batch_size = 1
427
+ elif prompt is not None and isinstance(prompt, list):
428
+ batch_size = len(prompt)
429
+ else:
430
+ batch_size = prompt_embeds.shape[0]
431
+
432
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
433
+ prompt=prompt,
434
+ negative_prompt=negative_prompt,
435
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
436
+ num_videos_per_prompt=num_videos_per_prompt,
437
+ prompt_embeds=prompt_embeds,
438
+ negative_prompt_embeds=negative_prompt_embeds,
439
+ max_sequence_length=max_sequence_length,
440
+ device=device,
441
+ )
442
+ transformer_dtype = self.transformer.dtype
443
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
444
+ if negative_prompt_embeds is not None:
445
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
446
+ if image_embeds is None:
447
+ if last_image is None:
448
+ image_embeds = self.encode_image(image, device)
449
+ else:
450
+ image_embeds = self.encode_image([image, last_image], device)
451
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
452
+ image_embeds = image_embeds.to(transformer_dtype)
453
+
454
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
455
+ timesteps = self.scheduler.timesteps
456
+ num_channels_latents = self.vae.config.z_dim
457
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
458
+ if last_image is not None:
459
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
460
+ device, dtype=torch.float32
461
+ )
462
+ latents, condition = self.prepare_latents(
463
+ image,
464
+ batch_size * num_videos_per_prompt,
465
+ num_channels_latents,
466
+ height,
467
+ width,
468
+ num_frames,
469
+ torch.float32,
470
+ device,
471
+ generator,
472
+ latents,
473
+ last_image,
474
+ )
475
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
476
+ self._num_timesteps = len(timesteps)
477
+
478
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
479
+ for i, t in enumerate(timesteps):
480
+ if self.interrupt:
481
+ continue
482
+ self._current_timestep = t
483
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
484
+ timestep = t.expand(latents.shape[0])
485
+ noise_pred = self.transformer(
486
+ hidden_states=latent_model_input,
487
+ timestep=timestep,
488
+ encoder_hidden_states=prompt_embeds,
489
+ encoder_hidden_states_image=image_embeds,
490
+ attention_kwargs=attention_kwargs,
491
+ return_dict=False,
492
+ numeral_timestep=i, # <--- MODIFICATION
493
+ )[0]
494
+ if self.do_classifier_free_guidance:
495
+ noise_uncond = self.transformer(
496
+ hidden_states=latent_model_input,
497
+ timestep=timestep,
498
+ encoder_hidden_states=negative_prompt_embeds,
499
+ encoder_hidden_states_image=image_embeds,
500
+ attention_kwargs=attention_kwargs,
501
+ return_dict=False,
502
+ numeral_timestep=i, # <--- MODIFICATION
503
+ )[0]
504
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
505
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
506
+ if callback_on_step_end is not None:
507
+ callback_kwargs = {}
508
+ for k in callback_on_step_end_tensor_inputs:
509
+ callback_kwargs[k] = locals()[k]
510
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
511
+ latents = callback_outputs.pop("latents", latents)
512
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
513
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
514
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
515
+ progress_bar.update()
516
+
517
+ self._current_timestep = None
518
+ if not output_type == "latent":
519
+ latents = latents.to(self.vae.dtype)
520
+ latents_mean = (
521
+ torch.tensor(self.vae.config.latents_mean)
522
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
523
+ .to(latents.device, latents.dtype)
524
+ )
525
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
526
+ latents.device, latents.dtype
527
+ )
528
+ latents = latents / latents_std + latents_mean
529
+ video = self.vae.decode(latents, return_dict=False)[0]
530
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
531
+ else:
532
+ video = latents
533
+ self.maybe_free_model_hooks()
534
+ if not return_dict:
535
+ return (video,)
536
+ return WanPipelineOutput(frames=video)
537
+
538
  def replace_sparse_forward():
539
  WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward
540
  WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward
541
+ WanPipeline.__call__ = WanPipeline_Sparse.__call__
542
+ WanImageToVideoPipeline.__call__ = wan_i2v_pipeline_call_sparse