Diffusers documentation
Wan2.1
Wan2.1
Wan2.1 is a series of large diffusion transformer available in two versions, a high-performance 14B parameter model and a more accessible 1.3B version. Trained on billions of images and videos, it supports tasks like text-to-video (T2V) and image-to-video (I2V) while enabling features such as camera control and stylistic diversity. The Wan-VAE features better image data compression and a feature cache mechanism that encodes and decodes a video in chunks. To maintain continuity, features from previous chunks are cached and reused for processing subsequent chunks. This improves inference efficiency by reducing memory usage. Wan2.1 also uses a multilingual text encoder and the diffusion transformer models space and time relationships and text conditions with each time step to capture more complex video dynamics.
You can find all the original Wan2.1 checkpoints under the Wan-AI organization.
Click on the Wan2.1 models in the right sidebar for more examples of video generation.
The example below demonstrates how to generate a video from text optimized for memory or inference speed.
Refer to the Reduce memory usage guide for more details about the various memory saving techniques.
The Wan2.1 text-to-video model below requires ~13GB of VRAM.
# pip install ftfy
import torch
import numpy as np
from diffusers import AutoModel, WanPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel
text_encoder = UMT5EncoderModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
transformer = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
# group-offloading
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
pipeline = WanPipeline.from_pretrained(
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=81,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
Notes
Wan2.1 supports LoRAs with load_lora_weights().
Show example code
# pip install ftfy import torch from diffusers import AutoModel, WanPipeline from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from diffusers.utils import export_to_video vae = AutoModel.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 ) pipeline = WanPipeline.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", vae=vae, torch_dtype=torch.bfloat16 ) pipeline.scheduler = UniPCMultistepScheduler.from_config( pipeline.scheduler.config, flow_shift=5.0 ) pipeline.to("cuda") pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie") pipeline.set_adapters("steamboat-willie") pipeline.enable_model_cpu_offload() # use "steamboat willie style" to trigger the LoRA prompt = """ steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ output = pipeline( prompt=prompt, num_frames=81, guidance_scale=5.0, ).frames[0] export_to_video(output, "output.mp4", fps=16)
WanTransformer3DModel and AutoencoderKLWan supports loading from single files with from_single_file().
Show example code
# pip install ftfy import torch from diffusers import WanPipeline, AutoModel vae = AutoModel.from_single_file( "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors" ) transformer = AutoModel.from_single_file( "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors", torch_dtype=torch.bfloat16 ) pipeline = WanPipeline.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", vae=vae, transformer=transformer, torch_dtype=torch.bfloat16 )
Set the AutoencoderKLWan dtype to
torch.float32
for better decoding quality.The number of frames per second (fps) or
k
should be calculated by4 * k + 1
.Try lower
shift
values (2.0
to5.0
) for lower resolution videos and highershift
values (7.0
to12.0
) for higher resolution images.
WanPipeline
class diffusers.WanPipeline
< source >( tokenizer: AutoTokenizer text_encoder: UMT5EncoderModel transformer: WanTransformer3DModel vae: AutoencoderKLWan scheduler: FlowMatchEulerDiscreteScheduler )
Parameters
- tokenizer (
T5Tokenizer
) — Tokenizer from T5, specifically the google/umt5-xxl variant. - text_encoder (
T5EncoderModel
) — T5, specifically the google/umt5-xxl variant. - transformer (WanTransformer3DModel) — Conditional Transformer to denoise the input latents.
- scheduler (UniPCMultistepScheduler) —
A scheduler to be used in combination with
transformer
to denoise the encoded image latents. - vae (AutoencoderKLWan) — Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
Pipeline for text-to-video generation using Wan.
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).
__call__
< source >( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None height: int = 480 width: int = 832 num_frames: int = 81 num_inference_steps: int = 50 guidance_scale: float = 5.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'np' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~WanPipelineOutput
or tuple
Parameters
- prompt (
str
orList[str]
, optional) — The prompt or prompts to guide the image generation. If not defined, one has to passprompt_embeds
. instead. - height (
int
, defaults to480
) — The height in pixels of the generated image. - width (
int
, defaults to832
) — The width in pixels of the generated image. - num_frames (
int
, defaults to81
) — The number of frames in the generated video. - num_inference_steps (
int
, defaults to50
) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (
float
, defaults to5.0
) — Guidance scale as defined in Classifier-Free Diffusion Guidance.guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingguidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. - num_videos_per_prompt (
int
, optional, defaults to 1) — The number of images to generate per prompt. - generator (
torch.Generator
orList[torch.Generator]
, optional) — Atorch.Generator
to make generation deterministic. - latents (
torch.Tensor
, optional) — Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied randomgenerator
. - prompt_embeds (
torch.Tensor
, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from theprompt
input argument. - output_type (
str
, optional, defaults to"np"
) — The output format of the generated image. Choose betweenPIL.Image
ornp.array
. - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return aWanPipelineOutput
instead of a plain tuple. - attention_kwargs (
dict
, optional) — A kwargs dictionary that if specified is passed along to theAttentionProcessor
as defined underself.processor
in diffusers.models.attention_processor. - callback_on_step_end (
Callable
,PipelineCallback
,MultiPipelineCallbacks
, optional) — A function or a subclass ofPipelineCallback
orMultiPipelineCallbacks
that is called at the end of each denoising step during the inference. with the following arguments:callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
.callback_kwargs
will include a list of all tensors as specified bycallback_on_step_end_tensor_inputs
. - callback_on_step_end_tensor_inputs (
List
, optional) — The list of tensor inputs for thecallback_on_step_end
function. The tensors specified in the list will be passed ascallback_kwargs
argument. You will only be able to include variables listed in the._callback_tensor_inputs
attribute of your pipeline class. - autocast_dtype (
torch.dtype
, optional, defaults totorch.bfloat16
) — The dtype to use for the torch.amp.autocast.
Returns
~WanPipelineOutput
or tuple
If return_dict
is True
, WanPipelineOutput
is returned, otherwise a tuple
is returned where
the first element is a list with the generated images and the second element is a list of bool
s
indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content.
The call function to the pipeline for generation.
Examples:
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from diffusers import AutoencoderKLWan, WanPipeline
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
>>> pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=720,
... width=1280,
... num_frames=81,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=16)
encode_prompt
< source >( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None max_sequence_length: int = 226 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )
Parameters
- prompt (
str
orList[str]
, optional) — prompt to be encoded - negative_prompt (
str
orList[str]
, optional) — The prompt or prompts not to guide the image generation. If not defined, one has to passnegative_prompt_embeds
instead. Ignored when not using guidance (i.e., ignored ifguidance_scale
is less than1
). - do_classifier_free_guidance (
bool
, optional, defaults toTrue
) — Whether to use classifier free guidance or not. - num_videos_per_prompt (
int
, optional, defaults to 1) — Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (
torch.Tensor
, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated fromprompt
input argument. - negative_prompt_embeds (
torch.Tensor
, optional) — Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be generated fromnegative_prompt
input argument. - device — (
torch.device
, optional): torch device - dtype — (
torch.dtype
, optional): torch dtype
Encodes the prompt into text encoder hidden states.
WanImageToVideoPipeline
class diffusers.WanImageToVideoPipeline
< source >( tokenizer: AutoTokenizer text_encoder: UMT5EncoderModel image_encoder: CLIPVisionModel image_processor: CLIPImageProcessor transformer: WanTransformer3DModel vae: AutoencoderKLWan scheduler: FlowMatchEulerDiscreteScheduler )
Parameters
- tokenizer (
T5Tokenizer
) — Tokenizer from T5, specifically the google/umt5-xxl variant. - text_encoder (
T5EncoderModel
) — T5, specifically the google/umt5-xxl variant. - image_encoder (
CLIPVisionModel
) — CLIP, specifically the clip-vit-huge-patch14 variant. - transformer (WanTransformer3DModel) — Conditional Transformer to denoise the input latents.
- scheduler (UniPCMultistepScheduler) —
A scheduler to be used in combination with
transformer
to denoise the encoded image latents. - vae (AutoencoderKLWan) — Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
Pipeline for image-to-video generation using Wan.
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.).
__call__
< source >( image: typing.Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, typing.List[PIL.Image.Image], typing.List[numpy.ndarray], typing.List[torch.Tensor]] prompt: typing.Union[str, typing.List[str]] = None negative_prompt: typing.Union[str, typing.List[str]] = None height: int = 480 width: int = 832 num_frames: int = 81 num_inference_steps: int = 50 guidance_scale: float = 5.0 num_videos_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None image_embeds: typing.Optional[torch.Tensor] = None last_image: typing.Optional[torch.Tensor] = None output_type: typing.Optional[str] = 'np' return_dict: bool = True attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callback_on_step_end: typing.Union[typing.Callable[[int, int, typing.Dict], NoneType], diffusers.callbacks.PipelineCallback, diffusers.callbacks.MultiPipelineCallbacks, NoneType] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] max_sequence_length: int = 512 ) → ~WanPipelineOutput
or tuple
Parameters
- image (
PipelineImageInput
) — The input image to condition the generation on. Must be an image, a list of images or atorch.Tensor
. - prompt (
str
orList[str]
, optional) — The prompt or prompts to guide the image generation. If not defined, one has to passprompt_embeds
. instead. - negative_prompt (
str
orList[str]
, optional) — The prompt or prompts not to guide the image generation. If not defined, one has to passnegative_prompt_embeds
instead. Ignored when not using guidance (i.e., ignored ifguidance_scale
is less than1
). - height (
int
, defaults to480
) — The height of the generated video. - width (
int
, defaults to832
) — The width of the generated video. - num_frames (
int
, defaults to81
) — The number of frames in the generated video. - num_inference_steps (
int
, defaults to50
) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (
float
, defaults to5.0
) — Guidance scale as defined in Classifier-Free Diffusion Guidance.guidance_scale
is defined asw
of equation 2. of Imagen Paper. Guidance scale is enabled by settingguidance_scale > 1
. Higher guidance scale encourages to generate images that are closely linked to the textprompt
, usually at the expense of lower image quality. - num_videos_per_prompt (
int
, optional, defaults to 1) — The number of images to generate per prompt. - generator (
torch.Generator
orList[torch.Generator]
, optional) — Atorch.Generator
to make generation deterministic. - latents (
torch.Tensor
, optional) — Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied randomgenerator
. - prompt_embeds (
torch.Tensor
, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from theprompt
input argument. - negative_prompt_embeds (
torch.Tensor
, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from thenegative_prompt
input argument. - image_embeds (
torch.Tensor
, optional) — Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, image embeddings are generated from theimage
input argument. - output_type (
str
, optional, defaults to"np"
) — The output format of the generated image. Choose betweenPIL.Image
ornp.array
. - return_dict (
bool
, optional, defaults toTrue
) — Whether or not to return aWanPipelineOutput
instead of a plain tuple. - attention_kwargs (
dict
, optional) — A kwargs dictionary that if specified is passed along to theAttentionProcessor
as defined underself.processor
in diffusers.models.attention_processor. - callback_on_step_end (
Callable
,PipelineCallback
,MultiPipelineCallbacks
, optional) — A function or a subclass ofPipelineCallback
orMultiPipelineCallbacks
that is called at the end of each denoising step during the inference. with the following arguments:callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)
.callback_kwargs
will include a list of all tensors as specified bycallback_on_step_end_tensor_inputs
. - callback_on_step_end_tensor_inputs (
List
, optional) — The list of tensor inputs for thecallback_on_step_end
function. The tensors specified in the list will be passed ascallback_kwargs
argument. You will only be able to include variables listed in the._callback_tensor_inputs
attribute of your pipeline class. - max_sequence_length (
int
, optional, defaults to512
) — The maximum sequence length of the prompt. - shift (
float
, optional, defaults to5.0
) — The shift of the flow. - autocast_dtype (
torch.dtype
, optional, defaults totorch.bfloat16
) — The dtype to use for the torch.amp.autocast.
Returns
~WanPipelineOutput
or tuple
If return_dict
is True
, WanPipelineOutput
is returned, otherwise a tuple
is returned where
the first element is a list with the generated images and the second element is a list of bool
s
indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content.
The call function to the pipeline for generation.
Examples:
>>> import torch
>>> import numpy as np
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> from transformers import CLIPVisionModel
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
>>> image_encoder = CLIPVisionModel.from_pretrained(
... model_id, subfolder="image_encoder", torch_dtype=torch.float32
... )
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = WanImageToVideoPipeline.from_pretrained(
... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
... )
>>> max_area = 480 * 832
>>> aspect_ratio = image.height / image.width
>>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
>>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
>>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
>>> image = image.resize((width, height))
>>> prompt = (
... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
... )
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
>>> output = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=height,
... width=width,
... num_frames=81,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=16)
encode_prompt
< source >( prompt: typing.Union[str, typing.List[str]] negative_prompt: typing.Union[str, typing.List[str], NoneType] = None do_classifier_free_guidance: bool = True num_videos_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.Tensor] = None negative_prompt_embeds: typing.Optional[torch.Tensor] = None max_sequence_length: int = 226 device: typing.Optional[torch.device] = None dtype: typing.Optional[torch.dtype] = None )
Parameters
- prompt (
str
orList[str]
, optional) — prompt to be encoded - negative_prompt (
str
orList[str]
, optional) — The prompt or prompts not to guide the image generation. If not defined, one has to passnegative_prompt_embeds
instead. Ignored when not using guidance (i.e., ignored ifguidance_scale
is less than1
). - do_classifier_free_guidance (
bool
, optional, defaults toTrue
) — Whether to use classifier free guidance or not. - num_videos_per_prompt (
int
, optional, defaults to 1) — Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (
torch.Tensor
, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated fromprompt
input argument. - negative_prompt_embeds (
torch.Tensor
, optional) — Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, negative_prompt_embeds will be generated fromnegative_prompt
input argument. - device — (
torch.device
, optional): torch device - dtype — (
torch.dtype
, optional): torch dtype
Encodes the prompt into text encoder hidden states.
WanPipelineOutput
class diffusers.pipelines.wan.pipeline_output.WanPipelineOutput
< source >( frames: Tensor )
Parameters
- frames (
torch.Tensor
,np.ndarray
, or List[List[PIL.Image.Image]]) — List of video outputs - It can be a nested list of lengthbatch_size,
with each sub-list containing denoised PIL image sequences of lengthnum_frames.
It can also be a NumPy array or Torch tensor of shape(batch_size, num_frames, channels, height, width)
.
Output class for Wan pipelines.