RobertML's picture
Add files using upload-large-folder tool
3c436c6 verified
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import fnmatch
from contextlib import contextmanager
from diffusers.models.attention import BasicTransformerBlock, JointTransformerBlock
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UpBlock2D,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.models.unets.unet_3d_blocks import (
CrossAttnDownBlockSpatioTemporal,
CrossAttnUpBlockSpatioTemporal,
DownBlockSpatioTemporal,
UNetMidBlockSpatioTemporal,
UpBlockSpatioTemporal,
)
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .module import CachedModule
from .utils import replace_module
CACHED_PIPE = {
UNet2DConditionModel: (
DownBlock2D,
CrossAttnDownBlock2D,
UNetMidBlock2DCrossAttn,
CrossAttnUpBlock2D,
UpBlock2D,
),
PixArtTransformer2DModel: (BasicTransformerBlock),
UNetSpatioTemporalConditionModel: (
CrossAttnDownBlockSpatioTemporal,
DownBlockSpatioTemporal,
UpBlockSpatioTemporal,
CrossAttnUpBlockSpatioTemporal,
UNetMidBlockSpatioTemporal,
),
SD3Transformer2DModel: (JointTransformerBlock),
}
def _apply_to_modules(model, action, modules=None, config_list=None):
if hasattr(model, "use_trt_infer") and model.use_trt_infer:
for key, module in model.engines.items():
if isinstance(module, CachedModule):
action(module)
elif config_list:
for config in config_list:
if _pass(key, config["wildcard_or_filter_func"]):
model.engines[key] = CachedModule(module, config["select_cache_step_func"])
else:
for name, module in model.named_modules():
if isinstance(module, CachedModule):
action(module)
elif modules and config_list:
for config in config_list:
if _pass(name, config["wildcard_or_filter_func"]) and isinstance(
module, modules
):
replace_module(
model,
name,
CachedModule(module, config["select_cache_step_func"]),
)
def cachify(model, config_list, modules):
def cache_action(module):
pass # No action needed, caching is handled in the loop itself
_apply_to_modules(model, cache_action, modules, config_list)
def disable(pipe):
model = get_model(pipe)
_apply_to_modules(model, lambda module: module.disable_cache())
def enable(pipe):
model = get_model(pipe)
_apply_to_modules(model, lambda module: module.enable_cache())
def reset_status(pipe):
model = get_model(pipe)
_apply_to_modules(model, lambda module: setattr(module, "cur_step", 0))
def _pass(name, wildcard_or_filter_func):
if isinstance(wildcard_or_filter_func, str):
return fnmatch.fnmatch(name, wildcard_or_filter_func)
elif callable(wildcard_or_filter_func):
return wildcard_or_filter_func(name)
else:
raise NotImplementedError(f"Unsupported type {type(wildcard_or_filter_func)}")
def get_model(pipe):
if hasattr(pipe, "unet"):
return pipe.unet
elif hasattr(pipe, "transformer"):
return pipe.transformer
else:
raise KeyError
@contextmanager
def infer(pipe):
try:
yield pipe
finally:
reset_status(pipe)
def prepare(pipe, config_list):
model = get_model(pipe)
assert model.__class__ in CACHED_PIPE.keys(), f"{model.__class__} is not supported!"
cachify(model, config_list, CACHED_PIPE[model.__class__])