jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import contextlib
import functools
import os
from typing import Callable, List, Tuple
import torch
import torch.backends
from diffusers.hooks import HookRegistry, ModelHook
from finetrainers import logging, parallel, patches
from finetrainers.args import BaseArgsType
from finetrainers.logging import get_logger
from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry
from finetrainers.state import State
logger = get_logger()
_LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook"
class Trainer:
def __init__(self, args: BaseArgsType):
self.args = args
self.state = State()
self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training)
self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference)
self._init_distributed()
self._init_config_options()
# Perform any patches that might be necessary for training to work as expected
patches.perform_patches_for_training(self.args, self.state.parallel_backend)
@contextlib.contextmanager
def attention_provider_ctx(self, training: bool = True):
name_providers_active = (
self._module_name_providers_training if training else self._module_name_providers_inference
)
name_providers_dict = dict(name_providers_active)
default_provider = _AttentionProviderRegistry._active_provider
all_registered_module_names = [
attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module)
]
for module_name in all_registered_module_names:
if module_name in name_providers_dict:
continue
name_providers_dict[module_name] = default_provider
module_providers_dict = {}
for module_name, provider in name_providers_dict.items():
module = getattr(self, module_name, None)
if module is not None:
module_providers_dict[module] = (module_name, provider)
# We don't want to immediately unset the attention provider to default after forward because if the
# model is being trained, the backward pass must be invoked with the same attention provider
# So, we lazily switch attention providers only when the forward pass of a new module is called
def callback(m: torch.nn.Module):
module_name, provider = module_providers_dict[m]
# HACK: for CP on transformer. Need to support other modules too and improve overall experience for external usage
if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled:
if not _AttentionProviderRegistry.supports_context_parallel(provider):
raise ValueError(
f"Attention provider {provider} does not support context parallel. Please use a different provider."
)
_AttentionProviderRegistry._set_context_parallel(
mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather"
)
_AttentionProviderRegistry._active_provider = provider
# HACK: for VAE
if "vae" in name_providers_dict:
_apply_forward_hooks_hack(self.vae, name_providers_dict["vae"])
for module in module_providers_dict.keys():
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LatestActiveModuleHook(callback)
registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK)
yield
_AttentionProviderRegistry._active_provider = default_provider
_AttentionProviderRegistry._set_context_parallel(reset=True)
for module in module_providers_dict.keys():
registry: HookRegistry = module._diffusers_hook
registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK)
def _init_distributed(self) -> None:
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
# TODO(aryan): handle other backends
backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend)
self.state.parallel_backend = backend_cls(
world_size=world_size,
pp_degree=self.args.pp_degree,
dp_degree=self.args.dp_degree,
dp_shards=self.args.dp_shards,
cp_degree=self.args.cp_degree,
tp_degree=self.args.tp_degree,
backend="nccl",
timeout=self.args.init_timeout,
logging_dir=self.args.logging_dir,
output_dir=self.args.output_dir,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)
if self.args.seed is not None:
self.state.parallel_backend.enable_determinism(self.args.seed)
def _init_logging(self) -> None:
logging._set_parallel_backend(self.state.parallel_backend)
logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process)
logger.info("Initialized FineTrainers")
def _init_trackers(self) -> None:
# TODO(aryan): handle multiple trackers
trackers = [self.args.report_to]
experiment_name = self.args.tracker_name or "finetrainers-experiment"
self.state.parallel_backend.initialize_trackers(
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
)
def _init_config_options(self) -> None:
# Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if self.args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision(self.args.float32_matmul_precision)
@property
def tracker(self):
return self.state.parallel_backend.tracker
class LatestActiveModuleHook(ModelHook):
def __init__(self, callback: Callable[[torch.nn.Module], None] = None):
super().__init__()
self.callback = callback
def pre_forward(self, module, *args, **kwargs):
self.callback(module)
return args, kwargs
def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]:
parsed_providers = []
if attn_providers:
for provider_str in attn_providers:
parts = provider_str.split(":")
if len(parts) != 2:
raise ValueError(
f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'."
)
parts[1] = AttentionProvider(parts[1])
parsed_providers.append(tuple(parts))
return parsed_providers
# TODO(aryan): instead of this, we could probably just apply the hook to vae.children() as we know their forward methods will be invoked
def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider):
if hasattr(module, "_finetrainers_wrapped_methods"):
return
def create_wrapper(old_method):
@functools.wraps(old_method)
def wrapper(*args, **kwargs):
_AttentionProviderRegistry._set_context_parallel(reset=True) # HACK: needs improvement
old_provider = _AttentionProviderRegistry._active_provider
_AttentionProviderRegistry._active_provider = provider
output = old_method(*args, **kwargs)
_AttentionProviderRegistry._active_provider = old_provider
return output
return wrapper
methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"]
finetrainers_wrapped_methods = []
for method_name in methods:
if not hasattr(module, method_name):
continue
method = getattr(module, method_name)
wrapper = create_wrapper(method)
setattr(module, method_name, wrapper)
finetrainers_wrapped_methods.append(method_name)
module._finetrainers_wrapped_methods = finetrainers_wrapped_methods