import types from ..models import ModelManager from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_vace import VaceWanModel from ..schedulers.flow_match import FlowMatchScheduler from .base import BasePipeline from ..prompters import WanPrompter import torch, os from einops import rearrange import numpy as np from PIL import Image from tqdm import tqdm from typing import Optional from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample from ..models.wan_video_motion_controller import WanMotionControllerModel class WanVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None): super().__init__(device=device, torch_dtype=torch_dtype) self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) self.prompter = WanPrompter(tokenizer_path=tokenizer_path) self.text_encoder: WanTextEncoder = None self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace'] self.height_division_factor = 16 self.width_division_factor = 16 self.use_unified_sequence_parallel = False self.model_fn = model_fn_wan_video #*me def enable_vram_management(self, num_persistent_param_in_dit=None): dtype = next(iter(self.text_encoder.parameters())).dtype enable_vram_management( self.text_encoder, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Embedding: AutoWrappedModule, T5RelativeEmbedding: AutoWrappedModule, T5LayerNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), ) dtype = next(iter(self.dit.parameters())).dtype enable_vram_management( self.dit, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, torch.nn.Conv2d: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.torch_dtype, computation_device=self.device, ), max_num_param=num_persistent_param_in_dit, overflow_module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=self.torch_dtype, computation_device=self.device, ), ) dtype = next(iter(self.vae.parameters())).dtype enable_vram_management( self.vae, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv2d: AutoWrappedModule, RMS_norm: AutoWrappedModule, CausalConv3d: AutoWrappedModule, Upsample: AutoWrappedModule, torch.nn.SiLU: AutoWrappedModule, torch.nn.Dropout: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.torch_dtype, computation_device=self.device, ), ) if self.image_encoder is not None: dtype = next(iter(self.image_encoder.parameters())).dtype enable_vram_management( self.image_encoder, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv2d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=dtype, computation_device=self.device, ), ) if self.motion_controller is not None: dtype = next(iter(self.motion_controller.parameters())).dtype enable_vram_management( self.motion_controller, module_map = { torch.nn.Linear: AutoWrappedLinear, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", computation_dtype=dtype, computation_device=self.device, ), ) if self.vace is not None: enable_vram_management( self.vace, module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.Conv3d: AutoWrappedModule, torch.nn.LayerNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device=self.device, computation_dtype=self.torch_dtype, computation_device=self.device, ), ) self.enable_cpu_offload() def fetch_models(self, model_manager: ModelManager): text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True) if text_encoder_model_and_path is not None: self.text_encoder, tokenizer_path = text_encoder_model_and_path self.prompter.fetch_models(self.text_encoder) self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl")) self.dit = model_manager.fetch_model("wan_video_dit") self.vae = model_manager.fetch_model("wan_video_vae") self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") self.motion_controller = model_manager.fetch_model("wan_video_motion_controller") self.vace = model_manager.fetch_model("wan_video_vace") @staticmethod def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): if device is None: device = model_manager.device if torch_dtype is None: torch_dtype = model_manager.torch_dtype pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe.fetch_models(model_manager) if use_usp: from xfuser.core.distributed import get_sequence_parallel_world_size from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward for block in pipe.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit) pipe.sp_size = get_sequence_parallel_world_size() pipe.use_unified_sequence_parallel = True return pipe def denoising_model(self): return self.dit def encode_prompt(self, prompt, positive=True): prompt_emb = self.prompter.encode_prompt(prompt, positive=positive, device=self.device) return {"context": prompt_emb} # For Inp模型 def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): image = self.preprocess_image(image.resize((width, height))).to(self.device) # 1,c,h,w clip_context = self.image_encoder.encode_image([image]) msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) # 1,f,h1,w1,c=1 msk[:, 1:] = 0 # 首帧之后置为0 if end_image is not None: end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device) vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) if self.dit.has_image_pos_emb: clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1) msk[:, -1:] = 1 # 最后一帧置为1 else: # 第一帧+剩余帧拼0; c=3,f,h,w vae_input = torch.concat( [ image.transpose(0, 1), # 1,c=3,h,w->c=3,1,h,w torch.zeros(3, num_frames-1, height, width).to(image.device) ], dim=1) # mask说明: 首尾为1; 其余为0-> 保留为1, 生成为0, 应为fg_mask(fg为1) # 第一帧重复3次49+3=52 // 4 = 13 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) # 调整维度 msk = msk.transpose(1, 2)[0] # 4,f1,h1,w1 y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] y = y.to(dtype=self.torch_dtype, device=self.device) # c1=16, f1, h1, w1 y = torch.concat([msk, y]) y = y.unsqueeze(0) clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device) y = y.to(dtype=self.torch_dtype, device=self.device) return {"clip_feature": clip_context, "y": y} def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): control_video = self.preprocess_images(control_video) # 归一化 f=49,1,c=3,h,w -> 下一行: 1(bs),c=3,f=49,h,w control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device) # print(control_video.shape, control_video.max(), control_video.min()) # torch.Size([1, 3, 49, 800, 1920]) tensor(0.8125, device='cuda:0', dtype=torch.bfloat16) tensor(-1., device='cuda:0', dtype=torch.bfloat16) latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) return latents def prepare_reference_image(self, reference_image, height=480, width=832): if reference_image is not None: self.load_models_to_device(["vae"]) reference_image = reference_image.resize((width, height)) reference_image = self.preprocess_images([reference_image]) # f=1,1,c=3,h,w # # 输入: 1(bs),c=3,f=1,h,w reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device) reference_latents = self.vae.encode(reference_image, device=self.device) # 1,c1,f1,h1,w1 # reference_image: [1, 3, 1, 480, 832], reference_latents: [1, 16, 1, 60, 104]) return {"reference_latents": reference_latents} else: return {} #* clip_feature #me def image_clip_feature(self, image, height, width): # image: h,w,c -> 1,c=3,h,w (-1,1) image = Image.fromarray(image).convert('RGB') image = self.preprocess_image(image.resize((width, height))).to(self.device) # encode_image输入格式为: # [image]: 1,1,c=3,h,w; 输出clip_feature: 1,257,1280 clip_feature = self.image_encoder.encode_image( [image] ).to(self.device) clip_feature = clip_feature.to(dtype=self.torch_dtype, device=self.device) return clip_feature #me def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, more_cond=None, cond_mode=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): if control_video is not None: # control_video: control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) # control_latents: f if clip_feature is None: clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device) if more_cond is None: y0 = torch.zeros((1, 16, (num_frames-1)//4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device) elif cond_mode in [ 'v2v', 'v2v_bg_fg' ]: y0 = self.encode_control_video(more_cond, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) else: # cond_mode = inp y0 = more_cond.to(dtype=self.torch_dtype, device=self.device) if cond_mode in [ 'inp', 'v2v_bg_fg', 'test' ]: y = torch.concat([y0, control_latents], dim=1) else: y = torch.concat([control_latents, y0], dim=1) # torch.Size([1, 257, 1280]) torch.Size([1, 16+16, 13, 100, 240]) return {"clip_feature": clip_feature, "y": y} # 原代码 def prepare_controlnet_kwargs0(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): if control_video is not None: control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) if clip_feature is None or y is None: clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device) y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device) else: y = y[:, -16:] y = torch.concat([control_latents, y], dim=1) return {"clip_feature": clip_feature, "y": y} def tensor2video(self, frames): frames = rearrange(frames, "C T H W -> T H W C") frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) frames = [Image.fromarray(frame) for frame in frames] return frames def prepare_extra_input(self, latents=None): return {} def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return latents def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames def prepare_unified_sequence_parallel(self): return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} def prepare_motion_bucket_id(self, motion_bucket_id): motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device) return {"motion_bucket_id": motion_bucket_id} def prepare_vace_kwargs( self, latents, vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0, height=480, width=832, num_frames=81, seed=None, rand_device="cpu", tiled=True, tile_size=(34, 34), tile_stride=(18, 16) ): if vace_video is not None or vace_mask is not None or vace_reference_image is not None: self.load_models_to_device(["vae"]) if vace_video is None: vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device) else: vace_video = self.preprocess_images(vace_video) vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device) if vace_mask is None: vace_mask = torch.ones_like(vace_video) else: vace_mask = self.preprocess_images(vace_mask) vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device) inactive = vace_video * (1 - vace_mask) + 0 * vace_mask reactive = vace_video * vace_mask + 0 * (1 - vace_mask) inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) vace_video_latents = torch.concat((inactive, reactive), dim=1) vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') if vace_reference_image is None: pass else: vace_reference_image = self.preprocess_images([vace_reference_image]) vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device) vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32) noise = noise.to(dtype=self.torch_dtype, device=self.device) latents = torch.concat((noise, latents), dim=2) vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) return latents, {"vace_context": vace_context, "vace_scale": vace_scale} else: return latents, {"vace_context": None, "vace_scale": vace_scale} @torch.no_grad() def __call__( self, prompt, negative_prompt="", input_image=None, end_image=None, input_video=None, control_video=None, reference_image=None, vace_video=None, vace_video_mask=None, vace_reference_image=None, vace_scale=1.0, denoising_strength=1.0, seed=None, rand_device="cpu", height=480, width=832, num_frames=81, cfg_scale=5.0, num_inference_steps=50, sigma_shift=5.0, motion_bucket_id=None, tiled=True, tile_size=(30, 52), tile_stride=(15, 26), tea_cache_l1_thresh=None, tea_cache_model_id="", progress_bar_cmd=tqdm, progress_bar_st=None, cond_mode = None, ): # Parameter check height, width = self.check_resize_height_width(height, width) if num_frames % 4 != 1: num_frames = (num_frames + 2) // 4 * 4 + 1 print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") # Tiler parameters tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) # Initialize noise noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32) noise = noise.to(dtype=self.torch_dtype, device=self.device) if input_video is not None: self.load_models_to_device(['vae']) input_video = self.preprocess_images(input_video) input_video = torch.stack(input_video, dim=2).to(dtype=self.torch_dtype, device=self.device) latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: latents = noise # Encode prompts self.load_models_to_device(["text_encoder"]) prompt_emb_posi = self.encode_prompt(prompt, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) # Encode image if input_image is not None and self.image_encoder is not None: self.load_models_to_device(["image_encoder", "vae"]) image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs) else: image_emb = {} # Reference image if reference_image is not None and cond_mode == 'i2v': reference_image_kwargs = self.prepare_reference_image(reference_image, height, width) more_cond = None else: # reference_image_kwargs和more_cond只有一个有值 more_cond = reference_image # ref background video (v2v) or mask latents(inp) reference_image_kwargs = {} # ControlNet if control_video is not None: self.load_models_to_device(["image_encoder", "vae"]) # image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs) #* 输入首帧的clip feature, 有助于保持前景ID clip_feature = self.image_clip_feature(control_video[0], height, width) # 推理时调用 image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, clip_feature, more_cond=more_cond, cond_mode=cond_mode, **image_emb, **tiler_kwargs) # y=cond_latents2, more_config=more_config, **image_emb, **tiler_kwargs) # Motion Controller if self.motion_controller is not None and motion_bucket_id is not None: motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id) else: motion_kwargs = {} # Extra input extra_input = self.prepare_extra_input(latents) # VACE latents, vace_kwargs = self.prepare_vace_kwargs( latents, vace_video, vace_video_mask, vace_reference_image, vace_scale, height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs ) # TeaCache tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} # Unified Sequence Parallel usp_kwargs = self.prepare_unified_sequence_parallel() # Denoise self.load_models_to_device(["dit", "motion_controller", "vace"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference noise_pred_posi = model_fn_wan_video( self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs, ) if cfg_scale != 1.0: noise_pred_nega = model_fn_wan_video( self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi # Scheduler latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) if vace_reference_image is not None: latents = latents[:, :, 1:] # Decode self.load_models_to_device(['vae']) frames = self.decode_video(latents, **tiler_kwargs) self.load_models_to_device([]) frames = self.tensor2video(frames[0]) return frames class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.num_inference_steps = num_inference_steps self.step = 0 self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.rel_l1_thresh = rel_l1_thresh self.previous_residual = None self.previous_hidden_states = None self.coefficients_dict = { "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], } if model_id not in self.coefficients_dict: supported_model_ids = ", ".join([i for i in self.coefficients_dict]) raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") self.coefficients = self.coefficients_dict[model_id] def check(self, dit: WanModel, x, t_mod): modulated_inp = t_mod.clone() if self.step == 0 or self.step == self.num_inference_steps - 1: should_calc = True self.accumulated_rel_l1_distance = 0 else: coefficients = self.coefficients rescale_func = np.poly1d(coefficients) self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) if self.accumulated_rel_l1_distance < self.rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp self.step += 1 if self.step == self.num_inference_steps: self.step = 0 if should_calc: self.previous_hidden_states = x.clone() return not should_calc def store(self, hidden_states): self.previous_residual = hidden_states - self.previous_hidden_states self.previous_hidden_states = None def update(self, hidden_states): hidden_states = hidden_states + self.previous_residual return hidden_states # 旧版前向代码 def model_fn_wan_video0( dit: WanModel, motion_controller: WanMotionControllerModel = None, vace: VaceWanModel = None, x: torch.Tensor = None, timestep: torch.Tensor = None, context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, motion_bucket_id: Optional[torch.Tensor] = None, **kwargs, ): if use_unified_sequence_parallel: import torch.distributed as dist from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) if dit.has_image_input: # 只有这使用了y等, 推出dit.has_image_input=True x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) x, (f, h, w) = dit.patchify(x) # Reference image if reference_latents is not None: # reference_latents: bs=1,c1=16,f1=1,h1,w1->取1,c1,h1,w1 -> 过卷积: 1,dit.dim=1536,h1,w1 # -> flatten(2): 1,dit.dim,h1*w1 -> 1,h1*w1,dit.dim reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2) x = torch.concat([reference_latents, x], dim=1) # 在sequence length维度上拼接 f += 1 # 时间维度+1: 49//4+1=13, + 1 = 14; 相当于把reference_latents当做第0帧拼在了x的前面 freqs = torch.cat([ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) # TeaCache if tea_cache is not None: tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False if vace_context is not None: vace_hints = vace(x, vace_context, context, t_mod, freqs) # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) else: for block_id, block in enumerate(dit.blocks): x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale if tea_cache is not None: tea_cache.store(x) if reference_latents is not None: x = x[:, reference_latents.shape[1]:] f -= 1 x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) x = dit.unpatchify(x, (f, h, w)) return x # 新版前向代码 copy from https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/pipelines/wan_video_new.py 2025.6.30 def model_fn_wan_video( dit: WanModel, motion_controller: WanMotionControllerModel = None, vace: VaceWanModel = None, # latents: torch.Tensor = None, x: torch.Tensor = None, #me timestep: torch.Tensor = None, context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, motion_bucket_id: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, cfg_merge: bool = False, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, control_camera_latents_input = None, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: model_kwargs = dict( dit=dit, motion_controller=motion_controller, vace=vace, latents=latents, timestep=timestep, context=context, clip_feature=clip_feature, y=y, reference_latents=reference_latents, vace_context=vace_context, vace_scale=vace_scale, tea_cache=tea_cache, use_unified_sequence_parallel=use_unified_sequence_parallel, motion_bucket_id=motion_bucket_id, ) return TemporalTiler_BCTHW().run( model_fn_wan_video, sliding_window_size, sliding_window_stride, latents.device, latents.dtype, model_kwargs=model_kwargs, tensor_names=["latents", "y"], batch_size=2 if cfg_merge else 1 ) if use_unified_sequence_parallel: import torch.distributed as dist from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) # x = latents # # Merged cfg #me注释掉 # if x.shape[0] != context.shape[0]: # x = torch.concat([x] * context.shape[0], dim=0) # if timestep.shape[0] != context.shape[0]: # timestep = torch.concat([timestep] * context.shape[0], dim=0) if dit.has_image_input:# 只有这使用了y等, 推出dit.has_image_input=True x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) # Add camera control # x, (f, h, w) = dit.patchify(x, control_camera_latents_input) x, (f, h, w) = dit.patchify(x) #me # Reference image if reference_latents is not None: # reference_latents: bs=1,c1=16,f1=1,h1,w1->取1,c1,h1,w1 -> 过卷积: 1,dit.dim=1536,h1,w1 # -> flatten(2): 1,dit.dim,h1*w1 -> 1,h1*w1,dit.dim if len(reference_latents.shape) == 5: reference_latents = reference_latents[:, :, 0] reference_latents = dit.ref_conv(reference_latents).flatten(2).transpose(1, 2) x = torch.concat([reference_latents, x], dim=1) # 在sequence length维度上拼接 f += 1 # 时间维度+1: 49//4+1=13, + 1 = 14; 相当于把reference_latents当做第0帧拼在了x的前面 freqs = torch.cat([ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) # TeaCache if tea_cache is not None: tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False if vace_context is not None: vace_hints = vace(x, vace_context, context, t_mod, freqs) # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] if tea_cache_update: x = tea_cache.update(x) else: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward for block_id, block in enumerate(dit.blocks): if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) elif use_gradient_checkpointing: #* 训练时为ture x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: x = block(x, context, t_mod, freqs) if vace_context is not None and block_id in vace.vace_layers_mapping: current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] x = x + current_vace_hint * vace_scale if tea_cache is not None: tea_cache.store(x) x = dit.head(x, t) if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) # Remove reference latents if reference_latents is not None: x = x[:, reference_latents.shape[1]:] f -= 1 x = dit.unpatchify(x, (f, h, w)) return x