Spaces:
Running
Running
File size: 1,439 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from contextlib import contextmanager
from typing import List, Union
import torch
from diffusers.hooks import HookRegistry, ModelHook
_CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK"
class ControlChannelConcatenateHook(ModelHook):
def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]):
self.input_names = input_names
self.inputs = inputs
self.dims = dims
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims):
original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name]
control_tensor = torch.cat([original_tensor, input_tensor], dim=dim)
if isinstance(input_name, int):
args[input_name] = control_tensor
else:
kwargs[input_name] = control_tensor
return args, kwargs
@contextmanager
def control_channel_concat(
module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int]
):
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = ControlChannelConcatenateHook(input_names, inputs, dims)
registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK)
yield
registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False)
|