Spaces:
Running
Running
File size: 8,235 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import contextlib
import functools
import os
from typing import Callable, List, Tuple
import torch
import torch.backends
from diffusers.hooks import HookRegistry, ModelHook
from finetrainers import logging, parallel, patches
from finetrainers.args import BaseArgsType
from finetrainers.logging import get_logger
from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry
from finetrainers.state import State
logger = get_logger()
_LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook"
class Trainer:
def __init__(self, args: BaseArgsType):
self.args = args
self.state = State()
self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training)
self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference)
self._init_distributed()
self._init_config_options()
# Perform any patches that might be necessary for training to work as expected
patches.perform_patches_for_training(self.args, self.state.parallel_backend)
@contextlib.contextmanager
def attention_provider_ctx(self, training: bool = True):
name_providers_active = (
self._module_name_providers_training if training else self._module_name_providers_inference
)
name_providers_dict = dict(name_providers_active)
default_provider = _AttentionProviderRegistry._active_provider
all_registered_module_names = [
attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module)
]
for module_name in all_registered_module_names:
if module_name in name_providers_dict:
continue
name_providers_dict[module_name] = default_provider
module_providers_dict = {}
for module_name, provider in name_providers_dict.items():
module = getattr(self, module_name, None)
if module is not None:
module_providers_dict[module] = (module_name, provider)
# We don't want to immediately unset the attention provider to default after forward because if the
# model is being trained, the backward pass must be invoked with the same attention provider
# So, we lazily switch attention providers only when the forward pass of a new module is called
def callback(m: torch.nn.Module):
module_name, provider = module_providers_dict[m]
# HACK: for CP on transformer. Need to support other modules too and improve overall experience for external usage
if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled:
if not _AttentionProviderRegistry.supports_context_parallel(provider):
raise ValueError(
f"Attention provider {provider} does not support context parallel. Please use a different provider."
)
_AttentionProviderRegistry._set_context_parallel(
mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather"
)
_AttentionProviderRegistry._active_provider = provider
# HACK: for VAE
if "vae" in name_providers_dict:
_apply_forward_hooks_hack(self.vae, name_providers_dict["vae"])
for module in module_providers_dict.keys():
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LatestActiveModuleHook(callback)
registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK)
yield
_AttentionProviderRegistry._active_provider = default_provider
_AttentionProviderRegistry._set_context_parallel(reset=True)
for module in module_providers_dict.keys():
registry: HookRegistry = module._diffusers_hook
registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK)
def _init_distributed(self) -> None:
world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
# TODO(aryan): handle other backends
backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend)
self.state.parallel_backend = backend_cls(
world_size=world_size,
pp_degree=self.args.pp_degree,
dp_degree=self.args.dp_degree,
dp_shards=self.args.dp_shards,
cp_degree=self.args.cp_degree,
tp_degree=self.args.tp_degree,
backend="nccl",
timeout=self.args.init_timeout,
logging_dir=self.args.logging_dir,
output_dir=self.args.output_dir,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)
if self.args.seed is not None:
self.state.parallel_backend.enable_determinism(self.args.seed)
def _init_logging(self) -> None:
logging._set_parallel_backend(self.state.parallel_backend)
logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process)
logger.info("Initialized FineTrainers")
def _init_trackers(self) -> None:
# TODO(aryan): handle multiple trackers
trackers = [self.args.report_to]
experiment_name = self.args.tracker_name or "finetrainers-experiment"
self.state.parallel_backend.initialize_trackers(
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
)
def _init_config_options(self) -> None:
# Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if self.args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision(self.args.float32_matmul_precision)
@property
def tracker(self):
return self.state.parallel_backend.tracker
class LatestActiveModuleHook(ModelHook):
def __init__(self, callback: Callable[[torch.nn.Module], None] = None):
super().__init__()
self.callback = callback
def pre_forward(self, module, *args, **kwargs):
self.callback(module)
return args, kwargs
def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]:
parsed_providers = []
if attn_providers:
for provider_str in attn_providers:
parts = provider_str.split(":")
if len(parts) != 2:
raise ValueError(
f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'."
)
parts[1] = AttentionProvider(parts[1])
parsed_providers.append(tuple(parts))
return parsed_providers
# TODO(aryan): instead of this, we could probably just apply the hook to vae.children() as we know their forward methods will be invoked
def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider):
if hasattr(module, "_finetrainers_wrapped_methods"):
return
def create_wrapper(old_method):
@functools.wraps(old_method)
def wrapper(*args, **kwargs):
_AttentionProviderRegistry._set_context_parallel(reset=True) # HACK: needs improvement
old_provider = _AttentionProviderRegistry._active_provider
_AttentionProviderRegistry._active_provider = provider
output = old_method(*args, **kwargs)
_AttentionProviderRegistry._active_provider = old_provider
return output
return wrapper
methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"]
finetrainers_wrapped_methods = []
for method_name in methods:
if not hasattr(module, method_name):
continue
method = getattr(module, method_name)
wrapper = create_wrapper(method)
setattr(module, method_name, wrapper)
finetrainers_wrapped_methods.append(method_name)
module._finetrainers_wrapped_methods = finetrainers_wrapped_methods
|