Spaces:
Running
Running
import contextlib | |
import functools | |
import os | |
from typing import Callable, List, Tuple | |
import torch | |
import torch.backends | |
from diffusers.hooks import HookRegistry, ModelHook | |
from finetrainers import logging, parallel, patches | |
from finetrainers.args import BaseArgsType | |
from finetrainers.logging import get_logger | |
from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry | |
from finetrainers.state import State | |
logger = get_logger() | |
_LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook" | |
class Trainer: | |
def __init__(self, args: BaseArgsType): | |
self.args = args | |
self.state = State() | |
self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training) | |
self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference) | |
self._init_distributed() | |
self._init_config_options() | |
# Perform any patches that might be necessary for training to work as expected | |
patches.perform_patches_for_training(self.args, self.state.parallel_backend) | |
def attention_provider_ctx(self, training: bool = True): | |
name_providers_active = ( | |
self._module_name_providers_training if training else self._module_name_providers_inference | |
) | |
name_providers_dict = dict(name_providers_active) | |
default_provider = _AttentionProviderRegistry._active_provider | |
all_registered_module_names = [ | |
attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module) | |
] | |
for module_name in all_registered_module_names: | |
if module_name in name_providers_dict: | |
continue | |
name_providers_dict[module_name] = default_provider | |
module_providers_dict = {} | |
for module_name, provider in name_providers_dict.items(): | |
module = getattr(self, module_name, None) | |
if module is not None: | |
module_providers_dict[module] = (module_name, provider) | |
# We don't want to immediately unset the attention provider to default after forward because if the | |
# model is being trained, the backward pass must be invoked with the same attention provider | |
# So, we lazily switch attention providers only when the forward pass of a new module is called | |
def callback(m: torch.nn.Module): | |
module_name, provider = module_providers_dict[m] | |
# HACK: for CP on transformer. Need to support other modules too and improve overall experience for external usage | |
if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled: | |
if not _AttentionProviderRegistry.supports_context_parallel(provider): | |
raise ValueError( | |
f"Attention provider {provider} does not support context parallel. Please use a different provider." | |
) | |
_AttentionProviderRegistry._set_context_parallel( | |
mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather" | |
) | |
_AttentionProviderRegistry._active_provider = provider | |
# HACK: for VAE | |
if "vae" in name_providers_dict: | |
_apply_forward_hooks_hack(self.vae, name_providers_dict["vae"]) | |
for module in module_providers_dict.keys(): | |
registry = HookRegistry.check_if_exists_or_initialize(module) | |
hook = LatestActiveModuleHook(callback) | |
registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK) | |
yield | |
_AttentionProviderRegistry._active_provider = default_provider | |
_AttentionProviderRegistry._set_context_parallel(reset=True) | |
for module in module_providers_dict.keys(): | |
registry: HookRegistry = module._diffusers_hook | |
registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK) | |
def _init_distributed(self) -> None: | |
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) | |
# TODO(aryan): handle other backends | |
backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) | |
self.state.parallel_backend = backend_cls( | |
world_size=world_size, | |
pp_degree=self.args.pp_degree, | |
dp_degree=self.args.dp_degree, | |
dp_shards=self.args.dp_shards, | |
cp_degree=self.args.cp_degree, | |
tp_degree=self.args.tp_degree, | |
backend="nccl", | |
timeout=self.args.init_timeout, | |
logging_dir=self.args.logging_dir, | |
output_dir=self.args.output_dir, | |
gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
) | |
if self.args.seed is not None: | |
self.state.parallel_backend.enable_determinism(self.args.seed) | |
def _init_logging(self) -> None: | |
logging._set_parallel_backend(self.state.parallel_backend) | |
logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process) | |
logger.info("Initialized FineTrainers") | |
def _init_trackers(self) -> None: | |
# TODO(aryan): handle multiple trackers | |
trackers = [self.args.report_to] | |
experiment_name = self.args.tracker_name or "finetrainers-experiment" | |
self.state.parallel_backend.initialize_trackers( | |
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir | |
) | |
def _init_config_options(self) -> None: | |
# Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
if self.args.allow_tf32 and torch.cuda.is_available(): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.set_float32_matmul_precision(self.args.float32_matmul_precision) | |
def tracker(self): | |
return self.state.parallel_backend.tracker | |
class LatestActiveModuleHook(ModelHook): | |
def __init__(self, callback: Callable[[torch.nn.Module], None] = None): | |
super().__init__() | |
self.callback = callback | |
def pre_forward(self, module, *args, **kwargs): | |
self.callback(module) | |
return args, kwargs | |
def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]: | |
parsed_providers = [] | |
if attn_providers: | |
for provider_str in attn_providers: | |
parts = provider_str.split(":") | |
if len(parts) != 2: | |
raise ValueError( | |
f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'." | |
) | |
parts[1] = AttentionProvider(parts[1]) | |
parsed_providers.append(tuple(parts)) | |
return parsed_providers | |
# TODO(aryan): instead of this, we could probably just apply the hook to vae.children() as we know their forward methods will be invoked | |
def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider): | |
if hasattr(module, "_finetrainers_wrapped_methods"): | |
return | |
def create_wrapper(old_method): | |
def wrapper(*args, **kwargs): | |
_AttentionProviderRegistry._set_context_parallel(reset=True) # HACK: needs improvement | |
old_provider = _AttentionProviderRegistry._active_provider | |
_AttentionProviderRegistry._active_provider = provider | |
output = old_method(*args, **kwargs) | |
_AttentionProviderRegistry._active_provider = old_provider | |
return output | |
return wrapper | |
methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"] | |
finetrainers_wrapped_methods = [] | |
for method_name in methods: | |
if not hasattr(module, method_name): | |
continue | |
method = getattr(module, method_name) | |
wrapper = create_wrapper(method) | |
setattr(module, method_name, wrapper) | |
finetrainers_wrapped_methods.append(method_name) | |
module._finetrainers_wrapped_methods = finetrainers_wrapped_methods | |