File size: 8,235 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import contextlib
import functools
import os
from typing import Callable, List, Tuple

import torch
import torch.backends
from diffusers.hooks import HookRegistry, ModelHook

from finetrainers import logging, parallel, patches
from finetrainers.args import BaseArgsType
from finetrainers.logging import get_logger
from finetrainers.models.attention_dispatch import AttentionProvider, _AttentionProviderRegistry
from finetrainers.state import State


logger = get_logger()

_LATEST_ACTIVE_MODULE_HOOK = "latest_active_module_hook"


class Trainer:
    def __init__(self, args: BaseArgsType):
        self.args = args

        self.state = State()

        self._module_name_providers_training = _parse_attention_providers(args.attn_provider_training)
        self._module_name_providers_inference = _parse_attention_providers(args.attn_provider_inference)

        self._init_distributed()
        self._init_config_options()

        # Perform any patches that might be necessary for training to work as expected
        patches.perform_patches_for_training(self.args, self.state.parallel_backend)

    @contextlib.contextmanager
    def attention_provider_ctx(self, training: bool = True):
        name_providers_active = (
            self._module_name_providers_training if training else self._module_name_providers_inference
        )
        name_providers_dict = dict(name_providers_active)
        default_provider = _AttentionProviderRegistry._active_provider

        all_registered_module_names = [
            attr for attr in dir(self) if isinstance(getattr(self, attr, None), torch.nn.Module)
        ]
        for module_name in all_registered_module_names:
            if module_name in name_providers_dict:
                continue
            name_providers_dict[module_name] = default_provider

        module_providers_dict = {}
        for module_name, provider in name_providers_dict.items():
            module = getattr(self, module_name, None)
            if module is not None:
                module_providers_dict[module] = (module_name, provider)

        # We don't want to immediately unset the attention provider to default after forward because if the
        # model is being trained, the backward pass must be invoked with the same attention provider
        # So, we lazily switch attention providers only when the forward pass of a new module is called
        def callback(m: torch.nn.Module):
            module_name, provider = module_providers_dict[m]
            # HACK: for CP on transformer. Need to support other modules too and improve overall experience for external usage
            if module_name in ["transformer"] and self.state.parallel_backend.context_parallel_enabled:
                if not _AttentionProviderRegistry.supports_context_parallel(provider):
                    raise ValueError(
                        f"Attention provider {provider} does not support context parallel. Please use a different provider."
                    )
                _AttentionProviderRegistry._set_context_parallel(
                    mesh=self.state.parallel_backend.get_mesh()["cp"], convert_to_fp32=True, rotate_method="allgather"
                )
            _AttentionProviderRegistry._active_provider = provider

        # HACK: for VAE
        if "vae" in name_providers_dict:
            _apply_forward_hooks_hack(self.vae, name_providers_dict["vae"])

        for module in module_providers_dict.keys():
            registry = HookRegistry.check_if_exists_or_initialize(module)
            hook = LatestActiveModuleHook(callback)
            registry.register_hook(hook, _LATEST_ACTIVE_MODULE_HOOK)

        yield

        _AttentionProviderRegistry._active_provider = default_provider
        _AttentionProviderRegistry._set_context_parallel(reset=True)
        for module in module_providers_dict.keys():
            registry: HookRegistry = module._diffusers_hook
            registry.remove_hook(_LATEST_ACTIVE_MODULE_HOOK)

    def _init_distributed(self) -> None:
        world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))

        # TODO(aryan): handle other backends
        backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend)
        self.state.parallel_backend = backend_cls(
            world_size=world_size,
            pp_degree=self.args.pp_degree,
            dp_degree=self.args.dp_degree,
            dp_shards=self.args.dp_shards,
            cp_degree=self.args.cp_degree,
            tp_degree=self.args.tp_degree,
            backend="nccl",
            timeout=self.args.init_timeout,
            logging_dir=self.args.logging_dir,
            output_dir=self.args.output_dir,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
        )

        if self.args.seed is not None:
            self.state.parallel_backend.enable_determinism(self.args.seed)

    def _init_logging(self) -> None:
        logging._set_parallel_backend(self.state.parallel_backend)
        logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process)
        logger.info("Initialized FineTrainers")

    def _init_trackers(self) -> None:
        # TODO(aryan): handle multiple trackers
        trackers = [self.args.report_to]
        experiment_name = self.args.tracker_name or "finetrainers-experiment"
        self.state.parallel_backend.initialize_trackers(
            trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
        )

    def _init_config_options(self) -> None:
        # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
        if self.args.allow_tf32 and torch.cuda.is_available():
            torch.backends.cuda.matmul.allow_tf32 = True
        torch.set_float32_matmul_precision(self.args.float32_matmul_precision)

    @property
    def tracker(self):
        return self.state.parallel_backend.tracker


class LatestActiveModuleHook(ModelHook):
    def __init__(self, callback: Callable[[torch.nn.Module], None] = None):
        super().__init__()
        self.callback = callback

    def pre_forward(self, module, *args, **kwargs):
        self.callback(module)
        return args, kwargs


def _parse_attention_providers(attn_providers: List[str] = None) -> List[Tuple[str, AttentionProvider]]:
    parsed_providers = []
    if attn_providers:
        for provider_str in attn_providers:
            parts = provider_str.split(":")
            if len(parts) != 2:
                raise ValueError(
                    f"Invalid attention provider format: '{provider_str}'. Expected 'module_name:provider_name'."
                )
            parts[1] = AttentionProvider(parts[1])
            parsed_providers.append(tuple(parts))
    return parsed_providers


# TODO(aryan): instead of this, we could probably just apply the hook to vae.children() as we know their forward methods will be invoked
def _apply_forward_hooks_hack(module: torch.nn.Module, provider: AttentionProvider):
    if hasattr(module, "_finetrainers_wrapped_methods"):
        return

    def create_wrapper(old_method):
        @functools.wraps(old_method)
        def wrapper(*args, **kwargs):
            _AttentionProviderRegistry._set_context_parallel(reset=True)  # HACK: needs improvement
            old_provider = _AttentionProviderRegistry._active_provider
            _AttentionProviderRegistry._active_provider = provider
            output = old_method(*args, **kwargs)
            _AttentionProviderRegistry._active_provider = old_provider
            return output

        return wrapper

    methods = ["encode", "decode", "_encode", "_decode", "tiled_encode", "tiled_decode"]
    finetrainers_wrapped_methods = []
    for method_name in methods:
        if not hasattr(module, method_name):
            continue
        method = getattr(module, method_name)
        wrapper = create_wrapper(method)
        setattr(module, method_name, wrapper)
        finetrainers_wrapped_methods.append(method_name)
    module._finetrainers_wrapped_methods = finetrainers_wrapped_methods