Diffusers documentation
Caching methods
Caching methods
Cache methods speedup diffusion transformers by storing and reusing intermediate outputs of specific layers, such as attention and feedforward layers, instead of recalculating them at each inference step.
CacheMixin
A class for enable/disabling caching techniques on diffusion models.
Supported caching techniques:
enable_cache
< source >( config )
Enable caching techniques on the model.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> pipe.transformer.enable_cache(config)
PyramidAttentionBroadcastConfig
class diffusers.PyramidAttentionBroadcastConfig
< source >( spatial_attention_block_skip_range: typing.Optional[int] = None temporal_attention_block_skip_range: typing.Optional[int] = None cross_attention_block_skip_range: typing.Optional[int] = None spatial_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) temporal_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) cross_attention_timestep_skip_range: typing.Tuple[int, int] = (100, 800) spatial_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks', 'single_transformer_blocks') temporal_attention_block_identifiers: typing.Tuple[str, ...] = ('temporal_transformer_blocks',) cross_attention_block_identifiers: typing.Tuple[str, ...] = ('blocks', 'transformer_blocks') current_timestep_callback: typing.Callable[[], int] = None )
Parameters
- spatial_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific spatial attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - temporal_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific temporal attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - cross_attention_block_skip_range (
int
, optional, defaults toNone
) — The number of times a specific cross-attention broadcast is skipped before computing the attention states to re-use. If this is set to the valueN
, the attention computation will be skippedN - 1
times (i.e., old attention states will be re-used) before computing the new attention states again. - spatial_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the spatial attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - temporal_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the temporal attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - cross_attention_timestep_skip_range (
Tuple[int, int]
, defaults to(100, 800)
) — The range of timesteps to skip in the cross-attention layer. The attention computations will be conditionally skipped if the current timestep is within the specified range. - spatial_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks", "transformer_blocks")
) — The identifiers to match against the layer names to determine if the layer is a spatial attention layer. - temporal_attention_block_identifiers (
Tuple[str, ...]
, defaults to("temporal_transformer_blocks",)
) — The identifiers to match against the layer names to determine if the layer is a temporal attention layer. - cross_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks", "transformer_blocks")
) — The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
Configuration for Pyramid Attention Broadcast.
diffusers.apply_pyramid_attention_broadcast
< source >( module: Module config: PyramidAttentionBroadcastConfig )
Apply Pyramid Attention Broadcast to a given pipeline.
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to reduce the computational cost of attention computation. The key takeaway from the paper is that the attention similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
>>> from diffusers.utils import export_to_video
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = PyramidAttentionBroadcastConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(100, 800),
... current_timestep_callback=lambda: pipe.current_timestep,
... )
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
FasterCacheConfig
class diffusers.FasterCacheConfig
< source >( spatial_attention_block_skip_range: int = 2 temporal_attention_block_skip_range: typing.Optional[int] = None spatial_attention_timestep_skip_range: typing.Tuple[int, int] = (-1, 681) temporal_attention_timestep_skip_range: typing.Tuple[int, int] = (-1, 681) low_frequency_weight_update_timestep_range: typing.Tuple[int, int] = (99, 901) high_frequency_weight_update_timestep_range: typing.Tuple[int, int] = (-1, 301) alpha_low_frequency: float = 1.1 alpha_high_frequency: float = 1.1 unconditional_batch_skip_range: int = 5 unconditional_batch_timestep_skip_range: typing.Tuple[int, int] = (-1, 641) spatial_attention_block_identifiers: typing.Tuple[str, ...] = ('^blocks.*attn', '^transformer_blocks.*attn', '^single_transformer_blocks.*attn') temporal_attention_block_identifiers: typing.Tuple[str, ...] = ('^temporal_transformer_blocks.*attn',) attention_weight_callback: typing.Callable[[torch.nn.modules.module.Module], float] = None low_frequency_weight_callback: typing.Callable[[torch.nn.modules.module.Module], float] = None high_frequency_weight_callback: typing.Callable[[torch.nn.modules.module.Module], float] = None tensor_format: str = 'BCFHW' is_guidance_distilled: bool = False current_timestep_callback: typing.Callable[[], int] = None _unconditional_conditional_input_kwargs_identifiers: typing.List[str] = ('hidden_states', 'encoder_hidden_states', 'timestep', 'attention_mask', 'encoder_attention_mask') )
Parameters
- spatial_attention_block_skip_range (
int
, defaults to2
) — Calculate the attention states everyN
iterations. If this is set toN
, the attention computation will be skippedN - 1
times (i.e., cached attention states will be re-used) before computing the new attention states again. - temporal_attention_block_skip_range (
int
, optional, defaults toNone
) — Calculate the attention states everyN
iterations. If this is set toN
, the attention computation will be skippedN - 1
times (i.e., cached attention states will be re-used) before computing the new attention states again. - spatial_attention_timestep_skip_range (
Tuple[float, float]
, defaults to(-1, 681)
) — The timestep range within which the spatial attention computation can be skipped without a significant loss in quality. This is to be determined by the user based on the underlying model. The first value in the tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). For the default values, this would mean that the spatial attention computation skipping will be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising process. - temporal_attention_timestep_skip_range (
Tuple[float, float]
, optional, defaults toNone
) — The timestep range within which the temporal attention computation can be skipped without a significant loss in quality. This is to be determined by the user based on the underlying model. The first value in the tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). - low_frequency_weight_update_timestep_range (
Tuple[int, int]
, defaults to(99, 901)
) — The timestep range within which the low frequency weight scaling update is applied. The first value in the tuple is the lower bound and the second value is the upper bound of the timestep range. The callback function for the update is called only within this range. - high_frequency_weight_update_timestep_range (
Tuple[int, int]
, defaults to(-1, 301)
) — The timestep range within which the high frequency weight scaling update is applied. The first value in the tuple is the lower bound and the second value is the upper bound of the timestep range. The callback function for the update is called only within this range. - alpha_low_frequency (
float
, defaults to1.1
) — The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from the conditional branch outputs. - alpha_high_frequency (
float
, defaults to1.1
) — The weight to scale the high frequency updates by. This is used to approximate the unconditional branch from the conditional branch outputs. - unconditional_batch_skip_range (
int
, defaults to5
) — Process the unconditional branch everyN
iterations. If this is set toN
, the unconditional branch computation will be skippedN - 1
times (i.e., cached unconditional branch states will be re-used) before computing the new unconditional branch states again. - unconditional_batch_timestep_skip_range (
Tuple[float, float]
, defaults to(-1, 641)
) — The timestep range within which the unconditional branch computation can be skipped without a significant loss in quality. This is to be determined by the user based on the underlying model. The first value in the tuple is the lower bound and the second value is the upper bound. - spatial_attention_block_identifiers (
Tuple[str, ...]
, defaults to("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")
) — The identifiers to match the spatial attention blocks in the model. If the name of the block contains any of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial layer names, or regex patterns. Matching will always be done using a regex match. - temporal_attention_block_identifiers (
Tuple[str, ...]
, defaults to("temporal_transformer_blocks.*attn1",)
) — The identifiers to match the temporal attention blocks in the model. If the name of the block contains any of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial layer names, or regex patterns. Matching will always be done using a regex match. - attention_weight_callback (
Callable[[torch.nn.Module], float]
, defaults toNone
) — The callback function to determine the weight to scale the attention outputs by. This function should take the attention module as input and return a float value. This is used to approximate the unconditional branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference progresses. Users are encouraged to experiment and provide custom weight schedules that take into account the number of inference steps and underlying model behaviour as denoising progresses. - low_frequency_weight_callback (
Callable[[torch.nn.Module], float]
, defaults toNone
) — The callback function to determine the weight to scale the low frequency updates by. If not provided, the default weight is 1.1 for timesteps within the range specified (as described in the paper). - high_frequency_weight_callback (
Callable[[torch.nn.Module], float]
, defaults toNone
) — The callback function to determine the weight to scale the high frequency updates by. If not provided, the default weight is 1.1 for timesteps within the range specified (as described in the paper). - tensor_format (
str
, defaults to"BCFHW"
) — The format of the input tensors. This should be one of"BCFHW"
,"BFCHW"
, or"BCHW"
. The format is used to split individual latent frames in order for low and high frequency components to be computed. - is_guidance_distilled (
bool
, defaults toFalse
) — Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be applied at the denoiser-level to skip the unconditional branch computation (as there is none). - _unconditional_conditional_input_kwargs_identifiers (
List[str]
, defaults to("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")
) — The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs names that contain the batchwise-concatenated unconditional and conditional inputs.
Configuration for FasterCache.
diffusers.apply_faster_cache
< source >( module: Module config: FasterCacheConfig )
Applies FasterCache to a given pipeline.
Example:
>>> import torch
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = FasterCacheConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(-1, 681),
... low_frequency_weight_update_timestep_range=(99, 641),
... high_frequency_weight_update_timestep_range=(-1, 301),
... spatial_attention_block_identifiers=["transformer_blocks"],
... attention_weight_callback=lambda _: 0.3,
... tensor_format="BFCHW",
... )
>>> apply_faster_cache(pipe.transformer, config)