|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import fnmatch |
|
from contextlib import contextmanager |
|
|
|
from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock |
|
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel |
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
|
from diffusers.models.unets.unet_2d_blocks import ( |
|
CrossAttnDownBlock2D, |
|
CrossAttnUpBlock2D, |
|
DownBlock2D, |
|
UNetMidBlock2DCrossAttn, |
|
UpBlock2D, |
|
) |
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
from diffusers.models.unets.unet_3d_blocks import ( |
|
CrossAttnDownBlockSpatioTemporal, |
|
CrossAttnUpBlockSpatioTemporal, |
|
DownBlockSpatioTemporal, |
|
UNetMidBlockSpatioTemporal, |
|
UpBlockSpatioTemporal, |
|
) |
|
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel |
|
|
|
from .module import CachedModule |
|
from .utils import replace_module |
|
|
|
CACHED_PIPE = { |
|
UNet2DConditionModel: ( |
|
DownBlock2D, |
|
CrossAttnDownBlock2D, |
|
UNetMidBlock2DCrossAttn, |
|
CrossAttnUpBlock2D, |
|
UpBlock2D, |
|
), |
|
PixArtTransformer2DModel: (BasicTransformerBlock), |
|
UNetSpatioTemporalConditionModel: ( |
|
CrossAttnDownBlockSpatioTemporal, |
|
DownBlockSpatioTemporal, |
|
UpBlockSpatioTemporal, |
|
CrossAttnUpBlockSpatioTemporal, |
|
UNetMidBlockSpatioTemporal, |
|
), |
|
SD3Transformer2DModel: (JointTransformerBlock), |
|
} |
|
|
|
|
|
def _apply_to_modules(model, action, modules=None, config_list=None): |
|
if hasattr(model, "use_trt_infer") and model.use_trt_infer: |
|
for key, module in model.engines.items(): |
|
if isinstance(module, CachedModule): |
|
action(module) |
|
elif config_list: |
|
for config in config_list: |
|
if _pass(key, config["wildcard_or_filter_func"]): |
|
model.engines[key] = CachedModule(module, config["select_cache_step_func"]) |
|
else: |
|
for name, module in model.named_modules(): |
|
if isinstance(module, CachedModule): |
|
action(module) |
|
elif modules and config_list: |
|
for config in config_list: |
|
if _pass(name, config["wildcard_or_filter_func"]) and isinstance( |
|
module, modules |
|
): |
|
replace_module( |
|
model, |
|
name, |
|
CachedModule(module, config["select_cache_step_func"]), |
|
) |
|
|
|
|
|
def cachify(model, config_list, modules): |
|
def cache_action(module): |
|
pass |
|
|
|
_apply_to_modules(model, cache_action, modules, config_list) |
|
|
|
|
|
def disable(pipe): |
|
model = get_model(pipe) |
|
_apply_to_modules(model, lambda module: module.disable_cache()) |
|
|
|
|
|
def enable(pipe): |
|
model = get_model(pipe) |
|
_apply_to_modules(model, lambda module: module.enable_cache()) |
|
|
|
|
|
def reset_status(pipe): |
|
model = get_model(pipe) |
|
_apply_to_modules(model, lambda module: setattr(module, "cur_step", 0)) |
|
|
|
|
|
def _pass(name, wildcard_or_filter_func): |
|
if isinstance(wildcard_or_filter_func, str): |
|
return fnmatch.fnmatch(name, wildcard_or_filter_func) |
|
elif callable(wildcard_or_filter_func): |
|
return wildcard_or_filter_func(name) |
|
else: |
|
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}") |
|
|
|
|
|
def get_model(pipe): |
|
if hasattr(pipe, "unet"): |
|
return pipe.unet |
|
elif hasattr(pipe, "transformer"): |
|
return pipe.transformer |
|
else: |
|
raise KeyError |
|
|
|
|
|
@contextmanager |
|
def infer(pipe): |
|
try: |
|
yield pipe |
|
finally: |
|
reset_status(pipe) |
|
|
|
|
|
def prepare(pipe, config_list): |
|
model = get_model(pipe) |
|
assert model.__class__ in CACHED_PIPE.keys(), f"{model.__class__} is not supported!" |
|
cachify(model, config_list, CACHED_PIPE[model.__class__]) |
|
|