from contextlib import contextmanager from typing import List, Union import torch from diffusers.hooks import HookRegistry, ModelHook _CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK" class ControlChannelConcatenateHook(ModelHook): def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]): self.input_names = input_names self.inputs = inputs self.dims = dims def pre_forward(self, module: torch.nn.Module, *args, **kwargs): for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims): original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name] control_tensor = torch.cat([original_tensor, input_tensor], dim=dim) if isinstance(input_name, int): args[input_name] = control_tensor else: kwargs[input_name] = control_tensor return args, kwargs @contextmanager def control_channel_concat( module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int] ): registry = HookRegistry.check_if_exists_or_initialize(module) hook = ControlChannelConcatenateHook(input_names, inputs, dims) registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK) yield registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False)