from __future__ import annotations import torch import comfy.utils from comfy.patcher_extension import WrappersMP from typing import TYPE_CHECKING, Callable, Optional if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.patcher_extension import WrapperExecutor COMPILE_KEY = "torch.compile" TORCH_COMPILE_KWARGS = "torch_compile_kwargs" def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: ''' Create a wrapper that will refer to the compiled_diffusion_model. ''' def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): try: orig_modules = {} for key, value in compiled_module_dict.items(): orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) comfy.utils.set_attr(executor.class_obj, key, value) return executor(*args, **kwargs) finally: for key, value in orig_modules.items(): comfy.utils.set_attr(executor.class_obj, key, value) return apply_torch_compile_wrapper def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, keys: list[str]=["diffusion_model"], *args, **kwargs): ''' Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model. When a list of keys is provided, it will perform torch.compile on only the selected modules. ''' # clear out any other torch.compile wrappers model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY) # if no keys, default to 'diffusion_model' if not keys: keys = ["diffusion_model"] # create kwargs dict that can be referenced later compile_kwargs = { "backend": backend, "options": options, "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic, } # get a dict of compiled keys compiled_modules = {} for key in keys: compiled_modules[key] = torch.compile( model=model.get_model_object(key), **compile_kwargs, ) # add torch.compile wrapper wrapper_func = apply_torch_compile_factory( compiled_module_dict=compiled_modules, ) # store wrapper to run on BaseModel's apply_model function model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) # keep compile kwargs for reference model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs